152 lines
5.3 KiB
Python
152 lines
5.3 KiB
Python
from datetime import datetime
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from . import get_session
|
|
from .models import User, Subscription, Payment, SubscriptionStatus
|
|
import logging
|
|
import ipaddress
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DatabaseManager:
|
|
@staticmethod
|
|
def get_next_available_ip():
|
|
"""Get the next available IP from the WireGuard subnet"""
|
|
with get_session() as session:
|
|
# Get all assigned IPs
|
|
assigned_ips = session.query(Subscription.assigned_ip)\
|
|
.filter(Subscription.assigned_ip.isnot(None))\
|
|
.all()
|
|
assigned_ips = [ip[0] for ip in assigned_ips]
|
|
|
|
# Start from 10.8.0.2 (10.8.0.1 is server)
|
|
network = ipaddress.IPv4Network('10.8.0.0/24')
|
|
for ip in network.hosts():
|
|
str_ip = str(ip)
|
|
if str_ip == '10.8.0.1': # Skip server IP
|
|
continue
|
|
if str_ip not in assigned_ips:
|
|
return str_ip
|
|
|
|
raise Exception("No available IPs in the subnet")
|
|
|
|
@staticmethod
|
|
def create_user(user_id):
|
|
"""Create a new user with UUID"""
|
|
with get_session() as session:
|
|
user = User(user_id=user_id)
|
|
session.add(user)
|
|
session.commit()
|
|
return user
|
|
|
|
@staticmethod
|
|
def get_user_by_uuid(user_id):
|
|
"""Get user by UUID"""
|
|
with get_session() as session:
|
|
return session.query(User).filter(User.user_id == user_id).first()
|
|
|
|
@staticmethod
|
|
def get_subscription_by_invoice(invoice_id):
|
|
"""Get subscription by invoice ID"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.invoice_id == invoice_id)\
|
|
.first()
|
|
|
|
@staticmethod
|
|
def create_subscription(user_id, invoice_id, public_key, duration_hours):
|
|
"""Create a new subscription"""
|
|
with get_session() as session:
|
|
try:
|
|
# Get user or create if doesn't exist
|
|
user = DatabaseManager.get_user_by_uuid(user_id)
|
|
if not user:
|
|
user = DatabaseManager.create_user(user_id)
|
|
|
|
start_time = datetime.utcnow()
|
|
expiry_time = start_time + datetime.timedelta(hours=duration_hours)
|
|
|
|
# Get next available IP
|
|
assigned_ip = DatabaseManager.get_next_available_ip()
|
|
|
|
subscription = Subscription(
|
|
user_id=user.id,
|
|
invoice_id=invoice_id,
|
|
public_key=public_key,
|
|
start_time=start_time,
|
|
expiry_time=expiry_time,
|
|
status=SubscriptionStatus.PENDING,
|
|
assigned_ip=assigned_ip
|
|
)
|
|
session.add(subscription)
|
|
session.commit()
|
|
return subscription
|
|
except Exception as e:
|
|
logger.error(f"Error creating subscription: {str(e)}")
|
|
session.rollback()
|
|
raise
|
|
|
|
@staticmethod
|
|
def activate_subscription(invoice_id):
|
|
"""Activate a subscription after payment"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.invoice_id == invoice_id)\
|
|
.first()
|
|
|
|
if subscription:
|
|
subscription.status = SubscriptionStatus.ACTIVE
|
|
session.commit()
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def record_payment(user_id, subscription_id, invoice_id, amount):
|
|
"""Record a payment"""
|
|
with get_session() as session:
|
|
user = DatabaseManager.get_user_by_uuid(user_id)
|
|
if not user:
|
|
raise ValueError(f"User {user_id} not found")
|
|
|
|
payment = Payment(
|
|
user_id=user.id,
|
|
subscription_id=subscription_id,
|
|
invoice_id=invoice_id,
|
|
amount=amount
|
|
)
|
|
session.add(payment)
|
|
session.commit()
|
|
return payment
|
|
|
|
@staticmethod
|
|
def get_active_subscriptions():
|
|
"""Get all active subscriptions"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
|
|
.all()
|
|
|
|
@staticmethod
|
|
def expire_subscription(subscription_id):
|
|
"""Mark a subscription as expired"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.id == subscription_id)\
|
|
.first()
|
|
|
|
if subscription:
|
|
subscription.status = SubscriptionStatus.EXPIRED
|
|
session.commit()
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def update_warning_sent(subscription_id):
|
|
"""Update the warning_sent flag for a subscription"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.id == subscription_id)\
|
|
.first()
|
|
if subscription:
|
|
subscription.warning_sent = 1
|
|
session.commit()
|
|
return True
|
|
return False |