from datetime import datetime from sqlalchemy.exc import SQLAlchemyError from . import get_session from .models import User, Subscription, Payment, SubscriptionStatus import logging import ipaddress logger = logging.getLogger(__name__) class DatabaseManager: @staticmethod def get_next_available_ip(): """Get the next available IP from the WireGuard subnet""" with get_session() as session: # Get all assigned IPs assigned_ips = session.query(Subscription.assigned_ip)\ .filter(Subscription.assigned_ip.isnot(None))\ .all() assigned_ips = [ip[0] for ip in assigned_ips] # Start from 10.8.0.2 (10.8.0.1 is server) network = ipaddress.IPv4Network('10.8.0.0/24') for ip in network.hosts(): str_ip = str(ip) if str_ip == '10.8.0.1': # Skip server IP continue if str_ip not in assigned_ips: return str_ip raise Exception("No available IPs in the subnet") @staticmethod def create_user(user_id): """Create a new user with UUID""" with get_session() as session: user = User(user_id=user_id) session.add(user) session.commit() return user @staticmethod def get_user_by_uuid(user_id): """Get user by UUID""" with get_session() as session: return session.query(User).filter(User.user_id == user_id).first() @staticmethod def get_subscription_by_invoice(invoice_id): """Get subscription by invoice ID""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.invoice_id == invoice_id)\ .first() @staticmethod def create_subscription(user_id, invoice_id, public_key, duration_hours): """Create a new subscription""" with get_session() as session: try: # Get user or create if doesn't exist user = DatabaseManager.get_user_by_uuid(user_id) if not user: user = DatabaseManager.create_user(user_id) start_time = datetime.utcnow() expiry_time = start_time + datetime.timedelta(hours=duration_hours) # Get next available IP assigned_ip = DatabaseManager.get_next_available_ip() subscription = Subscription( user_id=user.id, invoice_id=invoice_id, public_key=public_key, start_time=start_time, expiry_time=expiry_time, status=SubscriptionStatus.PENDING, assigned_ip=assigned_ip ) session.add(subscription) session.commit() return subscription except Exception as e: logger.error(f"Error creating subscription: {str(e)}") session.rollback() raise @staticmethod def activate_subscription(invoice_id): """Activate a subscription after payment""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.invoice_id == invoice_id)\ .first() if subscription: subscription.status = SubscriptionStatus.ACTIVE session.commit() return subscription @staticmethod def record_payment(user_id, subscription_id, invoice_id, amount): """Record a payment""" with get_session() as session: user = DatabaseManager.get_user_by_uuid(user_id) if not user: raise ValueError(f"User {user_id} not found") payment = Payment( user_id=user.id, subscription_id=subscription_id, invoice_id=invoice_id, amount=amount ) session.add(payment) session.commit() return payment @staticmethod def get_active_subscriptions(): """Get all active subscriptions""" with get_session() as session: return session.query(Subscription)\ .filter(Subscription.status == SubscriptionStatus.ACTIVE)\ .all() @staticmethod def expire_subscription(subscription_id): """Mark a subscription as expired""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.id == subscription_id)\ .first() if subscription: subscription.status = SubscriptionStatus.EXPIRED session.commit() return True return False @staticmethod def update_warning_sent(subscription_id): """Update the warning_sent flag for a subscription""" with get_session() as session: subscription = session.query(Subscription)\ .filter(Subscription.id == subscription_id)\ .first() if subscription: subscription.warning_sent = 1 session.commit() return True return False