vpn-btcpay-provisioner/app/utils/db/operations.py
2025-01-10 08:20:21 +00:00

217 lines
8.0 KiB
Python

from datetime import datetime, timedelta
from sqlalchemy.exc import SQLAlchemyError
from . import get_session
from .models import User, Subscription, Payment, SubscriptionStatus
import logging
import ipaddress
from .models import User, Subscription, Payment, ProvisionLog, SubscriptionStatus
logger = logging.getLogger(__name__)
class DatabaseManager:
@staticmethod
def get_next_available_ip():
"""Get the next available IP from the WireGuard subnet"""
with get_session() as session:
assigned_ips = session.query(Subscription.assigned_ip)\
.filter(Subscription.assigned_ip.isnot(None))\
.all()
assigned_ips = [ip[0] for ip in assigned_ips]
network = ipaddress.IPv4Network('10.8.0.0/24')
for ip in network.hosts():
str_ip = str(ip)
if str_ip == '10.8.0.1': # Skip server IP
continue
if str_ip not in assigned_ips:
return str_ip
raise Exception("No available IPs in the subnet")
@staticmethod
def create_user(user_id):
"""Create a new user with UUID"""
with get_session() as session:
user = User(user_id=user_id)
session.add(user)
session.commit()
return user
@staticmethod
def get_user_by_uuid(user_id):
"""Get user by UUID"""
with get_session() as session:
return session.query(User).filter(User.user_id == user_id).first()
@staticmethod
def get_subscription_by_invoice(invoice_id):
"""Get subscription by invoice ID"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
@staticmethod
def get_subscription_by_id(subscription_id):
"""Get subscription by ID"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
@staticmethod
def get_user_subscriptions(user_id):
"""Get all subscriptions for a user"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.user_id == user_id)\
.order_by(Subscription.start_time.desc())\
.all()
@staticmethod
def get_active_subscription_for_user(user_id):
"""Get active subscription for a user"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.user_id == user_id)\
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
.first()
@staticmethod
def create_subscription(user_id, invoice_id, public_key, duration_hours):
"""Create a new subscription"""
with get_session() as session:
try:
user = session.query(User).filter(User.user_id == user_id).first()
if not user:
user = User(user_id=user_id)
session.add(user)
session.flush()
start_time = datetime.utcnow()
expiry_time = start_time + timedelta(hours=duration_hours)
assigned_ip = DatabaseManager.get_next_available_ip()
subscription = Subscription(
user_id=user.id,
invoice_id=invoice_id,
public_key=public_key,
start_time=start_time,
expiry_time=expiry_time,
status=SubscriptionStatus.PENDING,
assigned_ip=assigned_ip
)
session.add(subscription)
session.commit()
return {
'id': subscription.id,
'user_id': user.id,
'invoice_id': subscription.invoice_id,
'public_key': subscription.public_key,
'assigned_ip': subscription.assigned_ip,
'start_time': subscription.start_time,
'expiry_time': subscription.expiry_time,
'status': subscription.status.value
}
except Exception as e:
logger.error(f"Error creating subscription: {str(e)}")
session.rollback()
raise
@staticmethod
def activate_subscription(invoice_id):
"""Activate a subscription after payment"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.ACTIVE
session.commit()
return subscription
@staticmethod
def record_payment(user_id, subscription_id, invoice_id, amount):
"""Record a payment"""
with get_session() as session:
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
raise ValueError(f"User {user_id} not found")
payment = Payment(
user_id=user.id,
subscription_id=subscription_id,
invoice_id=invoice_id,
amount=amount
)
session.add(payment)
session.commit()
return payment
@staticmethod
def get_active_subscriptions():
"""Get all active subscriptions"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
.all()
@staticmethod
def expire_subscription(subscription_id):
"""Mark a subscription as expired"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.EXPIRED
session.commit()
return True
return False
@staticmethod
def update_warning_sent(subscription_id):
"""Update the warning_sent flag for a subscription"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.warning_sent = 1
session.commit()
return True
return False
@staticmethod
def create_provision_log(log_data):
"""Create a new provision log entry"""
with get_session() as session:
try:
provision_log = ProvisionLog(
subscription_id=log_data['subscription_id'],
action=log_data['action'],
status=log_data['status'],
ansible_output=log_data['ansible_output'],
error_message=log_data.get('error_message')
)
session.add(provision_log)
session.commit()
return provision_log
except Exception as e:
session.rollback()
logger.error(f"Error creating provision log: {str(e)}")
raise
@staticmethod
def get_provision_logs(subscription_id=None, limit=100):
"""Get provision logs, optionally filtered by subscription"""
with get_session() as session:
query = session.query(ProvisionLog)
if subscription_id:
query = query.filter(ProvisionLog.subscription_id == subscription_id)
return query.order_by(ProvisionLog.timestamp.desc()).limit(limit).all()