217 lines
8.0 KiB
Python
217 lines
8.0 KiB
Python
from datetime import datetime, timedelta
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from . import get_session
|
|
from .models import User, Subscription, Payment, SubscriptionStatus
|
|
import logging
|
|
import ipaddress
|
|
from .models import User, Subscription, Payment, ProvisionLog, SubscriptionStatus
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DatabaseManager:
|
|
@staticmethod
|
|
def get_next_available_ip():
|
|
"""Get the next available IP from the WireGuard subnet"""
|
|
with get_session() as session:
|
|
assigned_ips = session.query(Subscription.assigned_ip)\
|
|
.filter(Subscription.assigned_ip.isnot(None))\
|
|
.all()
|
|
assigned_ips = [ip[0] for ip in assigned_ips]
|
|
|
|
network = ipaddress.IPv4Network('10.8.0.0/24')
|
|
for ip in network.hosts():
|
|
str_ip = str(ip)
|
|
if str_ip == '10.8.0.1': # Skip server IP
|
|
continue
|
|
if str_ip not in assigned_ips:
|
|
return str_ip
|
|
|
|
raise Exception("No available IPs in the subnet")
|
|
|
|
@staticmethod
|
|
def create_user(user_id):
|
|
"""Create a new user with UUID"""
|
|
with get_session() as session:
|
|
user = User(user_id=user_id)
|
|
session.add(user)
|
|
session.commit()
|
|
return user
|
|
|
|
@staticmethod
|
|
def get_user_by_uuid(user_id):
|
|
"""Get user by UUID"""
|
|
with get_session() as session:
|
|
return session.query(User).filter(User.user_id == user_id).first()
|
|
|
|
@staticmethod
|
|
def get_subscription_by_invoice(invoice_id):
|
|
"""Get subscription by invoice ID"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.invoice_id == invoice_id)\
|
|
.first()
|
|
|
|
@staticmethod
|
|
def get_subscription_by_id(subscription_id):
|
|
"""Get subscription by ID"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.id == subscription_id)\
|
|
.first()
|
|
|
|
@staticmethod
|
|
def get_user_subscriptions(user_id):
|
|
"""Get all subscriptions for a user"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.user_id == user_id)\
|
|
.order_by(Subscription.start_time.desc())\
|
|
.all()
|
|
|
|
@staticmethod
|
|
def get_active_subscription_for_user(user_id):
|
|
"""Get active subscription for a user"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.user_id == user_id)\
|
|
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
|
|
.first()
|
|
|
|
@staticmethod
|
|
def create_subscription(user_id, invoice_id, public_key, duration_hours):
|
|
"""Create a new subscription"""
|
|
with get_session() as session:
|
|
try:
|
|
user = session.query(User).filter(User.user_id == user_id).first()
|
|
if not user:
|
|
user = User(user_id=user_id)
|
|
session.add(user)
|
|
session.flush()
|
|
|
|
start_time = datetime.utcnow()
|
|
expiry_time = start_time + timedelta(hours=duration_hours)
|
|
assigned_ip = DatabaseManager.get_next_available_ip()
|
|
|
|
subscription = Subscription(
|
|
user_id=user.id,
|
|
invoice_id=invoice_id,
|
|
public_key=public_key,
|
|
start_time=start_time,
|
|
expiry_time=expiry_time,
|
|
status=SubscriptionStatus.PENDING,
|
|
assigned_ip=assigned_ip
|
|
)
|
|
|
|
session.add(subscription)
|
|
session.commit()
|
|
|
|
return {
|
|
'id': subscription.id,
|
|
'user_id': user.id,
|
|
'invoice_id': subscription.invoice_id,
|
|
'public_key': subscription.public_key,
|
|
'assigned_ip': subscription.assigned_ip,
|
|
'start_time': subscription.start_time,
|
|
'expiry_time': subscription.expiry_time,
|
|
'status': subscription.status.value
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating subscription: {str(e)}")
|
|
session.rollback()
|
|
raise
|
|
|
|
@staticmethod
|
|
def activate_subscription(invoice_id):
|
|
"""Activate a subscription after payment"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.invoice_id == invoice_id)\
|
|
.first()
|
|
|
|
if subscription:
|
|
subscription.status = SubscriptionStatus.ACTIVE
|
|
session.commit()
|
|
return subscription
|
|
|
|
@staticmethod
|
|
def record_payment(user_id, subscription_id, invoice_id, amount):
|
|
"""Record a payment"""
|
|
with get_session() as session:
|
|
user = DatabaseManager.get_user_by_uuid(user_id)
|
|
if not user:
|
|
raise ValueError(f"User {user_id} not found")
|
|
|
|
payment = Payment(
|
|
user_id=user.id,
|
|
subscription_id=subscription_id,
|
|
invoice_id=invoice_id,
|
|
amount=amount
|
|
)
|
|
session.add(payment)
|
|
session.commit()
|
|
return payment
|
|
|
|
@staticmethod
|
|
def get_active_subscriptions():
|
|
"""Get all active subscriptions"""
|
|
with get_session() as session:
|
|
return session.query(Subscription)\
|
|
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
|
|
.all()
|
|
|
|
@staticmethod
|
|
def expire_subscription(subscription_id):
|
|
"""Mark a subscription as expired"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.id == subscription_id)\
|
|
.first()
|
|
|
|
if subscription:
|
|
subscription.status = SubscriptionStatus.EXPIRED
|
|
session.commit()
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def update_warning_sent(subscription_id):
|
|
"""Update the warning_sent flag for a subscription"""
|
|
with get_session() as session:
|
|
subscription = session.query(Subscription)\
|
|
.filter(Subscription.id == subscription_id)\
|
|
.first()
|
|
if subscription:
|
|
subscription.warning_sent = 1
|
|
session.commit()
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def create_provision_log(log_data):
|
|
"""Create a new provision log entry"""
|
|
with get_session() as session:
|
|
try:
|
|
provision_log = ProvisionLog(
|
|
subscription_id=log_data['subscription_id'],
|
|
action=log_data['action'],
|
|
status=log_data['status'],
|
|
ansible_output=log_data['ansible_output'],
|
|
error_message=log_data.get('error_message')
|
|
)
|
|
session.add(provision_log)
|
|
session.commit()
|
|
return provision_log
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Error creating provision log: {str(e)}")
|
|
raise
|
|
|
|
@staticmethod
|
|
def get_provision_logs(subscription_id=None, limit=100):
|
|
"""Get provision logs, optionally filtered by subscription"""
|
|
with get_session() as session:
|
|
query = session.query(ProvisionLog)
|
|
if subscription_id:
|
|
query = query.filter(ProvisionLog.subscription_id == subscription_id)
|
|
return query.order_by(ProvisionLog.timestamp.desc()).limit(limit).all() |