from datetime import datetime, timedelta from sqlalchemy.exc import SQLAlchemyError from . import get_session from .models import User, Subscription, Payment, SubscriptionStatus import logging import ipaddress from .models import User, Subscription, Payment, ProvisionLog, SubscriptionStatus logger = logging.getLogger(__name__) class DatabaseManager: @staticmethod def get_next_available_ip(): """Get the next available IP from the WireGuard subnet""" with get_session() as session: assigned_ips = session.query(Subscription.assigned_ip)\ .filter(Subscription.assigned_ip.isnot(None))\ .all() assigned_ips = [ip[0] for ip in assigned_ips] network = ipaddress.IPv4Network('10.8.0.0/24') for ip in network.hosts(): str_ip = str(ip) if str_ip == '10.8.0.1': # Skip server IP continue if str_ip not in assigned_ips: return str_ip raise Exception("No available IPs in the subnet") @staticmethod def create_user(user_id): """Create a new user with UUID""" with get_session() as session: user = User(user_id=user_id) session.add(user) session.commit() return user @staticmethod def get_user_by_uuid(user_id): """Get user by UUID""" with get_session() as session: return session.query(User).filter(User.user_id == user_id).first() @staticmethod def get_subscription_by_invoice(invoice_id): """Get subscription by invoice ID""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.invoice_id == invoice_id)\ .first() @staticmethod def get_subscription_by_id(subscription_id): """Get subscription by ID""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.id == subscription_id)\ .first() @staticmethod def get_user_subscriptions(user_id): """Get all subscriptions for a user""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.user_id == user_id)\ .order_by(Subscription.start_time.desc())\ .all() @staticmethod def get_active_subscription_for_user(user_id): """Get active subscription for a user""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.user_id == user_id)\ .filter(Subscription.status == SubscriptionStatus.ACTIVE)\ .first() @staticmethod def create_subscription(user_id, invoice_id, public_key, duration_hours): """Create a new subscription""" with get_session() as session: try: user = session.query(User).filter(User.user_id == user_id).first() if not user: user = User(user_id=user_id) session.add(user) session.flush() start_time = datetime.utcnow() expiry_time = start_time + timedelta(hours=duration_hours) assigned_ip = DatabaseManager.get_next_available_ip() subscription = Subscription( user_id=user.id, invoice_id=invoice_id, public_key=public_key, start_time=start_time, expiry_time=expiry_time, status=SubscriptionStatus.PENDING, assigned_ip=assigned_ip ) session.add(subscription) session.commit() return { 'id': subscription.id, 'user_id': user.id, 'invoice_id': subscription.invoice_id, 'public_key': subscription.public_key, 'assigned_ip': subscription.assigned_ip, 'start_time': subscription.start_time, 'expiry_time': subscription.expiry_time, 'status': subscription.status.value } except Exception as e: logger.error(f"Error creating subscription: {str(e)}") session.rollback() raise @staticmethod def activate_subscription(invoice_id): """Activate a subscription after payment""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.invoice_id == invoice_id)\ .first() if subscription: subscription.status = SubscriptionStatus.ACTIVE session.commit() return subscription @staticmethod def record_payment(user_id, subscription_id, invoice_id, amount): """Record a payment""" with get_session() as session: user = DatabaseManager.get_user_by_uuid(user_id) if not user: raise ValueError(f"User {user_id} not found") payment = Payment( user_id=user.id, subscription_id=subscription_id, invoice_id=invoice_id, amount=amount ) session.add(payment) session.commit() return payment @staticmethod def get_active_subscriptions(): """Get all active subscriptions""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.status == SubscriptionStatus.ACTIVE)\ .all() @staticmethod def expire_subscription(subscription_id): """Mark a subscription as expired""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.id == subscription_id)\ .first() if subscription: subscription.status = SubscriptionStatus.EXPIRED session.commit() return True return False @staticmethod def update_warning_sent(subscription_id): """Update the warning_sent flag for a subscription""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.id == subscription_id)\ .first() if subscription: subscription.warning_sent = 1 session.commit() return True return False @staticmethod def create_provision_log(log_data): """Create a new provision log entry""" with get_session() as session: try: provision_log = ProvisionLog( subscription_id=log_data['subscription_id'], action=log_data['action'], status=log_data['status'], ansible_output=log_data['ansible_output'], error_message=log_data.get('error_message') ) session.add(provision_log) session.commit() return provision_log except Exception as e: session.rollback() logger.error(f"Error creating provision log: {str(e)}") raise @staticmethod def get_provision_logs(subscription_id=None, limit=100): """Get provision logs, optionally filtered by subscription""" with get_session() as session: query = session.query(ProvisionLog) if subscription_id: query = query.filter(ProvisionLog.subscription_id == subscription_id) return query.order_by(ProvisionLog.timestamp.desc()).limit(limit).all()