157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
# scripts/init_db.py
|
|
import os
|
|
import sys
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
# Configure logging before imports
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Add the project root to Python path
|
|
project_root = Path(__file__).resolve().parent.parent
|
|
sys.path.append(str(project_root))
|
|
|
|
# Import only what's needed for DB initialization
|
|
from sqlalchemy import create_engine, text
|
|
from app.utils.db.models import Base
|
|
|
|
def get_db_path():
|
|
"""Get the database path"""
|
|
data_dir = project_root / 'data'
|
|
data_dir.mkdir(exist_ok=True)
|
|
return data_dir / 'vpn.db'
|
|
|
|
def backup_existing_db():
|
|
"""Backup existing database if it exists"""
|
|
try:
|
|
db_path = get_db_path()
|
|
if db_path.exists():
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
backup_path = db_path.parent / f'vpn_backup_{timestamp}.db'
|
|
db_path.rename(backup_path)
|
|
logger.info(f"Created backup at: {backup_path}")
|
|
return backup_path
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Backup failed: {str(e)}")
|
|
return None
|
|
|
|
def init_db(force=False):
|
|
"""Initialize the database with all tables"""
|
|
try:
|
|
db_path = get_db_path()
|
|
|
|
# Check if database already exists
|
|
if db_path.exists() and not force:
|
|
logger.warning(f"Database already exists at {db_path}")
|
|
logger.warning("Use --force to recreate the database")
|
|
return None
|
|
|
|
# Backup existing database if force is True
|
|
if force and db_path.exists():
|
|
backup_existing_db()
|
|
|
|
logger.info(f"Initializing database at: {db_path}")
|
|
|
|
# Create database URL
|
|
db_url = f"sqlite:///{db_path}"
|
|
|
|
# Create engine with pragma statements for foreign keys
|
|
engine = create_engine(
|
|
db_url,
|
|
connect_args={"check_same_thread": False}
|
|
)
|
|
|
|
# Enable foreign key support using PRAGMA
|
|
with engine.connect() as conn:
|
|
conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
|
|
# Create all tables
|
|
Base.metadata.create_all(engine)
|
|
logger.info("Successfully created all database tables")
|
|
|
|
# Log created tables
|
|
tables = Base.metadata.tables.keys()
|
|
logger.info("Created tables:")
|
|
for table in tables:
|
|
logger.info(f" - {table}")
|
|
|
|
return engine
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database initialization failed: {str(e)}")
|
|
raise
|
|
|
|
def verify_tables(engine):
|
|
"""Verify that all tables were created correctly"""
|
|
try:
|
|
# Get list of all tables in the database
|
|
with engine.connect() as conn:
|
|
# SQLite specific query to get table info
|
|
query = text("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
|
result = conn.execute(query)
|
|
existing_tables = {row[0] for row in result}
|
|
|
|
# Get list of all tables defined in models
|
|
expected_tables = set(Base.metadata.tables.keys())
|
|
|
|
# Check for missing tables
|
|
missing_tables = expected_tables - existing_tables
|
|
if missing_tables:
|
|
logger.error(f"Missing tables: {missing_tables}")
|
|
return False
|
|
|
|
# Verify table schemas
|
|
for table_name in existing_tables:
|
|
schema_query = text(f"PRAGMA table_info({table_name})")
|
|
result = conn.execute(schema_query)
|
|
logger.info(f"\nSchema for {table_name}:")
|
|
# SQLite PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk)
|
|
for row in result:
|
|
cid, name, type_, notnull, dflt_value, pk = row
|
|
pk_str = "PRIMARY KEY" if pk else ""
|
|
null_str = "NOT NULL" if notnull else "NULL"
|
|
default_str = f"DEFAULT {dflt_value}" if dflt_value is not None else ""
|
|
logger.info(f" - {name} ({type_}) {null_str} {default_str} {pk_str}".strip())
|
|
|
|
logger.info("All expected tables were created successfully")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Table verification failed: {str(e)}")
|
|
return False
|
|
|
|
def main():
|
|
"""Main function to handle database initialization"""
|
|
try:
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description='Initialize the VPN database')
|
|
parser.add_argument('--force', action='store_true',
|
|
help='Force database recreation')
|
|
args = parser.parse_args()
|
|
|
|
logger.info("Starting database initialization")
|
|
engine = init_db(force=args.force)
|
|
|
|
if engine is None:
|
|
return 1
|
|
|
|
# Verify tables were created correctly
|
|
if verify_tables(engine):
|
|
logger.info("Database initialization completed successfully")
|
|
return 0
|
|
else:
|
|
logger.error("Database initialization failed - tables missing")
|
|
return 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database initialization failed: {str(e)}")
|
|
return 1
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main()) |