# scripts/init_db.py import os import sys import logging from pathlib import Path from datetime import datetime # Configure logging before imports logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Add the project root to Python path project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) # Import only what's needed for DB initialization from sqlalchemy import create_engine, text from app.utils.db.models import Base def get_db_path(): """Get the database path""" data_dir = project_root / 'data' data_dir.mkdir(exist_ok=True) return data_dir / 'vpn.db' def backup_existing_db(): """Backup existing database if it exists""" try: db_path = get_db_path() if db_path.exists(): timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') backup_path = db_path.parent / f'vpn_backup_{timestamp}.db' db_path.rename(backup_path) logger.info(f"Created backup at: {backup_path}") return backup_path return None except Exception as e: logger.error(f"Backup failed: {str(e)}") return None def init_db(force=False): """Initialize the database with all tables""" try: db_path = get_db_path() # Check if database already exists if db_path.exists() and not force: logger.warning(f"Database already exists at {db_path}") logger.warning("Use --force to recreate the database") return None # Backup existing database if force is True if force and db_path.exists(): backup_existing_db() logger.info(f"Initializing database at: {db_path}") # Create database URL db_url = f"sqlite:///{db_path}" # Create engine with pragma statements for foreign keys engine = create_engine( db_url, connect_args={"check_same_thread": False} ) # Enable foreign key support using PRAGMA with engine.connect() as conn: conn.execute(text("PRAGMA foreign_keys = ON")) # Create all tables Base.metadata.create_all(engine) logger.info("Successfully created all database tables") # Log created tables tables = Base.metadata.tables.keys() logger.info("Created tables:") for table in tables: logger.info(f" - {table}") return engine except Exception as e: logger.error(f"Database initialization failed: {str(e)}") raise def verify_tables(engine): """Verify that all tables were created correctly""" try: # Get list of all tables in the database with engine.connect() as conn: # SQLite specific query to get table info query = text("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") result = conn.execute(query) existing_tables = {row[0] for row in result} # Get list of all tables defined in models expected_tables = set(Base.metadata.tables.keys()) # Check for missing tables missing_tables = expected_tables - existing_tables if missing_tables: logger.error(f"Missing tables: {missing_tables}") return False # Verify table schemas for table_name in existing_tables: schema_query = text(f"PRAGMA table_info({table_name})") result = conn.execute(schema_query) logger.info(f"\nSchema for {table_name}:") # SQLite PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk) for row in result: cid, name, type_, notnull, dflt_value, pk = row pk_str = "PRIMARY KEY" if pk else "" null_str = "NOT NULL" if notnull else "NULL" default_str = f"DEFAULT {dflt_value}" if dflt_value is not None else "" logger.info(f" - {name} ({type_}) {null_str} {default_str} {pk_str}".strip()) logger.info("All expected tables were created successfully") return True except Exception as e: logger.error(f"Table verification failed: {str(e)}") return False def main(): """Main function to handle database initialization""" try: import argparse parser = argparse.ArgumentParser(description='Initialize the VPN database') parser.add_argument('--force', action='store_true', help='Force database recreation') args = parser.parse_args() logger.info("Starting database initialization") engine = init_db(force=args.force) if engine is None: return 1 # Verify tables were created correctly if verify_tables(engine): logger.info("Database initialization completed successfully") return 0 else: logger.error("Database initialization failed - tables missing") return 1 except Exception as e: logger.error(f"Database initialization failed: {str(e)}") return 1 if __name__ == "__main__": sys.exit(main())