diff --git a/ansible/playbooks/vpn_provision.yml b/ansible/playbooks/vpn_provision.yml
index 4d210d8..d4dc662 100644
--- a/ansible/playbooks/vpn_provision.yml
+++ b/ansible/playbooks/vpn_provision.yml
@@ -44,7 +44,23 @@
dest: "{{ server_dir }}/public.key"
mode: '0644'
when: not server_config.stat.exists
-
+
+ - name: Update vault with server details
+ block:
+ - name: Read server public key
+ shell: "cat {{ server_dir }}/public.key"
+ register: pubkey_content
+ changed_when: false
+
+ - name: Save server details to vault
+ copy:
+ content: |
+ wireguard_server_public_key: "{{ pubkey_content.stdout }}"
+ wireguard_server_endpoint: "{{ ansible_host }}"
+ dest: "{{ playbook_dir }}/../group_vars/vpn_servers/vault.yml"
+ mode: '0600'
+ when: not server_config.stat.exists
+
- name: Create initial server config
template:
src: templates/server.conf.j2
diff --git a/app/__init__.py b/app/__init__.py
index f99f452..688ef97 100644
--- a/app/__init__.py
+++ b/app/__init__.py
@@ -1,5 +1,3 @@
-# app/__init__.py
-
from flask import Flask, request, jsonify, render_template
import logging
from .handlers.webhook_handler import handle_payment_webhook
diff --git a/app/data/vpn.db b/app/data/vpn.db
new file mode 100644
index 0000000..6b7c17d
Binary files /dev/null and b/app/data/vpn.db differ
diff --git a/app/handlers/payment_handler.py b/app/handlers/payment_handler.py
index 7fc8ec7..2bb3984 100644
--- a/app/handlers/payment_handler.py
+++ b/app/handlers/payment_handler.py
@@ -1,15 +1,11 @@
-# app/handlers/payment_handler.py
-
import logging
import requests
import os
-from flask import jsonify
-import smtplib
-from email.mime.text import MIMEText
-from email.mime.multipart import MIMEMultipart
from pathlib import Path
import traceback
from .webhook_handler import get_vault_values
+from ..utils.db.operations import DatabaseManager
+from ..utils.db.models import SubscriptionStatus
logger = logging.getLogger(__name__)
@@ -25,9 +21,6 @@ class BTCPayHandler:
self.api_key = vault_values['btcpay_api_key']
self.store_id = vault_values['btcpay_store_id']
- # Email configuration
- self.smtp_config = vault_values.get('smtp_config', {})
-
logger.info(f"BTCPayHandler initialized with base URL: {self.base_url}")
except Exception as e:
@@ -35,10 +28,16 @@ class BTCPayHandler:
logger.error(traceback.format_exc())
raise
- def create_invoice(self, amount_sats, duration_hours, email):
+ def create_invoice(self, amount_sats, duration_hours, user_id, public_key):
"""Create BTCPay invoice for VPN subscription"""
try:
- logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for {email}")
+ logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for user {user_id}")
+
+ # First, get or create user
+ user = DatabaseManager.get_user_by_uuid(user_id)
+ if not user:
+ user = DatabaseManager.create_user(user_id)
+ logger.info(f"Created new user with ID: {user_id}")
headers = {
'Authorization': f'token {self.api_key}',
@@ -54,11 +53,12 @@ class BTCPayHandler:
'currency': 'SATS',
'metadata': {
'duration_hours': duration_hours,
- 'email': email,
- 'orderId': f'vpn_sub_{duration_hours}h',
+ 'userId': user_id,
+ 'publicKey': public_key,
+ 'orderId': f'vpn_sub_{duration_hours}h'
},
'checkout': {
- 'redirectURL': f'{app_url}/payment/success',
+ 'redirectURL': f'{app_url}/payment/success?userId={user_id}',
'redirectAutomatically': True
}
}
@@ -80,11 +80,22 @@ class BTCPayHandler:
return None
invoice_data = response.json()
- logger.info(f"Successfully created invoice {invoice_data.get('id')}")
+ invoice_id = invoice_data['id']
+ logger.info(f"Successfully created invoice {invoice_id}")
+
+ # Create pending subscription
+ subscription = DatabaseManager.create_subscription(
+ user_id,
+ invoice_id,
+ public_key,
+ duration_hours
+ )
+ logger.info(f"Created pending subscription for invoice {invoice_id}")
return {
- 'invoice_id': invoice_data['id'],
- 'checkout_url': invoice_data['checkoutLink']
+ 'invoice_id': invoice_id,
+ 'checkout_url': invoice_data['checkoutLink'],
+ 'user_id': user_id
}
except Exception as e:
@@ -92,53 +103,27 @@ class BTCPayHandler:
logger.error(traceback.format_exc())
return None
- def send_confirmation_email(self, email, config_data):
- """Send VPN configuration details via email"""
+ def get_subscription_config(self, user_id):
+ """Get WireGuard configuration details for a subscription"""
try:
- logger.info(f"Sending confirmation email to {email}")
-
- if not self.smtp_config:
- logger.warning("SMTP configuration not found in vault")
- return False
-
- msg = MIMEMultipart()
- msg['From'] = self.smtp_config['sender_email']
- msg['To'] = email
- msg['Subject'] = "Your VPN Configuration"
-
- body = f"""
- Thank you for subscribing to our VPN service!
-
- Please find your WireGuard configuration below:
-
- {config_data}
-
- Installation instructions:
- 1. Install WireGuard client for your platform from https://www.wireguard.com/install/
- 2. Save the above configuration to a file named 'wg0.conf'
- 3. Import the configuration file into your WireGuard client
-
- Need help? Reply to this email for support.
- """
-
- msg.attach(MIMEText(body, 'plain'))
-
- logger.debug("Connecting to SMTP server")
- with smtplib.SMTP(
- self.smtp_config['server'],
- self.smtp_config.get('port', 587)
- ) as server:
- server.starttls()
- server.login(
- self.smtp_config['username'],
- self.smtp_config['password']
- )
- server.send_message(msg)
-
- logger.info("Confirmation email sent successfully")
- return True
-
+ logger.info(f"Fetching subscription config for user {user_id}")
+ user = DatabaseManager.get_user_by_uuid(user_id)
+ if not user:
+ logger.error(f"User {user_id} not found")
+ return None
+
+ subscription = DatabaseManager.get_active_subscription_for_user(user.id)
+ if not subscription:
+ logger.error(f"No active subscription found for user {user_id}")
+ return None
+
+ return {
+ 'serverPublicKey': os.getenv('WIREGUARD_SERVER_PUBLIC_KEY'),
+ 'serverEndpoint': os.getenv('WIREGUARD_SERVER_ENDPOINT'),
+ 'clientIp': subscription.assigned_ip
+ }
+
except Exception as e:
- logger.error("Error sending confirmation email:")
+ logger.error(f"Error getting subscription config: {str(e)}")
logger.error(traceback.format_exc())
- return False
\ No newline at end of file
+ return None
\ No newline at end of file
diff --git a/app/handlers/webhook_handler.py b/app/handlers/webhook_handler.py
index b3295f9..9791a40 100644
--- a/app/handlers/webhook_handler.py
+++ b/app/handlers/webhook_handler.py
@@ -6,11 +6,12 @@ import logging
import hmac
import hashlib
import yaml
-import json
import datetime
import traceback
from pathlib import Path
from dotenv import load_dotenv
+from ..utils.db.operations import DatabaseManager
+from ..utils.db.models import SubscriptionStatus
load_dotenv()
@@ -23,7 +24,6 @@ logger = logging.getLogger(__name__)
BASE_DIR = Path(__file__).resolve().parent.parent.parent
PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml'
CLEANUP_PLAYBOOK = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
-SUBSCRIPTION_DB = BASE_DIR / 'data' / 'subscriptions.json'
def get_vault_values():
"""Get decrypted values from Ansible vault"""
@@ -80,31 +80,6 @@ def verify_signature(payload_body, signature_header):
logger.error(f"Signature verification failed: {str(e)}")
return False
-def load_subscriptions():
- """Load subscription data from JSON file"""
- if not SUBSCRIPTION_DB.parent.exists():
- SUBSCRIPTION_DB.parent.mkdir(parents=True)
-
- if not SUBSCRIPTION_DB.exists():
- return {}
-
- with open(SUBSCRIPTION_DB, 'r') as f:
- return json.load(f)
-
-def save_subscription(subscription_data):
- """Save subscription data to JSON file"""
- subscriptions = load_subscriptions()
- sub_id = subscription_data['subscriptionId']
- subscriptions[sub_id] = subscription_data
-
- with open(SUBSCRIPTION_DB, 'w') as f:
- json.dump(subscriptions, f, indent=2)
-
-def calculate_expiry(duration_hours):
- """Calculate expiry date based on subscription duration"""
- return (datetime.datetime.now() +
- datetime.timedelta(hours=duration_hours)).isoformat()
-
def run_ansible_playbook(invoice_id):
"""Run the VPN provisioning playbook"""
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
@@ -142,12 +117,15 @@ def handle_subscription_status(data):
logger.info(f"Processing subscription status update: {sub_id} -> {status}")
- # Store subscription data
- data['last_updated'] = datetime.datetime.now().isoformat()
- save_subscription(data)
-
- if status != 'Active':
- # Run cleanup playbook for inactive subscriptions
+ subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
+ if not subscription:
+ logger.error(f"Subscription {sub_id} not found")
+ return jsonify({"error": "Subscription not found"}), 404
+
+ if status == 'Active':
+ DatabaseManager.activate_subscription(sub_id)
+ else:
+ # Run cleanup for inactive subscriptions
result = subprocess.run([
'ansible-playbook',
str(CLEANUP_PLAYBOOK),
@@ -158,7 +136,8 @@ def handle_subscription_status(data):
if result.returncode != 0:
logger.error(f"Failed to clean up subscription {sub_id}: {result.stderr}")
-
+
+ DatabaseManager.expire_subscription(subscription.id)
logger.info(f"Subscription {sub_id} is no longer active")
return jsonify({
@@ -171,9 +150,10 @@ def handle_subscription_renewal(data):
sub_id = data['subscriptionId']
logger.info(f"Processing subscription renewal request: {sub_id}")
- # Update subscription data
- data['renewal_requested'] = datetime.datetime.now().isoformat()
- save_subscription(data)
+ subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
+ if not subscription:
+ logger.error(f"Subscription {sub_id} not found")
+ return jsonify({"error": "Subscription not found"}), 404
# TODO: Send renewal notification to user
return jsonify({
@@ -199,6 +179,16 @@ def handle_payment_webhook(request):
data = request.json
logger.info(f"Received webhook data: {data}")
+
+ # Handle test webhooks
+ invoice_id = data.get('invoiceId', '')
+ if invoice_id.startswith('__test__'):
+ logger.info(f"Received test webhook, acknowledging: {data.get('type')}")
+ return jsonify({
+ "status": "success",
+ "message": "Test webhook acknowledged"
+ })
+
webhook_type = data.get('type')
if webhook_type == 'SubscriptionStatusUpdated':
@@ -207,20 +197,13 @@ def handle_payment_webhook(request):
elif webhook_type == 'SubscriptionRenewalRequested':
return handle_subscription_renewal(data)
- elif webhook_type == 'InvoiceSettled':
- # Handle regular invoice payment
+ elif webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']:
invoice_id = data.get('invoiceId')
- metadata = data.get('metadata', {})
-
if not invoice_id:
logger.error("Missing invoiceId in webhook data")
return jsonify({"error": "Missing invoiceId"}), 400
-
- if invoice_id.startswith('__test__') and invoice_id.endswith('__test__'):
- invoice_id = invoice_id[8:-8]
- logger.info(f"Stripped test markers from invoice ID: {invoice_id}")
- # Run Ansible playbook with enhanced logging
+ # Get subscription and run Ansible playbook
logger.info(f"Starting VPN provisioning for invoice {invoice_id}")
result = run_ansible_playbook(invoice_id)
@@ -236,46 +219,18 @@ def handle_payment_webhook(request):
"stderr": result.stderr
}), 500
+ # Get subscription and activate it
+ subscription = DatabaseManager.get_subscription_by_invoice(invoice_id)
+ if subscription:
+ subscription = DatabaseManager.activate_subscription(invoice_id)
+ DatabaseManager.record_payment(
+ subscription.user_id,
+ subscription.id,
+ invoice_id,
+ data.get('amount', 0)
+ )
+
logger.info(f"VPN provisioning completed for invoice {invoice_id}")
-
- # Update subscription database
- try:
- duration_hours = metadata.get('duration_hours', 24)
- subscriptions = load_subscriptions()
- subscriptions[invoice_id] = {
- 'email': metadata.get('email'),
- 'duration_hours': duration_hours,
- 'start_time': datetime.datetime.now().isoformat(),
- 'expiry': calculate_expiry(duration_hours),
- 'status': 'Active'
- }
- with open(SUBSCRIPTION_DB, 'w') as f:
- json.dump(subscriptions, f, indent=2)
- logger.info(f"Updated subscription database for invoice {invoice_id}")
- except Exception as e:
- logger.error(f"Error updating subscription database: {str(e)}")
-
- # Send email confirmation if email is provided
- try:
- email = metadata.get('email')
- if email:
- config_path = f"/etc/wireguard/clients/{invoice_id}/wg0.conf"
- if os.path.exists(config_path):
- with open(config_path, 'r') as f:
- config_data = f.read()
-
- btcpay_handler = BTCPayHandler()
- email_sent = btcpay_handler.send_confirmation_email(email, config_data)
- if email_sent:
- logger.info(f"Sent configuration email to {email}")
- else:
- logger.warning(f"Failed to send configuration email to {email}")
- else:
- logger.warning(f"Config file not found at {config_path}")
- except Exception as e:
- logger.error(f"Error sending confirmation email: {str(e)}")
-
- logger.info(f"Successfully processed invoice {invoice_id}")
return jsonify({
"status": "success",
"invoice_id": invoice_id,
@@ -283,7 +238,6 @@ def handle_payment_webhook(request):
})
else:
- # Log other webhook types as info instead of warning
logger.info(f"Received {webhook_type} webhook - no action required")
return jsonify({
"status": "success",
diff --git a/app/static/js/components/WireGuardPayment.jsx b/app/static/js/components/WireGuardPayment.jsx
new file mode 100644
index 0000000..30c065b
--- /dev/null
+++ b/app/static/js/components/WireGuardPayment.jsx
@@ -0,0 +1,213 @@
+import React, { useState, useEffect } from 'react';
+import { generateKeys } from '../utils/wireguard';
+import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/card';
+import { Button } from '@/components/ui/button';
+import { Input } from '@/components/ui/input';
+import { AlertCircle } from 'lucide-react';
+import { Alert, AlertDescription } from '@/components/ui/alert';
+
+const WireGuardPayment = () => {
+ const [keyData, setKeyData] = useState(null);
+ const [duration, setDuration] = useState(24);
+ const [price, setPrice] = useState(0);
+ const [loading, setLoading] = useState(false);
+ const [error, setError] = useState('');
+ const [userId, setUserId] = useState('');
+
+ useEffect(() => {
+ // Generate random userId and keys when component mounts
+ const init = async () => {
+ try {
+ const randomId = crypto.randomUUID();
+ setUserId(randomId);
+ const keys = await generateKeys();
+ setKeyData(keys);
+ calculatePrice(duration);
+ } catch (err) {
+ setError('Failed to initialize keys');
+ console.error(err);
+ }
+ };
+ init();
+ }, []);
+
+ const calculatePrice = async (hours) => {
+ try {
+ const response = await fetch('/api/calculate-price', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({ hours: parseInt(hours) })
+ });
+ const data = await response.json();
+ setPrice(data.price);
+ } catch (err) {
+ setError('Failed to calculate price');
+ }
+ };
+
+ const handleDurationChange = (event) => {
+ const newDuration = parseInt(event.target.value);
+ setDuration(newDuration);
+ calculatePrice(newDuration);
+ };
+
+ const handlePayment = async () => {
+ if (!keyData) {
+ setError('No keys generated. Please refresh the page.');
+ return;
+ }
+
+ try {
+ setLoading(true);
+ const response = await fetch('/create-invoice', {
+ method: 'POST',
+ headers: { 'Content-Type': 'application/json' },
+ body: JSON.stringify({
+ duration,
+ userId,
+ publicKey: keyData.publicKey,
+ // Don't send private key to server!
+ configuration: {
+ type: 'wireguard',
+ publicKey: keyData.publicKey
+ }
+ })
+ });
+
+ if (!response.ok) {
+ throw new Error('Failed to create payment');
+ }
+
+ const data = await response.json();
+
+ // Save private key to localStorage before redirecting
+ localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify({
+ privateKey: keyData.privateKey,
+ publicKey: keyData.publicKey,
+ createdAt: new Date().toISOString()
+ }));
+
+ window.location.href = data.checkout_url;
+ } catch (err) {
+ setError('Failed to initiate payment');
+ console.error(err);
+ } finally {
+ setLoading(false);
+ }
+ };
+
+ const handleRegenerateKeys = async () => {
+ try {
+ const keys = await generateKeys();
+ setKeyData(keys);
+ } catch (err) {
+ setError('Failed to regenerate keys');
+ }
+ };
+
+ return (
+
+
+ WireGuard VPN Configuration
+
+
+ {error && (
+
+
+ {error}
+
+ )}
+
+
+
+
+
+ Save this ID - you'll need it to manage your subscription
+
+
+
+ {keyData && (
+
+
+
+
+
+ )}
+
+
+
+
+
+
+
+
+
+
+
{duration} hours
+
+
+
+
+
+
+
+
+
• Keys are generated securely in your browser
+
• Your private key never leaves your device
+
• Configuration will be available after payment
+
+
+
+ );
+};
+
+export default WireGuardPayment;
\ No newline at end of file
diff --git a/app/static/js/pricing.js b/app/static/js/pricing.js
index 5169466..698d9a7 100644
--- a/app/static/js/pricing.js
+++ b/app/static/js/pricing.js
@@ -1,11 +1,39 @@
-// app/static/js/pricing.js
-document.addEventListener('DOMContentLoaded', function() {
+// Base64 encoding/decoding utilities
+const b64 = {
+ encode: array => btoa(String.fromCharCode.apply(null, array)),
+ decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
+};
+
+async function generateKeyPair() {
+ const keyPair = await window.crypto.subtle.generateKey(
+ {
+ name: 'X25519',
+ namedCurve: 'X25519',
+ },
+ true,
+ ['deriveKey', 'deriveBits']
+ );
+
+ const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
+ const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
+
+ return {
+ privateKey: b64.encode(new Uint8Array(privateKey)),
+ publicKey: b64.encode(new Uint8Array(publicKey))
+ };
+}
+
+document.addEventListener('DOMContentLoaded', async function() {
const form = document.getElementById('subscription-form');
const slider = document.getElementById('duration-slider');
const durationDisplay = document.getElementById('duration-display');
const priceDisplay = document.getElementById('price-display');
const presetButtons = document.querySelectorAll('.duration-preset');
- const emailInput = document.getElementById('email');
+ const userIdInput = document.getElementById('user-id');
+ const publicKeyInput = document.getElementById('public-key');
+ const regenerateButton = document.getElementById('regenerate-keys');
+
+ let currentKeyPair = null;
function formatDuration(hours) {
if (hours < 24) return `${hours} hour${hours === 1 ? '' : 's'}`;
@@ -14,6 +42,32 @@ document.addEventListener('DOMContentLoaded', function() {
return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`;
}
+ async function generateNewKeys() {
+ try {
+ currentKeyPair = await generateKeyPair();
+ publicKeyInput.value = currentKeyPair.publicKey;
+
+ // Save private key to localStorage
+ const keyData = {
+ privateKey: currentKeyPair.privateKey,
+ publicKey: currentKeyPair.publicKey,
+ createdAt: new Date().toISOString()
+ };
+ localStorage.setItem(`vpn_keys_${userIdInput.value}`, JSON.stringify(keyData));
+ } catch (error) {
+ console.error('Failed to generate keys:', error);
+ alert('Failed to generate WireGuard keys. Please try again.');
+ }
+ }
+
+ async function initializeForm() {
+ // Generate user ID
+ userIdInput.value = crypto.randomUUID();
+
+ // Generate initial keys
+ await generateNewKeys();
+ }
+
async function updatePrice(hours) {
try {
const response = await fetch('/api/calculate-price', {
@@ -29,7 +83,27 @@ document.addEventListener('DOMContentLoaded', function() {
}
}
- async function createInvoice(duration, email) {
+ // Event listeners
+ slider.addEventListener('input', () => updatePrice(slider.value));
+
+ regenerateButton.addEventListener('click', generateNewKeys);
+
+ presetButtons.forEach(button => {
+ button.addEventListener('click', (e) => {
+ const hours = e.target.dataset.hours;
+ slider.value = hours;
+ updatePrice(hours);
+ });
+ });
+
+ form.addEventListener('submit', async (e) => {
+ e.preventDefault();
+
+ if (!currentKeyPair) {
+ alert('No keys generated. Please refresh the page.');
+ return;
+ }
+
try {
const response = await fetch('/create-invoice', {
method: 'POST',
@@ -37,8 +111,9 @@ document.addEventListener('DOMContentLoaded', function() {
'Content-Type': 'application/json'
},
body: JSON.stringify({
- duration: parseInt(duration),
- email: email
+ duration: parseInt(slider.value),
+ userId: userIdInput.value,
+ publicKey: currentKeyPair.publicKey
})
});
@@ -53,32 +128,10 @@ document.addEventListener('DOMContentLoaded', function() {
console.error('Error creating invoice:', error);
alert('Failed to create payment invoice. Please try again.');
}
- }
-
- // Event listeners
- slider.addEventListener('input', () => updatePrice(slider.value));
-
- presetButtons.forEach(button => {
- button.addEventListener('click', (e) => {
- const hours = e.target.dataset.hours;
- slider.value = hours;
- updatePrice(hours);
- });
- });
-
- form.addEventListener('submit', async (e) => {
- e.preventDefault();
-
- const email = emailInput.value.trim();
- if (!email) {
- alert('Please enter your email address');
- return;
- }
-
- const duration = slider.value;
- await createInvoice(duration, email);
});
+ // Initialize the form
+ await initializeForm();
// Initial price calculation
updatePrice(slider.value);
});
\ No newline at end of file
diff --git a/app/static/js/utils/wireguard.js b/app/static/js/utils/wireguard.js
new file mode 100644
index 0000000..21985ba
--- /dev/null
+++ b/app/static/js/utils/wireguard.js
@@ -0,0 +1,54 @@
+// Base64 encoding/decoding utilities
+const b64 = {
+ encode: array => btoa(String.fromCharCode.apply(null, array)),
+ decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
+ };
+
+ async function generateKeyPair() {
+ // Generate a random key pair using Web Crypto API
+ const keyPair = await window.crypto.subtle.generateKey(
+ {
+ name: 'X25519',
+ namedCurve: 'X25519',
+ },
+ true,
+ ['deriveKey', 'deriveBits']
+ );
+
+ // Export keys in raw format
+ const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
+ const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
+
+ // Convert to base64
+ return {
+ privateKey: b64.encode(new Uint8Array(privateKey)),
+ publicKey: b64.encode(new Uint8Array(publicKey))
+ };
+ }
+
+ export async function generateWireGuardConfig(serverPublicKey, serverEndpoint, address) {
+ const keys = await generateKeyPair();
+
+ return {
+ keys,
+ config: `[Interface]
+ PrivateKey = ${keys.privateKey}
+ Address = ${address}
+ DNS = 1.1.1.1
+
+ [Peer]
+ PublicKey = ${serverPublicKey}
+ Endpoint = ${serverEndpoint}
+ AllowedIPs = 0.0.0.0/0
+ PersistentKeepalive = 25`
+ };
+ }
+
+ export async function generateKeys() {
+ try {
+ return await generateKeyPair();
+ } catch (error) {
+ console.error('Failed to generate WireGuard keys:', error);
+ throw error;
+ }
+ }
\ No newline at end of file
diff --git a/app/templates/index.html b/app/templates/index.html
index 2125312..284d890 100644
--- a/app/templates/index.html
+++ b/app/templates/index.html
@@ -1,30 +1,45 @@
{% extends "base.html" %}
-
{% block content %}
Subscribe to VPN Service
-
diff --git a/app/templates/payment_success.html b/app/templates/payment_success.html
index 05bd9ba..d309da0 100644
--- a/app/templates/payment_success.html
+++ b/app/templates/payment_success.html
@@ -2,17 +2,110 @@
{% block content %}
-
-
Payment Successful!
-
- Thank you for your payment. Your VPN configuration will be sent to your email shortly.
-
-
- Please check your email for further instructions on setting up your VPN connection.
-
-
- Return to Home
-
+
+
Payment Successful!
+
+
+
+ Your VPN subscription is now active. Please follow the instructions below to set up your VPN connection.
+
+
+
+
+
+
+
Your WireGuard Configuration
+
+
+
+
+
+
+
Installation Instructions
+
+ - Download WireGuard for your platform:
+
+
+ - Create a new tunnel in WireGuard
+ - Copy the configuration above and paste it into the new tunnel
+ - Activate the tunnel to connect
+
+
+
+
+
• Save your configuration securely - you'll need it to reconnect
+
• Your private key is stored in your browser's local storage
+
• For security, clear your browser data after saving the configuration
+
+
+
+
+
+
{% endblock %}
\ No newline at end of file
diff --git a/app/utils/db/__init__.py b/app/utils/db/__init__.py
new file mode 100644
index 0000000..8fdadd7
--- /dev/null
+++ b/app/utils/db/__init__.py
@@ -0,0 +1,23 @@
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from pathlib import Path
+
+def get_db_path():
+ base_dir = Path(__file__).resolve().parent.parent.parent
+ data_dir = base_dir / 'data'
+ data_dir.mkdir(exist_ok=True)
+ return data_dir / 'vpn.db'
+
+def init_db():
+ """Initialize the database"""
+ from .models import Base
+ db_url = f"sqlite:///{get_db_path()}"
+ engine = create_engine(db_url)
+ Base.metadata.create_all(engine)
+ return engine
+
+def get_session():
+ """Get a database session"""
+ engine = init_db()
+ Session = sessionmaker(bind=engine)
+ return Session()
\ No newline at end of file
diff --git a/app/utils/db/models.py b/app/utils/db/models.py
new file mode 100644
index 0000000..9390fda
--- /dev/null
+++ b/app/utils/db/models.py
@@ -0,0 +1,52 @@
+from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Enum, Text
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import relationship
+import enum
+import datetime
+
+Base = declarative_base()
+
+class SubscriptionStatus(enum.Enum):
+ ACTIVE = "active"
+ EXPIRED = "expired"
+ PENDING = "pending"
+ CANCELLED = "cancelled"
+
+class User(Base):
+ __tablename__ = 'users'
+
+ id = Column(Integer, primary_key=True)
+ user_id = Column(String, unique=True, nullable=False) # UUID generated in frontend
+ created_at = Column(DateTime, default=datetime.datetime.utcnow)
+
+ subscriptions = relationship("Subscription", back_populates="user")
+ payments = relationship("Payment", back_populates="user")
+
+class Subscription(Base):
+ __tablename__ = 'subscriptions'
+
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('users.id'))
+ invoice_id = Column(String, unique=True)
+ public_key = Column(Text, nullable=False) # WireGuard public key
+ start_time = Column(DateTime, nullable=False)
+ expiry_time = Column(DateTime, nullable=False)
+ status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.PENDING)
+ warning_sent = Column(Integer, default=0)
+ assigned_ip = Column(String) # WireGuard IP address assigned to this subscription
+
+ user = relationship("User", back_populates="subscriptions")
+ payments = relationship("Payment", back_populates="subscription")
+
+class Payment(Base):
+ __tablename__ = 'payments'
+
+ id = Column(Integer, primary_key=True)
+ user_id = Column(Integer, ForeignKey('users.id'))
+ subscription_id = Column(Integer, ForeignKey('subscriptions.id'))
+ invoice_id = Column(String, unique=True)
+ amount = Column(Integer, nullable=False) # Amount in sats
+ timestamp = Column(DateTime, default=datetime.datetime.utcnow)
+
+ user = relationship("User", back_populates="payments")
+ subscription = relationship("Subscription", back_populates="payments")
\ No newline at end of file
diff --git a/app/utils/db/operations.py b/app/utils/db/operations.py
new file mode 100644
index 0000000..e5f00f5
--- /dev/null
+++ b/app/utils/db/operations.py
@@ -0,0 +1,152 @@
+from datetime import datetime
+from sqlalchemy.exc import SQLAlchemyError
+from . import get_session
+from .models import User, Subscription, Payment, SubscriptionStatus
+import logging
+import ipaddress
+
+logger = logging.getLogger(__name__)
+
+class DatabaseManager:
+ @staticmethod
+ def get_next_available_ip():
+ """Get the next available IP from the WireGuard subnet"""
+ with get_session() as session:
+ # Get all assigned IPs
+ assigned_ips = session.query(Subscription.assigned_ip)\
+ .filter(Subscription.assigned_ip.isnot(None))\
+ .all()
+ assigned_ips = [ip[0] for ip in assigned_ips]
+
+ # Start from 10.8.0.2 (10.8.0.1 is server)
+ network = ipaddress.IPv4Network('10.8.0.0/24')
+ for ip in network.hosts():
+ str_ip = str(ip)
+ if str_ip == '10.8.0.1': # Skip server IP
+ continue
+ if str_ip not in assigned_ips:
+ return str_ip
+
+ raise Exception("No available IPs in the subnet")
+
+ @staticmethod
+ def create_user(user_id):
+ """Create a new user with UUID"""
+ with get_session() as session:
+ user = User(user_id=user_id)
+ session.add(user)
+ session.commit()
+ return user
+
+ @staticmethod
+ def get_user_by_uuid(user_id):
+ """Get user by UUID"""
+ with get_session() as session:
+ return session.query(User).filter(User.user_id == user_id).first()
+
+ @staticmethod
+ def get_subscription_by_invoice(invoice_id):
+ """Get subscription by invoice ID"""
+ with get_session() as session:
+ return session.query(Subscription)\
+ .filter(Subscription.invoice_id == invoice_id)\
+ .first()
+
+ @staticmethod
+ def create_subscription(user_id, invoice_id, public_key, duration_hours):
+ """Create a new subscription"""
+ with get_session() as session:
+ try:
+ # Get user or create if doesn't exist
+ user = DatabaseManager.get_user_by_uuid(user_id)
+ if not user:
+ user = DatabaseManager.create_user(user_id)
+
+ start_time = datetime.utcnow()
+ expiry_time = start_time + datetime.timedelta(hours=duration_hours)
+
+ # Get next available IP
+ assigned_ip = DatabaseManager.get_next_available_ip()
+
+ subscription = Subscription(
+ user_id=user.id,
+ invoice_id=invoice_id,
+ public_key=public_key,
+ start_time=start_time,
+ expiry_time=expiry_time,
+ status=SubscriptionStatus.PENDING,
+ assigned_ip=assigned_ip
+ )
+ session.add(subscription)
+ session.commit()
+ return subscription
+ except Exception as e:
+ logger.error(f"Error creating subscription: {str(e)}")
+ session.rollback()
+ raise
+
+ @staticmethod
+ def activate_subscription(invoice_id):
+ """Activate a subscription after payment"""
+ with get_session() as session:
+ subscription = session.query(Subscription)\
+ .filter(Subscription.invoice_id == invoice_id)\
+ .first()
+
+ if subscription:
+ subscription.status = SubscriptionStatus.ACTIVE
+ session.commit()
+ return subscription
+
+ @staticmethod
+ def record_payment(user_id, subscription_id, invoice_id, amount):
+ """Record a payment"""
+ with get_session() as session:
+ user = DatabaseManager.get_user_by_uuid(user_id)
+ if not user:
+ raise ValueError(f"User {user_id} not found")
+
+ payment = Payment(
+ user_id=user.id,
+ subscription_id=subscription_id,
+ invoice_id=invoice_id,
+ amount=amount
+ )
+ session.add(payment)
+ session.commit()
+ return payment
+
+ @staticmethod
+ def get_active_subscriptions():
+ """Get all active subscriptions"""
+ with get_session() as session:
+ return session.query(Subscription)\
+ .filter(Subscription.status == SubscriptionStatus.ACTIVE)\
+ .all()
+
+ @staticmethod
+ def expire_subscription(subscription_id):
+ """Mark a subscription as expired"""
+ with get_session() as session:
+ subscription = session.query(Subscription)\
+ .filter(Subscription.id == subscription_id)\
+ .first()
+
+ if subscription:
+ subscription.status = SubscriptionStatus.EXPIRED
+ session.commit()
+ return True
+ return False
+
+ @staticmethod
+ def update_warning_sent(subscription_id):
+ """Update the warning_sent flag for a subscription"""
+ with get_session() as session:
+ subscription = session.query(Subscription)\
+ .filter(Subscription.id == subscription_id)\
+ .first()
+ if subscription:
+ subscription.warning_sent = 1
+ session.commit()
+ return True
+ return False
\ No newline at end of file
diff --git a/app/utils/db/utils.py b/app/utils/db/utils.py
new file mode 100644
index 0000000..17f1e58
--- /dev/null
+++ b/app/utils/db/utils.py
@@ -0,0 +1,59 @@
+import sqlite3
+import shutil
+from datetime import datetime
+from pathlib import Path
+import logging
+from .. import get_db_path
+
+logger = logging.getLogger(__name__)
+
+def backup_database():
+ """Create a backup of the SQLite database"""
+ try:
+ db_path = get_db_path()
+ if not db_path.exists():
+ logger.error("Database file not found")
+ return False
+
+ backup_dir = db_path.parent / 'backups'
+ backup_dir.mkdir(exist_ok=True)
+
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+ backup_path = backup_dir / f'vpn_db_backup_{timestamp}.db'
+
+ shutil.copy2(db_path, backup_path)
+ logger.info(f"Database backed up successfully to {backup_path}")
+
+ # Keep only last 5 backups
+ backups = sorted(backup_dir.glob('vpn_db_backup_*.db'))
+ if len(backups) > 5:
+ for old_backup in backups[:-5]:
+ old_backup.unlink()
+ logger.info(f"Removed old backup: {old_backup}")
+
+ return True
+
+ except Exception as e:
+ logger.error(f"Backup failed: {str(e)}")
+ return False
+
+def verify_database():
+ """Check database integrity"""
+ try:
+ db_path = get_db_path()
+ conn = sqlite3.connect(db_path)
+ cursor = conn.cursor()
+
+ cursor.execute("PRAGMA integrity_check")
+ result = cursor.fetchone()[0]
+
+ if result != "ok":
+ logger.error(f"Database integrity check failed: {result}")
+ return False
+
+ logger.info("Database integrity verified")
+ return True
+
+ except Exception as e:
+ logger.error(f"Database verification failed: {str(e)}")
+ return False
\ No newline at end of file
diff --git a/data/vpn.db b/data/vpn.db
new file mode 100644
index 0000000..6ab8f2a
Binary files /dev/null and b/data/vpn.db differ
diff --git a/requirements.txt b/requirements.txt
index 99ca526..5e36ecb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,4 +3,5 @@ pyyaml==6.0.1
python-dotenv==1.0.0
cryptography==41.0.7
ansible==9.1.0
-requests==2.31.0
\ No newline at end of file
+requests==2.31.0
+SQLAlchemy==2.0.25
\ No newline at end of file
diff --git a/scripts/init_db.py b/scripts/init_db.py
new file mode 100644
index 0000000..cb7649e
--- /dev/null
+++ b/scripts/init_db.py
@@ -0,0 +1,28 @@
+# scripts/init_db.py
+import sys
+from pathlib import Path
+
+# Add the project root to Python path
+project_root = Path(__file__).resolve().parent.parent
+sys.path.append(str(project_root))
+
+from app.utils.db.models import Base
+from sqlalchemy import create_engine
+
+def init_db():
+ # Create data directory if it doesn't exist
+ data_dir = project_root / 'data'
+ data_dir.mkdir(exist_ok=True)
+
+ # Create database
+ db_path = data_dir / 'vpn.db'
+ db_url = f"sqlite:///{db_path}"
+ engine = create_engine(db_url)
+
+ # Create all tables
+ Base.metadata.create_all(engine)
+ print(f"Database initialized at: {db_path}")
+ return engine
+
+if __name__ == "__main__":
+ init_db()
\ No newline at end of file
diff --git a/scripts/migrate_db.py b/scripts/migrate_db.py
new file mode 100644
index 0000000..4e68743
--- /dev/null
+++ b/scripts/migrate_db.py
@@ -0,0 +1,151 @@
+import sys
+from pathlib import Path
+import sqlite3
+import uuid
+import logging
+from datetime import datetime
+
+# Add the project root to Python path
+project_root = Path(__file__).resolve().parent.parent
+sys.path.append(str(project_root))
+
+from app.utils.db.models import Base, SubscriptionStatus
+from app.utils.db import get_db_path, get_session
+from sqlalchemy import create_engine
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def backup_database():
+ """Create a backup of the existing database"""
+ db_path = get_db_path()
+ if not db_path.exists():
+ logger.info("No existing database found, skipping backup")
+ return
+
+ backup_path = db_path.parent / f"{db_path.stem}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}{db_path.suffix}"
+ db_path.rename(backup_path)
+ logger.info(f"Created database backup at {backup_path}")
+ return backup_path
+
+def migrate_database():
+ """Perform database migration"""
+ try:
+ # Create backup
+ backup_path = backup_database()
+
+ # Initialize new database
+ engine = create_engine(f"sqlite:///{get_db_path()}")
+ Base.metadata.create_all(engine)
+
+ # If there's an old database, migrate the data
+ if backup_path and backup_path.exists():
+ logger.info("Migrating data from old database...")
+
+ # Connect to old database
+ old_conn = sqlite3.connect(backup_path)
+ old_cursor = old_conn.cursor()
+
+ # Get schema information
+ old_cursor.execute("SELECT sql FROM sqlite_master WHERE type='table'")
+ old_schema = old_cursor.fetchall()
+ logger.info("Old database schema:")
+ for table in old_schema:
+ logger.info(table[0])
+
+ with get_session() as session:
+ try:
+ # Start transaction
+ session.begin()
+
+ # Migrate users
+ logger.info("Migrating users...")
+ old_cursor.execute("SELECT id, email, created_at FROM users")
+ users = old_cursor.fetchall()
+ user_id_map = {} # Map old IDs to new UUIDs
+
+ for old_id, email, created_at in users:
+ new_user_id = str(uuid.uuid4())
+ user_id_map[old_id] = new_user_id
+ session.execute(
+ "INSERT INTO users (user_id, created_at) VALUES (?, ?)",
+ [new_user_id, created_at or datetime.utcnow()]
+ )
+
+ # Migrate subscriptions
+ logger.info("Migrating subscriptions...")
+ old_cursor.execute("""
+ SELECT id, user_id, invoice_id, start_time, expiry_time,
+ status, warning_sent
+ FROM subscriptions
+ """)
+ subscriptions = old_cursor.fetchall()
+
+ for sub in subscriptions:
+ old_id, old_user_id, invoice_id, start_time, expiry_time, status, warning_sent = sub
+ if old_user_id in user_id_map:
+ # Generate a placeholder public key for existing subscriptions
+ placeholder_pubkey = f"MIGRATED_{uuid.uuid4()}"
+ session.execute("""
+ INSERT INTO subscriptions
+ (user_id, invoice_id, public_key, start_time, expiry_time,
+ status, warning_sent, assigned_ip)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ """, [
+ user_id_map[old_user_id],
+ invoice_id,
+ placeholder_pubkey,
+ start_time,
+ expiry_time,
+ status,
+ warning_sent,
+ f"10.8.0.{2 + old_id}" # Simple IP assignment
+ ])
+
+ # Migrate payments
+ logger.info("Migrating payments...")
+ old_cursor.execute("""
+ SELECT user_id, subscription_id, invoice_id, amount, timestamp
+ FROM payments
+ """)
+ payments = old_cursor.fetchall()
+
+ for payment in payments:
+ old_user_id, sub_id, invoice_id, amount, timestamp = payment
+ if old_user_id in user_id_map:
+ session.execute("""
+ INSERT INTO payments
+ (user_id, subscription_id, invoice_id, amount, timestamp)
+ VALUES (?, ?, ?, ?, ?)
+ """, [
+ user_id_map[old_user_id],
+ sub_id,
+ invoice_id,
+ amount,
+ timestamp
+ ])
+
+ # Commit transaction
+ session.commit()
+ logger.info("Migration completed successfully")
+
+ except Exception as e:
+ session.rollback()
+ logger.error(f"Migration failed: {str(e)}")
+ raise
+
+ old_conn.close()
+
+ else:
+ logger.info("No existing database to migrate")
+
+ except Exception as e:
+ logger.error(f"Migration failed: {str(e)}")
+ raise
+
+if __name__ == '__main__':
+ try:
+ migrate_database()
+ except Exception as e:
+ logger.error(f"Migration failed: {str(e)}")
+ sys.exit(1)
\ No newline at end of file
diff --git a/scripts/subscription_checker.py b/scripts/subscription_checker.py
index e3c3acc..9e6f368 100644
--- a/scripts/subscription_checker.py
+++ b/scripts/subscription_checker.py
@@ -1,14 +1,12 @@
-# scripts/subscription_checker.py
-
-import json
-import datetime
import logging
import subprocess
import os
import tempfile
from pathlib import Path
-from dateutil.parser import parse
-from dateutil.relativedelta import relativedelta
+from datetime import datetime, timedelta
+from utils.db.operations import DatabaseManager
+from utils.db.models import SubscriptionStatus
+from app.handlers.payment_handler import BTCPayHandler
logging.basicConfig(
level=logging.INFO,
@@ -19,20 +17,19 @@ logger = logging.getLogger(__name__)
# Path setup
SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent
-SUBSCRIPTION_DB = PROJECT_ROOT / 'data' / 'subscriptions.json'
CLEANUP_PLAYBOOK = PROJECT_ROOT / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
INVENTORY_FILE = PROJECT_ROOT / 'inventory.ini'
# Notification thresholds configuration
NOTIFICATION_THRESHOLDS = {
- 'minimum_duration': datetime.timedelta(hours=1), # Minimum subscription duration
+ 'minimum_duration': timedelta(hours=1), # Minimum subscription duration
'short_term': {
- 'max_duration': datetime.timedelta(days=1),
+ 'max_duration': timedelta(days=1),
'warning_fraction': 0.5, # Warn when 50% of time remains
'grace_fraction': 0.1 # Grace period of 10% of subscription length
},
'medium_term': {
- 'max_duration': datetime.timedelta(days=7),
+ 'max_duration': timedelta(days=7),
'warning_fraction': 0.25, # Warn when 25% of time remains
'grace_hours': 12 # Fixed 12-hour grace period
},
@@ -42,25 +39,9 @@ NOTIFICATION_THRESHOLDS = {
}
}
-def load_subscriptions():
- """Load subscription data from JSON file"""
- if not SUBSCRIPTION_DB.parent.exists():
- SUBSCRIPTION_DB.parent.mkdir(parents=True)
-
- if not SUBSCRIPTION_DB.exists():
- return {}
-
- with open(SUBSCRIPTION_DB, 'r') as f:
- return json.load(f)
-
-def save_subscriptions(subscriptions):
- """Save subscriptions to JSON file"""
- with open(SUBSCRIPTION_DB, 'w') as f:
- json.dump(subscriptions, f, indent=2)
-
-def run_cleanup_playbook(sub_id):
+def run_cleanup_playbook(subscription_id):
"""Run the VPN cleanup playbook"""
- logger.info(f"Running cleanup playbook for subscription {sub_id}")
+ logger.info(f"Running cleanup playbook for subscription {subscription_id}")
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
if not vault_pass:
@@ -74,7 +55,7 @@ def run_cleanup_playbook(sub_id):
'ansible-playbook',
str(CLEANUP_PLAYBOOK),
'-i', str(INVENTORY_FILE),
- '-e', f'subscription_id={sub_id}',
+ '-e', f'subscription_id={subscription_id}',
'--vault-password-file', vault_pass_file.name,
'-vvv'
]
@@ -122,42 +103,36 @@ def calculate_notification_times(start_time, end_time):
warning_delta = duration * NOTIFICATION_THRESHOLDS['medium_term']['warning_fraction']
grace_hours = NOTIFICATION_THRESHOLDS['medium_term']['grace_hours']
warning_time = end_time - warning_delta
- grace_end_time = end_time + datetime.timedelta(hours=grace_hours)
+ grace_end_time = end_time + timedelta(hours=grace_hours)
# Long-term subscriptions (> 7 days)
else:
warning_days = NOTIFICATION_THRESHOLDS['long_term']['warning_days']
grace_days = NOTIFICATION_THRESHOLDS['long_term']['grace_days']
- warning_time = end_time - datetime.timedelta(days=warning_days)
- grace_end_time = end_time + datetime.timedelta(days=grace_days)
+ warning_time = end_time - timedelta(days=warning_days)
+ grace_end_time = end_time + timedelta(days=grace_days)
return warning_time, grace_end_time
-def get_notification_message(sub_data, remaining_time):
+def get_notification_message(subscription, remaining_time):
"""Generate appropriate notification message based on subscription duration"""
- if remaining_time < datetime.timedelta(hours=1):
+ if remaining_time < timedelta(hours=1):
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 60)} minutes!"
- elif remaining_time < datetime.timedelta(days=1):
+ elif remaining_time < timedelta(days=1):
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 3600)} hours!"
else:
return f"Your VPN subscription expires in {remaining_time.days} days!"
-def notify_user(user_id, message):
+def notify_user(subscription, message):
"""Send notification to user about subscription status"""
try:
- subscriptions = load_subscriptions()
- sub_data = subscriptions.get(user_id)
-
- if not sub_data or 'email' not in sub_data:
- logger.error(f"No email found for user {user_id}")
+ if not subscription.user or not subscription.user.email:
+ logger.error(f"No email found for subscription {subscription.id}")
return False
-
- # Import BTCPayHandler here to avoid circular imports
- from app.handlers.payment_handler import BTCPayHandler
btcpay_handler = BTCPayHandler()
email_sent = btcpay_handler.send_confirmation_email(
- sub_data['email'],
+ subscription.user.email,
f"""
VPN Subscription Update
@@ -168,9 +143,9 @@ def notify_user(user_id, message):
)
if email_sent:
- logger.info(f"Sent notification to {sub_data['email']}: {message}")
+ logger.info(f"Sent notification to {subscription.user.email}: {message}")
else:
- logger.warning(f"Failed to send notification to {sub_data['email']}")
+ logger.warning(f"Failed to send notification to {subscription.user.email}")
return email_sent
@@ -182,68 +157,64 @@ def check_subscriptions():
"""Check subscription status and clean up expired ones"""
logger.info("Starting subscription check")
- subscriptions = load_subscriptions()
- now = datetime.datetime.now()
- modified = False
-
- logger.info(f"Checking {len(subscriptions)} subscriptions")
-
- for sub_id, sub_data in list(subscriptions.items()):
- try:
- logger.debug(f"Processing subscription {sub_id}")
-
- if sub_data.get('status') != 'Active':
- logger.debug(f"Skipping inactive subscription {sub_id}")
- continue
+ try:
+ active_subscriptions = DatabaseManager.get_active_subscriptions()
+ logger.info(f"Checking {len(active_subscriptions)} active subscriptions")
+
+ now = datetime.utcnow()
+
+ for subscription in active_subscriptions:
+ try:
+ logger.debug(f"Processing subscription {subscription.id}")
- start_time = parse(sub_data['start_time'])
- expiry = parse(sub_data['expiry'])
- warning_time, grace_end_time = calculate_notification_times(start_time, expiry)
-
- # Calculate remaining time
- remaining_time = expiry - now
- logger.debug(f"Subscription {sub_id} has {remaining_time} remaining")
-
- # Handle warnings
- if now >= warning_time and not sub_data.get('warning_sent'):
- message = get_notification_message(sub_data, remaining_time)
- logger.info(f"Sending notification for subscription {sub_id}: {message}")
+ warning_time, grace_end_time = calculate_notification_times(
+ subscription.start_time,
+ subscription.expiry_time
+ )
- notify_user(sub_id, message)
+ # Calculate remaining time
+ remaining_time = subscription.expiry_time - now
+ logger.debug(f"Subscription {subscription.id} has {remaining_time} remaining")
- sub_data['warning_sent'] = True
- modified = True
-
- # Handle expiration
- if now >= grace_end_time and sub_data.get('status') == 'Active':
- logger.info(f"Processing expiration for subscription {sub_id}")
-
- try:
- logger.debug(f"Running cleanup playbook for {sub_id}")
- result = run_cleanup_playbook(sub_id)
+ # Handle warnings
+ if now >= warning_time and not subscription.warning_sent:
+ message = get_notification_message(subscription, remaining_time)
+ logger.info(f"Sending notification for subscription {subscription.id}: {message}")
- if result.returncode == 0:
- sub_data['status'] = 'Expired'
- sub_data['cleanup_date'] = now.isoformat()
- modified = True
- logger.info(f"Successfully cleaned up subscription {sub_id}")
- else:
- logger.error(f"Cleanup failed: {result.stderr}")
+ if notify_user(subscription, message):
+ DatabaseManager.update_warning_sent(subscription.id)
+
+ # Handle expiration
+ if now >= grace_end_time:
+ logger.info(f"Processing expiration for subscription {subscription.id}")
+
+ try:
+ result = run_cleanup_playbook(subscription.invoice_id)
- except Exception as e:
- logger.error(f"Error during cleanup: {str(e)}")
- logger.error(traceback.format_exc())
-
- except Exception as e:
- logger.error(f"Error processing subscription {sub_id}: {str(e)}")
- logger.error(traceback.format_exc())
- continue
-
- if modified:
- save_subscriptions(subscriptions)
- logger.info("Updated subscription database")
-
- logger.info("Subscription check completed")
+ if result.returncode == 0:
+ DatabaseManager.expire_subscription(subscription.id)
+ logger.info(f"Successfully cleaned up subscription {subscription.id}")
+
+ # Send final notification
+ notify_user(subscription, "Your VPN subscription has expired and been deactivated.")
+ else:
+ logger.error(f"Cleanup failed: {result.stderr}")
+
+ except Exception as e:
+ logger.error(f"Error during cleanup: {str(e)}")
+ logger.error(traceback.format_exc())
+
+ except Exception as e:
+ logger.error(f"Error processing subscription {subscription.id}: {str(e)}")
+ logger.error(traceback.format_exc())
+ continue
+
+ logger.info("Subscription check completed")
+
+ except Exception as e:
+ logger.error(f"Subscription checker failed: {str(e)}")
+ logger.error(traceback.format_exc())
+ raise
if __name__ == '__main__':
try:
diff --git a/venv/include/site/python3.11/greenlet/greenlet.h b/venv/include/site/python3.11/greenlet/greenlet.h
new file mode 100644
index 0000000..d02a16e
--- /dev/null
+++ b/venv/include/site/python3.11/greenlet/greenlet.h
@@ -0,0 +1,164 @@
+/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
+
+/* Greenlet object interface */
+
+#ifndef Py_GREENLETOBJECT_H
+#define Py_GREENLETOBJECT_H
+
+
+#include
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* This is deprecated and undocumented. It does not change. */
+#define GREENLET_VERSION "1.0.0"
+
+#ifndef GREENLET_MODULE
+#define implementation_ptr_t void*
+#endif
+
+typedef struct _greenlet {
+ PyObject_HEAD
+ PyObject* weakreflist;
+ PyObject* dict;
+ implementation_ptr_t pimpl;
+} PyGreenlet;
+
+#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
+
+
+/* C API functions */
+
+/* Total number of symbols that are exported */
+#define PyGreenlet_API_pointers 12
+
+#define PyGreenlet_Type_NUM 0
+#define PyExc_GreenletError_NUM 1
+#define PyExc_GreenletExit_NUM 2
+
+#define PyGreenlet_New_NUM 3
+#define PyGreenlet_GetCurrent_NUM 4
+#define PyGreenlet_Throw_NUM 5
+#define PyGreenlet_Switch_NUM 6
+#define PyGreenlet_SetParent_NUM 7
+
+#define PyGreenlet_MAIN_NUM 8
+#define PyGreenlet_STARTED_NUM 9
+#define PyGreenlet_ACTIVE_NUM 10
+#define PyGreenlet_GET_PARENT_NUM 11
+
+#ifndef GREENLET_MODULE
+/* This section is used by modules that uses the greenlet C API */
+static void** _PyGreenlet_API = NULL;
+
+# define PyGreenlet_Type \
+ (*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
+
+# define PyExc_GreenletError \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
+
+# define PyExc_GreenletExit \
+ ((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
+
+/*
+ * PyGreenlet_New(PyObject *args)
+ *
+ * greenlet.greenlet(run, parent=None)
+ */
+# define PyGreenlet_New \
+ (*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
+ _PyGreenlet_API[PyGreenlet_New_NUM])
+
+/*
+ * PyGreenlet_GetCurrent(void)
+ *
+ * greenlet.getcurrent()
+ */
+# define PyGreenlet_GetCurrent \
+ (*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
+
+/*
+ * PyGreenlet_Throw(
+ * PyGreenlet *greenlet,
+ * PyObject *typ,
+ * PyObject *val,
+ * PyObject *tb)
+ *
+ * g.throw(...)
+ */
+# define PyGreenlet_Throw \
+ (*(PyObject * (*)(PyGreenlet * self, \
+ PyObject * typ, \
+ PyObject * val, \
+ PyObject * tb)) \
+ _PyGreenlet_API[PyGreenlet_Throw_NUM])
+
+/*
+ * PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
+ *
+ * g.switch(*args, **kwargs)
+ */
+# define PyGreenlet_Switch \
+ (*(PyObject * \
+ (*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
+ _PyGreenlet_API[PyGreenlet_Switch_NUM])
+
+/*
+ * PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
+ *
+ * g.parent = new_parent
+ */
+# define PyGreenlet_SetParent \
+ (*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
+ _PyGreenlet_API[PyGreenlet_SetParent_NUM])
+
+/*
+ * PyGreenlet_GetParent(PyObject* greenlet)
+ *
+ * return greenlet.parent;
+ *
+ * This could return NULL even if there is no exception active.
+ * If it does not return NULL, you are responsible for decrementing the
+ * reference count.
+ */
+# define PyGreenlet_GetParent \
+ (*(PyGreenlet* (*)(PyGreenlet*)) \
+ _PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
+
+/*
+ * deprecated, undocumented alias.
+ */
+# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
+
+# define PyGreenlet_MAIN \
+ (*(int (*)(PyGreenlet*)) \
+ _PyGreenlet_API[PyGreenlet_MAIN_NUM])
+
+# define PyGreenlet_STARTED \
+ (*(int (*)(PyGreenlet*)) \
+ _PyGreenlet_API[PyGreenlet_STARTED_NUM])
+
+# define PyGreenlet_ACTIVE \
+ (*(int (*)(PyGreenlet*)) \
+ _PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
+
+
+
+
+/* Macro that imports greenlet and initializes C API */
+/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
+ keep the older definition to be sure older code that might have a copy of
+ the header still works. */
+# define PyGreenlet_Import() \
+ { \
+ _PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
+ }
+
+#endif /* GREENLET_MODULE */
+
+#ifdef __cplusplus
+}
+#endif
+#endif /* !Py_GREENLETOBJECT_H */