vpn-btcpay-provisioner/app/utils/db/operations.py
2024-12-30 06:03:07 +00:00

152 lines
5.3 KiB
Python

from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from . import get_session
from .models import User, Subscription, Payment, SubscriptionStatus
import logging
import ipaddress
logger = logging.getLogger(__name__)
class DatabaseManager:
@staticmethod
def get_next_available_ip():
"""Get the next available IP from the WireGuard subnet"""
with get_session() as session:
# Get all assigned IPs
assigned_ips = session.query(Subscription.assigned_ip)\
.filter(Subscription.assigned_ip.isnot(None))\
.all()
assigned_ips = [ip[0] for ip in assigned_ips]
# Start from 10.8.0.2 (10.8.0.1 is server)
network = ipaddress.IPv4Network('10.8.0.0/24')
for ip in network.hosts():
str_ip = str(ip)
if str_ip == '10.8.0.1': # Skip server IP
continue
if str_ip not in assigned_ips:
return str_ip
raise Exception("No available IPs in the subnet")
@staticmethod
def create_user(user_id):
"""Create a new user with UUID"""
with get_session() as session:
user = User(user_id=user_id)
session.add(user)
session.commit()
return user
@staticmethod
def get_user_by_uuid(user_id):
"""Get user by UUID"""
with get_session() as session:
return session.query(User).filter(User.user_id == user_id).first()
@staticmethod
def get_subscription_by_invoice(invoice_id):
"""Get subscription by invoice ID"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
@staticmethod
def create_subscription(user_id, invoice_id, public_key, duration_hours):
"""Create a new subscription"""
with get_session() as session:
try:
# Get user or create if doesn't exist
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
user = DatabaseManager.create_user(user_id)
start_time = datetime.utcnow()
expiry_time = start_time + datetime.timedelta(hours=duration_hours)
# Get next available IP
assigned_ip = DatabaseManager.get_next_available_ip()
subscription = Subscription(
user_id=user.id,
invoice_id=invoice_id,
public_key=public_key,
start_time=start_time,
expiry_time=expiry_time,
status=SubscriptionStatus.PENDING,
assigned_ip=assigned_ip
)
session.add(subscription)
session.commit()
return subscription
except Exception as e:
logger.error(f"Error creating subscription: {str(e)}")
session.rollback()
raise
@staticmethod
def activate_subscription(invoice_id):
"""Activate a subscription after payment"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.ACTIVE
session.commit()
return subscription
@staticmethod
def record_payment(user_id, subscription_id, invoice_id, amount):
"""Record a payment"""
with get_session() as session:
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
raise ValueError(f"User {user_id} not found")
payment = Payment(
user_id=user.id,
subscription_id=subscription_id,
invoice_id=invoice_id,
amount=amount
)
session.add(payment)
session.commit()
return payment
@staticmethod
def get_active_subscriptions():
"""Get all active subscriptions"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
.all()
@staticmethod
def expire_subscription(subscription_id):
"""Mark a subscription as expired"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.EXPIRED
session.commit()
return True
return False
@staticmethod
def update_warning_sent(subscription_id):
"""Update the warning_sent flag for a subscription"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.warning_sent = 1
session.commit()
return True
return False