vpn-btcpay-provisioner/scripts/init_db.py

157 lines
5.4 KiB
Python

# scripts/init_db.py
import os
import sys
import logging
from pathlib import Path
from datetime import datetime
# Configure logging before imports
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Add the project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
# Import only what's needed for DB initialization
from sqlalchemy import create_engine, text
from app.utils.db.models import Base
def get_db_path():
"""Get the database path"""
data_dir = project_root / 'data'
data_dir.mkdir(exist_ok=True)
return data_dir / 'vpn.db'
def backup_existing_db():
"""Backup existing database if it exists"""
try:
db_path = get_db_path()
if db_path.exists():
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_path = db_path.parent / f'vpn_backup_{timestamp}.db'
db_path.rename(backup_path)
logger.info(f"Created backup at: {backup_path}")
return backup_path
return None
except Exception as e:
logger.error(f"Backup failed: {str(e)}")
return None
def init_db(force=False):
"""Initialize the database with all tables"""
try:
db_path = get_db_path()
# Check if database already exists
if db_path.exists() and not force:
logger.warning(f"Database already exists at {db_path}")
logger.warning("Use --force to recreate the database")
return None
# Backup existing database if force is True
if force and db_path.exists():
backup_existing_db()
logger.info(f"Initializing database at: {db_path}")
# Create database URL
db_url = f"sqlite:///{db_path}"
# Create engine with pragma statements for foreign keys
engine = create_engine(
db_url,
connect_args={"check_same_thread": False}
)
# Enable foreign key support using PRAGMA
with engine.connect() as conn:
conn.execute(text("PRAGMA foreign_keys = ON"))
# Create all tables
Base.metadata.create_all(engine)
logger.info("Successfully created all database tables")
# Log created tables
tables = Base.metadata.tables.keys()
logger.info("Created tables:")
for table in tables:
logger.info(f" - {table}")
return engine
except Exception as e:
logger.error(f"Database initialization failed: {str(e)}")
raise
def verify_tables(engine):
"""Verify that all tables were created correctly"""
try:
# Get list of all tables in the database
with engine.connect() as conn:
# SQLite specific query to get table info
query = text("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
result = conn.execute(query)
existing_tables = {row[0] for row in result}
# Get list of all tables defined in models
expected_tables = set(Base.metadata.tables.keys())
# Check for missing tables
missing_tables = expected_tables - existing_tables
if missing_tables:
logger.error(f"Missing tables: {missing_tables}")
return False
# Verify table schemas
for table_name in existing_tables:
schema_query = text(f"PRAGMA table_info({table_name})")
result = conn.execute(schema_query)
logger.info(f"\nSchema for {table_name}:")
# SQLite PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk)
for row in result:
cid, name, type_, notnull, dflt_value, pk = row
pk_str = "PRIMARY KEY" if pk else ""
null_str = "NOT NULL" if notnull else "NULL"
default_str = f"DEFAULT {dflt_value}" if dflt_value is not None else ""
logger.info(f" - {name} ({type_}) {null_str} {default_str} {pk_str}".strip())
logger.info("All expected tables were created successfully")
return True
except Exception as e:
logger.error(f"Table verification failed: {str(e)}")
return False
def main():
"""Main function to handle database initialization"""
try:
import argparse
parser = argparse.ArgumentParser(description='Initialize the VPN database')
parser.add_argument('--force', action='store_true',
help='Force database recreation')
args = parser.parse_args()
logger.info("Starting database initialization")
engine = init_db(force=args.force)
if engine is None:
return 1
# Verify tables were created correctly
if verify_tables(engine):
logger.info("Database initialization completed successfully")
return 0
else:
logger.error("Database initialization failed - tables missing")
return 1
except Exception as e:
logger.error(f"Database initialization failed: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())