DB update.

This commit is contained in:
Enki 2024-12-30 06:03:07 +00:00
parent bcd6c435f7
commit 72972efe5e
20 changed files with 1302 additions and 314 deletions

View File

@ -44,7 +44,23 @@
dest: "{{ server_dir }}/public.key" dest: "{{ server_dir }}/public.key"
mode: '0644' mode: '0644'
when: not server_config.stat.exists when: not server_config.stat.exists
- name: Update vault with server details
block:
- name: Read server public key
shell: "cat {{ server_dir }}/public.key"
register: pubkey_content
changed_when: false
- name: Save server details to vault
copy:
content: |
wireguard_server_public_key: "{{ pubkey_content.stdout }}"
wireguard_server_endpoint: "{{ ansible_host }}"
dest: "{{ playbook_dir }}/../group_vars/vpn_servers/vault.yml"
mode: '0600'
when: not server_config.stat.exists
- name: Create initial server config - name: Create initial server config
template: template:
src: templates/server.conf.j2 src: templates/server.conf.j2

View File

@ -1,5 +1,3 @@
# app/__init__.py
from flask import Flask, request, jsonify, render_template from flask import Flask, request, jsonify, render_template
import logging import logging
from .handlers.webhook_handler import handle_payment_webhook from .handlers.webhook_handler import handle_payment_webhook

BIN
app/data/vpn.db Normal file

Binary file not shown.

View File

@ -1,15 +1,11 @@
# app/handlers/payment_handler.py
import logging import logging
import requests import requests
import os import os
from flask import jsonify
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from pathlib import Path from pathlib import Path
import traceback import traceback
from .webhook_handler import get_vault_values from .webhook_handler import get_vault_values
from ..utils.db.operations import DatabaseManager
from ..utils.db.models import SubscriptionStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,9 +21,6 @@ class BTCPayHandler:
self.api_key = vault_values['btcpay_api_key'] self.api_key = vault_values['btcpay_api_key']
self.store_id = vault_values['btcpay_store_id'] self.store_id = vault_values['btcpay_store_id']
# Email configuration
self.smtp_config = vault_values.get('smtp_config', {})
logger.info(f"BTCPayHandler initialized with base URL: {self.base_url}") logger.info(f"BTCPayHandler initialized with base URL: {self.base_url}")
except Exception as e: except Exception as e:
@ -35,10 +28,16 @@ class BTCPayHandler:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
def create_invoice(self, amount_sats, duration_hours, email): def create_invoice(self, amount_sats, duration_hours, user_id, public_key):
"""Create BTCPay invoice for VPN subscription""" """Create BTCPay invoice for VPN subscription"""
try: try:
logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for {email}") logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for user {user_id}")
# First, get or create user
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
user = DatabaseManager.create_user(user_id)
logger.info(f"Created new user with ID: {user_id}")
headers = { headers = {
'Authorization': f'token {self.api_key}', 'Authorization': f'token {self.api_key}',
@ -54,11 +53,12 @@ class BTCPayHandler:
'currency': 'SATS', 'currency': 'SATS',
'metadata': { 'metadata': {
'duration_hours': duration_hours, 'duration_hours': duration_hours,
'email': email, 'userId': user_id,
'orderId': f'vpn_sub_{duration_hours}h', 'publicKey': public_key,
'orderId': f'vpn_sub_{duration_hours}h'
}, },
'checkout': { 'checkout': {
'redirectURL': f'{app_url}/payment/success', 'redirectURL': f'{app_url}/payment/success?userId={user_id}',
'redirectAutomatically': True 'redirectAutomatically': True
} }
} }
@ -80,11 +80,22 @@ class BTCPayHandler:
return None return None
invoice_data = response.json() invoice_data = response.json()
logger.info(f"Successfully created invoice {invoice_data.get('id')}") invoice_id = invoice_data['id']
logger.info(f"Successfully created invoice {invoice_id}")
# Create pending subscription
subscription = DatabaseManager.create_subscription(
user_id,
invoice_id,
public_key,
duration_hours
)
logger.info(f"Created pending subscription for invoice {invoice_id}")
return { return {
'invoice_id': invoice_data['id'], 'invoice_id': invoice_id,
'checkout_url': invoice_data['checkoutLink'] 'checkout_url': invoice_data['checkoutLink'],
'user_id': user_id
} }
except Exception as e: except Exception as e:
@ -92,53 +103,27 @@ class BTCPayHandler:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
def send_confirmation_email(self, email, config_data): def get_subscription_config(self, user_id):
"""Send VPN configuration details via email""" """Get WireGuard configuration details for a subscription"""
try: try:
logger.info(f"Sending confirmation email to {email}") logger.info(f"Fetching subscription config for user {user_id}")
user = DatabaseManager.get_user_by_uuid(user_id)
if not self.smtp_config: if not user:
logger.warning("SMTP configuration not found in vault") logger.error(f"User {user_id} not found")
return False return None
msg = MIMEMultipart() subscription = DatabaseManager.get_active_subscription_for_user(user.id)
msg['From'] = self.smtp_config['sender_email'] if not subscription:
msg['To'] = email logger.error(f"No active subscription found for user {user_id}")
msg['Subject'] = "Your VPN Configuration" return None
body = f""" return {
Thank you for subscribing to our VPN service! 'serverPublicKey': os.getenv('WIREGUARD_SERVER_PUBLIC_KEY'),
'serverEndpoint': os.getenv('WIREGUARD_SERVER_ENDPOINT'),
Please find your WireGuard configuration below: 'clientIp': subscription.assigned_ip
}
{config_data}
Installation instructions:
1. Install WireGuard client for your platform from https://www.wireguard.com/install/
2. Save the above configuration to a file named 'wg0.conf'
3. Import the configuration file into your WireGuard client
Need help? Reply to this email for support.
"""
msg.attach(MIMEText(body, 'plain'))
logger.debug("Connecting to SMTP server")
with smtplib.SMTP(
self.smtp_config['server'],
self.smtp_config.get('port', 587)
) as server:
server.starttls()
server.login(
self.smtp_config['username'],
self.smtp_config['password']
)
server.send_message(msg)
logger.info("Confirmation email sent successfully")
return True
except Exception as e: except Exception as e:
logger.error("Error sending confirmation email:") logger.error(f"Error getting subscription config: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return None

View File

@ -6,11 +6,12 @@ import logging
import hmac import hmac
import hashlib import hashlib
import yaml import yaml
import json
import datetime import datetime
import traceback import traceback
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
from ..utils.db.operations import DatabaseManager
from ..utils.db.models import SubscriptionStatus
load_dotenv() load_dotenv()
@ -23,7 +24,6 @@ logger = logging.getLogger(__name__)
BASE_DIR = Path(__file__).resolve().parent.parent.parent BASE_DIR = Path(__file__).resolve().parent.parent.parent
PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml' PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml'
CLEANUP_PLAYBOOK = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_cleanup.yml' CLEANUP_PLAYBOOK = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
SUBSCRIPTION_DB = BASE_DIR / 'data' / 'subscriptions.json'
def get_vault_values(): def get_vault_values():
"""Get decrypted values from Ansible vault""" """Get decrypted values from Ansible vault"""
@ -80,31 +80,6 @@ def verify_signature(payload_body, signature_header):
logger.error(f"Signature verification failed: {str(e)}") logger.error(f"Signature verification failed: {str(e)}")
return False return False
def load_subscriptions():
"""Load subscription data from JSON file"""
if not SUBSCRIPTION_DB.parent.exists():
SUBSCRIPTION_DB.parent.mkdir(parents=True)
if not SUBSCRIPTION_DB.exists():
return {}
with open(SUBSCRIPTION_DB, 'r') as f:
return json.load(f)
def save_subscription(subscription_data):
"""Save subscription data to JSON file"""
subscriptions = load_subscriptions()
sub_id = subscription_data['subscriptionId']
subscriptions[sub_id] = subscription_data
with open(SUBSCRIPTION_DB, 'w') as f:
json.dump(subscriptions, f, indent=2)
def calculate_expiry(duration_hours):
"""Calculate expiry date based on subscription duration"""
return (datetime.datetime.now() +
datetime.timedelta(hours=duration_hours)).isoformat()
def run_ansible_playbook(invoice_id): def run_ansible_playbook(invoice_id):
"""Run the VPN provisioning playbook""" """Run the VPN provisioning playbook"""
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '') vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
@ -142,12 +117,15 @@ def handle_subscription_status(data):
logger.info(f"Processing subscription status update: {sub_id} -> {status}") logger.info(f"Processing subscription status update: {sub_id} -> {status}")
# Store subscription data subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
data['last_updated'] = datetime.datetime.now().isoformat() if not subscription:
save_subscription(data) logger.error(f"Subscription {sub_id} not found")
return jsonify({"error": "Subscription not found"}), 404
if status != 'Active':
# Run cleanup playbook for inactive subscriptions if status == 'Active':
DatabaseManager.activate_subscription(sub_id)
else:
# Run cleanup for inactive subscriptions
result = subprocess.run([ result = subprocess.run([
'ansible-playbook', 'ansible-playbook',
str(CLEANUP_PLAYBOOK), str(CLEANUP_PLAYBOOK),
@ -158,7 +136,8 @@ def handle_subscription_status(data):
if result.returncode != 0: if result.returncode != 0:
logger.error(f"Failed to clean up subscription {sub_id}: {result.stderr}") logger.error(f"Failed to clean up subscription {sub_id}: {result.stderr}")
DatabaseManager.expire_subscription(subscription.id)
logger.info(f"Subscription {sub_id} is no longer active") logger.info(f"Subscription {sub_id} is no longer active")
return jsonify({ return jsonify({
@ -171,9 +150,10 @@ def handle_subscription_renewal(data):
sub_id = data['subscriptionId'] sub_id = data['subscriptionId']
logger.info(f"Processing subscription renewal request: {sub_id}") logger.info(f"Processing subscription renewal request: {sub_id}")
# Update subscription data subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
data['renewal_requested'] = datetime.datetime.now().isoformat() if not subscription:
save_subscription(data) logger.error(f"Subscription {sub_id} not found")
return jsonify({"error": "Subscription not found"}), 404
# TODO: Send renewal notification to user # TODO: Send renewal notification to user
return jsonify({ return jsonify({
@ -199,6 +179,16 @@ def handle_payment_webhook(request):
data = request.json data = request.json
logger.info(f"Received webhook data: {data}") logger.info(f"Received webhook data: {data}")
# Handle test webhooks
invoice_id = data.get('invoiceId', '')
if invoice_id.startswith('__test__'):
logger.info(f"Received test webhook, acknowledging: {data.get('type')}")
return jsonify({
"status": "success",
"message": "Test webhook acknowledged"
})
webhook_type = data.get('type') webhook_type = data.get('type')
if webhook_type == 'SubscriptionStatusUpdated': if webhook_type == 'SubscriptionStatusUpdated':
@ -207,20 +197,13 @@ def handle_payment_webhook(request):
elif webhook_type == 'SubscriptionRenewalRequested': elif webhook_type == 'SubscriptionRenewalRequested':
return handle_subscription_renewal(data) return handle_subscription_renewal(data)
elif webhook_type == 'InvoiceSettled': elif webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']:
# Handle regular invoice payment
invoice_id = data.get('invoiceId') invoice_id = data.get('invoiceId')
metadata = data.get('metadata', {})
if not invoice_id: if not invoice_id:
logger.error("Missing invoiceId in webhook data") logger.error("Missing invoiceId in webhook data")
return jsonify({"error": "Missing invoiceId"}), 400 return jsonify({"error": "Missing invoiceId"}), 400
if invoice_id.startswith('__test__') and invoice_id.endswith('__test__'):
invoice_id = invoice_id[8:-8]
logger.info(f"Stripped test markers from invoice ID: {invoice_id}")
# Run Ansible playbook with enhanced logging # Get subscription and run Ansible playbook
logger.info(f"Starting VPN provisioning for invoice {invoice_id}") logger.info(f"Starting VPN provisioning for invoice {invoice_id}")
result = run_ansible_playbook(invoice_id) result = run_ansible_playbook(invoice_id)
@ -236,46 +219,18 @@ def handle_payment_webhook(request):
"stderr": result.stderr "stderr": result.stderr
}), 500 }), 500
# Get subscription and activate it
subscription = DatabaseManager.get_subscription_by_invoice(invoice_id)
if subscription:
subscription = DatabaseManager.activate_subscription(invoice_id)
DatabaseManager.record_payment(
subscription.user_id,
subscription.id,
invoice_id,
data.get('amount', 0)
)
logger.info(f"VPN provisioning completed for invoice {invoice_id}") logger.info(f"VPN provisioning completed for invoice {invoice_id}")
# Update subscription database
try:
duration_hours = metadata.get('duration_hours', 24)
subscriptions = load_subscriptions()
subscriptions[invoice_id] = {
'email': metadata.get('email'),
'duration_hours': duration_hours,
'start_time': datetime.datetime.now().isoformat(),
'expiry': calculate_expiry(duration_hours),
'status': 'Active'
}
with open(SUBSCRIPTION_DB, 'w') as f:
json.dump(subscriptions, f, indent=2)
logger.info(f"Updated subscription database for invoice {invoice_id}")
except Exception as e:
logger.error(f"Error updating subscription database: {str(e)}")
# Send email confirmation if email is provided
try:
email = metadata.get('email')
if email:
config_path = f"/etc/wireguard/clients/{invoice_id}/wg0.conf"
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config_data = f.read()
btcpay_handler = BTCPayHandler()
email_sent = btcpay_handler.send_confirmation_email(email, config_data)
if email_sent:
logger.info(f"Sent configuration email to {email}")
else:
logger.warning(f"Failed to send configuration email to {email}")
else:
logger.warning(f"Config file not found at {config_path}")
except Exception as e:
logger.error(f"Error sending confirmation email: {str(e)}")
logger.info(f"Successfully processed invoice {invoice_id}")
return jsonify({ return jsonify({
"status": "success", "status": "success",
"invoice_id": invoice_id, "invoice_id": invoice_id,
@ -283,7 +238,6 @@ def handle_payment_webhook(request):
}) })
else: else:
# Log other webhook types as info instead of warning
logger.info(f"Received {webhook_type} webhook - no action required") logger.info(f"Received {webhook_type} webhook - no action required")
return jsonify({ return jsonify({
"status": "success", "status": "success",

View File

@ -0,0 +1,213 @@
import React, { useState, useEffect } from 'react';
import { generateKeys } from '../utils/wireguard';
import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/card';
import { Button } from '@/components/ui/button';
import { Input } from '@/components/ui/input';
import { AlertCircle } from 'lucide-react';
import { Alert, AlertDescription } from '@/components/ui/alert';
const WireGuardPayment = () => {
const [keyData, setKeyData] = useState(null);
const [duration, setDuration] = useState(24);
const [price, setPrice] = useState(0);
const [loading, setLoading] = useState(false);
const [error, setError] = useState('');
const [userId, setUserId] = useState('');
useEffect(() => {
// Generate random userId and keys when component mounts
const init = async () => {
try {
const randomId = crypto.randomUUID();
setUserId(randomId);
const keys = await generateKeys();
setKeyData(keys);
calculatePrice(duration);
} catch (err) {
setError('Failed to initialize keys');
console.error(err);
}
};
init();
}, []);
const calculatePrice = async (hours) => {
try {
const response = await fetch('/api/calculate-price', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ hours: parseInt(hours) })
});
const data = await response.json();
setPrice(data.price);
} catch (err) {
setError('Failed to calculate price');
}
};
const handleDurationChange = (event) => {
const newDuration = parseInt(event.target.value);
setDuration(newDuration);
calculatePrice(newDuration);
};
const handlePayment = async () => {
if (!keyData) {
setError('No keys generated. Please refresh the page.');
return;
}
try {
setLoading(true);
const response = await fetch('/create-invoice', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
duration,
userId,
publicKey: keyData.publicKey,
// Don't send private key to server!
configuration: {
type: 'wireguard',
publicKey: keyData.publicKey
}
})
});
if (!response.ok) {
throw new Error('Failed to create payment');
}
const data = await response.json();
// Save private key to localStorage before redirecting
localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify({
privateKey: keyData.privateKey,
publicKey: keyData.publicKey,
createdAt: new Date().toISOString()
}));
window.location.href = data.checkout_url;
} catch (err) {
setError('Failed to initiate payment');
console.error(err);
} finally {
setLoading(false);
}
};
const handleRegenerateKeys = async () => {
try {
const keys = await generateKeys();
setKeyData(keys);
} catch (err) {
setError('Failed to regenerate keys');
}
};
return (
<Card className="w-full max-w-xl mx-auto">
<CardHeader>
<CardTitle>WireGuard VPN Configuration</CardTitle>
</CardHeader>
<CardContent className="space-y-4">
{error && (
<Alert variant="destructive">
<AlertCircle className="h-4 w-4" />
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
<div className="space-y-2">
<label className="text-sm font-medium">Your User ID:</label>
<Input value={userId} readOnly className="font-mono text-sm" />
<p className="text-sm text-gray-500">
Save this ID - you'll need it to manage your subscription
</p>
</div>
{keyData && (
<div className="space-y-2">
<label className="text-sm font-medium">Your Public Key:</label>
<Input value={keyData.publicKey} readOnly className="font-mono text-sm" />
<Button
onClick={handleRegenerateKeys}
variant="outline"
size="sm"
className="w-full"
>
Regenerate Keys
</Button>
</div>
)}
<div className="space-y-2">
<label className="text-sm font-medium">Duration:</label>
<div className="space-y-1">
<Input
type="range"
min="1"
max="720"
value={duration}
onChange={handleDurationChange}
className="w-full"
/>
<div className="flex justify-between text-sm text-gray-500">
<Button
variant="ghost"
size="sm"
onClick={() => {
setDuration(24);
calculatePrice(24);
}}
>
1 Day
</Button>
<Button
variant="ghost"
size="sm"
onClick={() => {
setDuration(168);
calculatePrice(168);
}}
>
1 Week
</Button>
<Button
variant="ghost"
size="sm"
onClick={() => {
setDuration(720);
calculatePrice(720);
}}
>
30 Days
</Button>
</div>
<p className="text-center font-medium">{duration} hours</p>
</div>
</div>
<div className="text-center py-4">
<p className="text-3xl font-bold text-blue-500">{price} sats</p>
</div>
<Button
onClick={handlePayment}
disabled={loading || !keyData}
className="w-full"
>
{loading ? 'Processing...' : 'Pay with Bitcoin'}
</Button>
<div className="mt-4 text-sm text-gray-500 space-y-1">
<p> Keys are generated securely in your browser</p>
<p> Your private key never leaves your device</p>
<p> Configuration will be available after payment</p>
</div>
</CardContent>
</Card>
);
};
export default WireGuardPayment;

View File

@ -1,11 +1,39 @@
// app/static/js/pricing.js // Base64 encoding/decoding utilities
document.addEventListener('DOMContentLoaded', function() { const b64 = {
encode: array => btoa(String.fromCharCode.apply(null, array)),
decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
};
async function generateKeyPair() {
const keyPair = await window.crypto.subtle.generateKey(
{
name: 'X25519',
namedCurve: 'X25519',
},
true,
['deriveKey', 'deriveBits']
);
const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
return {
privateKey: b64.encode(new Uint8Array(privateKey)),
publicKey: b64.encode(new Uint8Array(publicKey))
};
}
document.addEventListener('DOMContentLoaded', async function() {
const form = document.getElementById('subscription-form'); const form = document.getElementById('subscription-form');
const slider = document.getElementById('duration-slider'); const slider = document.getElementById('duration-slider');
const durationDisplay = document.getElementById('duration-display'); const durationDisplay = document.getElementById('duration-display');
const priceDisplay = document.getElementById('price-display'); const priceDisplay = document.getElementById('price-display');
const presetButtons = document.querySelectorAll('.duration-preset'); const presetButtons = document.querySelectorAll('.duration-preset');
const emailInput = document.getElementById('email'); const userIdInput = document.getElementById('user-id');
const publicKeyInput = document.getElementById('public-key');
const regenerateButton = document.getElementById('regenerate-keys');
let currentKeyPair = null;
function formatDuration(hours) { function formatDuration(hours) {
if (hours < 24) return `${hours} hour${hours === 1 ? '' : 's'}`; if (hours < 24) return `${hours} hour${hours === 1 ? '' : 's'}`;
@ -14,6 +42,32 @@ document.addEventListener('DOMContentLoaded', function() {
return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`; return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`;
} }
async function generateNewKeys() {
try {
currentKeyPair = await generateKeyPair();
publicKeyInput.value = currentKeyPair.publicKey;
// Save private key to localStorage
const keyData = {
privateKey: currentKeyPair.privateKey,
publicKey: currentKeyPair.publicKey,
createdAt: new Date().toISOString()
};
localStorage.setItem(`vpn_keys_${userIdInput.value}`, JSON.stringify(keyData));
} catch (error) {
console.error('Failed to generate keys:', error);
alert('Failed to generate WireGuard keys. Please try again.');
}
}
async function initializeForm() {
// Generate user ID
userIdInput.value = crypto.randomUUID();
// Generate initial keys
await generateNewKeys();
}
async function updatePrice(hours) { async function updatePrice(hours) {
try { try {
const response = await fetch('/api/calculate-price', { const response = await fetch('/api/calculate-price', {
@ -29,7 +83,27 @@ document.addEventListener('DOMContentLoaded', function() {
} }
} }
async function createInvoice(duration, email) { // Event listeners
slider.addEventListener('input', () => updatePrice(slider.value));
regenerateButton.addEventListener('click', generateNewKeys);
presetButtons.forEach(button => {
button.addEventListener('click', (e) => {
const hours = e.target.dataset.hours;
slider.value = hours;
updatePrice(hours);
});
});
form.addEventListener('submit', async (e) => {
e.preventDefault();
if (!currentKeyPair) {
alert('No keys generated. Please refresh the page.');
return;
}
try { try {
const response = await fetch('/create-invoice', { const response = await fetch('/create-invoice', {
method: 'POST', method: 'POST',
@ -37,8 +111,9 @@ document.addEventListener('DOMContentLoaded', function() {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
}, },
body: JSON.stringify({ body: JSON.stringify({
duration: parseInt(duration), duration: parseInt(slider.value),
email: email userId: userIdInput.value,
publicKey: currentKeyPair.publicKey
}) })
}); });
@ -53,32 +128,10 @@ document.addEventListener('DOMContentLoaded', function() {
console.error('Error creating invoice:', error); console.error('Error creating invoice:', error);
alert('Failed to create payment invoice. Please try again.'); alert('Failed to create payment invoice. Please try again.');
} }
}
// Event listeners
slider.addEventListener('input', () => updatePrice(slider.value));
presetButtons.forEach(button => {
button.addEventListener('click', (e) => {
const hours = e.target.dataset.hours;
slider.value = hours;
updatePrice(hours);
});
});
form.addEventListener('submit', async (e) => {
e.preventDefault();
const email = emailInput.value.trim();
if (!email) {
alert('Please enter your email address');
return;
}
const duration = slider.value;
await createInvoice(duration, email);
}); });
// Initialize the form
await initializeForm();
// Initial price calculation // Initial price calculation
updatePrice(slider.value); updatePrice(slider.value);
}); });

View File

@ -0,0 +1,54 @@
// Base64 encoding/decoding utilities
const b64 = {
encode: array => btoa(String.fromCharCode.apply(null, array)),
decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
};
async function generateKeyPair() {
// Generate a random key pair using Web Crypto API
const keyPair = await window.crypto.subtle.generateKey(
{
name: 'X25519',
namedCurve: 'X25519',
},
true,
['deriveKey', 'deriveBits']
);
// Export keys in raw format
const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
// Convert to base64
return {
privateKey: b64.encode(new Uint8Array(privateKey)),
publicKey: b64.encode(new Uint8Array(publicKey))
};
}
export async function generateWireGuardConfig(serverPublicKey, serverEndpoint, address) {
const keys = await generateKeyPair();
return {
keys,
config: `[Interface]
PrivateKey = ${keys.privateKey}
Address = ${address}
DNS = 1.1.1.1
[Peer]
PublicKey = ${serverPublicKey}
Endpoint = ${serverEndpoint}
AllowedIPs = 0.0.0.0/0
PersistentKeepalive = 25`
};
}
export async function generateKeys() {
try {
return await generateKeyPair();
} catch (error) {
console.error('Failed to generate WireGuard keys:', error);
throw error;
}
}

View File

@ -1,30 +1,45 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block content %} {% block content %}
<div class="min-h-screen bg-dark py-8 px-4"> <div class="min-h-screen bg-dark py-8 px-4">
<div class="max-w-xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6"> <div class="max-w-xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6">
<h1 class="text-2xl font-bold mb-6 text-center">Subscribe to VPN Service</h1> <h1 class="text-2xl font-bold mb-6 text-center">Subscribe to VPN Service</h1>
<form id="subscription-form" class="space-y-6"> <form id="subscription-form" class="space-y-6">
<div> <div>
<label for="email" class="block text-sm font-medium mb-2">Email Address</label> <label class="block text-sm font-medium mb-2">User ID</label>
<input <input
type="email" type="text"
id="email" id="user-id"
required readonly
class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500" class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500 font-mono text-sm"
placeholder="your@email.com"
> >
<p class="mt-1 text-sm text-gray-400">Save this ID - you'll need it to manage your subscription</p>
</div> </div>
<div>
<label class="block text-sm font-medium mb-2">WireGuard Public Key</label>
<input
type="text"
id="public-key"
readonly
class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500 font-mono text-sm"
>
<button
type="button"
id="regenerate-keys"
class="mt-2 w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white hover:bg-gray-800 transition-colors"
>
Regenerate Keys
</button>
</div>
<div> <div>
<label class="block text-sm font-medium mb-2">Duration</label> <label class="block text-sm font-medium mb-2">Duration</label>
<div class="space-y-4"> <div class="space-y-4">
<input <input
type="range" type="range"
id="duration-slider" id="duration-slider"
min="1" min="1"
max="720" max="720"
value="24" value="24"
class="w-full" class="w-full"
> >
@ -44,12 +59,18 @@
</p> </p>
</div> </div>
<button <button
type="submit" type="submit"
class="w-full bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors" class="w-full bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors"
> >
Pay with Bitcoin Pay with Bitcoin
</button> </button>
<div class="mt-4 text-sm text-gray-400 space-y-1">
<p>• Keys are generated securely in your browser</p>
<p>• Your private key never leaves your device</p>
<p>• Configuration will be available after payment</p>
</div>
</form> </form>
</div> </div>
</div> </div>

View File

@ -2,17 +2,110 @@
{% block content %} {% block content %}
<div class="min-h-screen bg-dark py-8 px-4"> <div class="min-h-screen bg-dark py-8 px-4">
<div class="max-w-xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6 text-center"> <div class="max-w-2xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6">
<h1 class="text-2xl font-bold mb-4">Payment Successful!</h1> <h1 class="text-2xl font-bold mb-4 text-center">Payment Successful!</h1>
<p class="text-gray-300 mb-6">
Thank you for your payment. Your VPN configuration will be sent to your email shortly. <div class="mb-8">
</p> <p class="text-gray-300 text-center mb-4">
<p class="text-gray-400 mb-4"> Your VPN subscription is now active. Please follow the instructions below to set up your VPN connection.
Please check your email for further instructions on setting up your VPN connection. </p>
</p> </div>
<a href="/" class="inline-block bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors">
Return to Home <div class="space-y-6">
</a> <!-- WireGuard Configuration Section -->
<div id="config-section" class="hidden">
<h2 class="text-xl font-semibold mb-2">Your WireGuard Configuration</h2>
<pre id="wireguard-config" class="bg-dark p-4 rounded-md font-mono text-sm overflow-x-auto"></pre>
<button
id="copy-config"
class="mt-2 w-full bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors"
>
Copy Configuration
</button>
</div>
<!-- Installation Instructions -->
<div class="space-y-4">
<h2 class="text-xl font-semibold">Installation Instructions</h2>
<ol class="list-decimal list-inside space-y-2 text-gray-300">
<li>Download WireGuard for your platform:
<div class="ml-6 mt-2 space-y-2">
<a href="https://www.wireguard.com/install/"
target="_blank"
class="text-blue-400 hover:text-blue-300 block">
Official WireGuard Downloads
</a>
</div>
</li>
<li>Create a new tunnel in WireGuard</li>
<li>Copy the configuration above and paste it into the new tunnel</li>
<li>Activate the tunnel to connect</li>
</ol>
</div>
<div class="space-y-2 text-sm text-gray-400">
<p>• Save your configuration securely - you'll need it to reconnect</p>
<p>• Your private key is stored in your browser's local storage</p>
<p>• For security, clear your browser data after saving the configuration</p>
</div>
</div>
<div class="mt-8 text-center">
<a href="/" class="text-blue-400 hover:text-blue-300">
Return to Home
</a>
</div>
</div> </div>
</div> </div>
<script>
document.addEventListener('DOMContentLoaded', function() {
const userId = new URLSearchParams(window.location.search).get('userId');
if (userId) {
// Retrieve keys from localStorage
const keyData = localStorage.getItem(`vpn_keys_${userId}`);
if (keyData) {
const keys = JSON.parse(keyData);
// Make config section visible
document.getElementById('config-section').classList.remove('hidden');
// Get server details from response
fetch(`/api/subscription/config?userId=${userId}`)
.then(response => response.json())
.then(data => {
const config = `[Interface]
PrivateKey = ${keys.privateKey}
Address = ${data.clientIp}/24
DNS = 1.1.1.1
[Peer]
PublicKey = ${data.serverPublicKey}
Endpoint = ${data.serverEndpoint}:51820
AllowedIPs = 0.0.0.0/0
PersistentKeepalive = 25`;
document.getElementById('wireguard-config').textContent = config;
})
.catch(error => {
console.error('Error fetching configuration:', error);
});
// Setup copy button
document.getElementById('copy-config').addEventListener('click', function() {
const config = document.getElementById('wireguard-config').textContent;
navigator.clipboard.writeText(config).then(() => {
this.textContent = 'Copied!';
setTimeout(() => {
this.textContent = 'Copy Configuration';
}, 2000);
});
});
// Clear keys from localStorage after showing config
localStorage.removeItem(`vpn_keys_${userId}`);
}
}
});
</script>
{% endblock %} {% endblock %}

23
app/utils/db/__init__.py Normal file
View File

@ -0,0 +1,23 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from pathlib import Path
def get_db_path():
base_dir = Path(__file__).resolve().parent.parent.parent
data_dir = base_dir / 'data'
data_dir.mkdir(exist_ok=True)
return data_dir / 'vpn.db'
def init_db():
"""Initialize the database"""
from .models import Base
db_url = f"sqlite:///{get_db_path()}"
engine = create_engine(db_url)
Base.metadata.create_all(engine)
return engine
def get_session():
"""Get a database session"""
engine = init_db()
Session = sessionmaker(bind=engine)
return Session()

52
app/utils/db/models.py Normal file
View File

@ -0,0 +1,52 @@
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Enum, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
import enum
import datetime
Base = declarative_base()
class SubscriptionStatus(enum.Enum):
ACTIVE = "active"
EXPIRED = "expired"
PENDING = "pending"
CANCELLED = "cancelled"
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
user_id = Column(String, unique=True, nullable=False) # UUID generated in frontend
created_at = Column(DateTime, default=datetime.datetime.utcnow)
subscriptions = relationship("Subscription", back_populates="user")
payments = relationship("Payment", back_populates="user")
class Subscription(Base):
__tablename__ = 'subscriptions'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('users.id'))
invoice_id = Column(String, unique=True)
public_key = Column(Text, nullable=False) # WireGuard public key
start_time = Column(DateTime, nullable=False)
expiry_time = Column(DateTime, nullable=False)
status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.PENDING)
warning_sent = Column(Integer, default=0)
assigned_ip = Column(String) # WireGuard IP address assigned to this subscription
user = relationship("User", back_populates="subscriptions")
payments = relationship("Payment", back_populates="subscription")
class Payment(Base):
__tablename__ = 'payments'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey('users.id'))
subscription_id = Column(Integer, ForeignKey('subscriptions.id'))
invoice_id = Column(String, unique=True)
amount = Column(Integer, nullable=False) # Amount in sats
timestamp = Column(DateTime, default=datetime.datetime.utcnow)
user = relationship("User", back_populates="payments")
subscription = relationship("Subscription", back_populates="payments")

152
app/utils/db/operations.py Normal file
View File

@ -0,0 +1,152 @@
from datetime import datetime
from sqlalchemy.exc import SQLAlchemyError
from . import get_session
from .models import User, Subscription, Payment, SubscriptionStatus
import logging
import ipaddress
logger = logging.getLogger(__name__)
class DatabaseManager:
@staticmethod
def get_next_available_ip():
"""Get the next available IP from the WireGuard subnet"""
with get_session() as session:
# Get all assigned IPs
assigned_ips = session.query(Subscription.assigned_ip)\
.filter(Subscription.assigned_ip.isnot(None))\
.all()
assigned_ips = [ip[0] for ip in assigned_ips]
# Start from 10.8.0.2 (10.8.0.1 is server)
network = ipaddress.IPv4Network('10.8.0.0/24')
for ip in network.hosts():
str_ip = str(ip)
if str_ip == '10.8.0.1': # Skip server IP
continue
if str_ip not in assigned_ips:
return str_ip
raise Exception("No available IPs in the subnet")
@staticmethod
def create_user(user_id):
"""Create a new user with UUID"""
with get_session() as session:
user = User(user_id=user_id)
session.add(user)
session.commit()
return user
@staticmethod
def get_user_by_uuid(user_id):
"""Get user by UUID"""
with get_session() as session:
return session.query(User).filter(User.user_id == user_id).first()
@staticmethod
def get_subscription_by_invoice(invoice_id):
"""Get subscription by invoice ID"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
@staticmethod
def create_subscription(user_id, invoice_id, public_key, duration_hours):
"""Create a new subscription"""
with get_session() as session:
try:
# Get user or create if doesn't exist
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
user = DatabaseManager.create_user(user_id)
start_time = datetime.utcnow()
expiry_time = start_time + datetime.timedelta(hours=duration_hours)
# Get next available IP
assigned_ip = DatabaseManager.get_next_available_ip()
subscription = Subscription(
user_id=user.id,
invoice_id=invoice_id,
public_key=public_key,
start_time=start_time,
expiry_time=expiry_time,
status=SubscriptionStatus.PENDING,
assigned_ip=assigned_ip
)
session.add(subscription)
session.commit()
return subscription
except Exception as e:
logger.error(f"Error creating subscription: {str(e)}")
session.rollback()
raise
@staticmethod
def activate_subscription(invoice_id):
"""Activate a subscription after payment"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.invoice_id == invoice_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.ACTIVE
session.commit()
return subscription
@staticmethod
def record_payment(user_id, subscription_id, invoice_id, amount):
"""Record a payment"""
with get_session() as session:
user = DatabaseManager.get_user_by_uuid(user_id)
if not user:
raise ValueError(f"User {user_id} not found")
payment = Payment(
user_id=user.id,
subscription_id=subscription_id,
invoice_id=invoice_id,
amount=amount
)
session.add(payment)
session.commit()
return payment
@staticmethod
def get_active_subscriptions():
"""Get all active subscriptions"""
with get_session() as session:
return session.query(Subscription)\
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
.all()
@staticmethod
def expire_subscription(subscription_id):
"""Mark a subscription as expired"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.status = SubscriptionStatus.EXPIRED
session.commit()
return True
return False
@staticmethod
def update_warning_sent(subscription_id):
"""Update the warning_sent flag for a subscription"""
with get_session() as session:
subscription = session.query(Subscription)\
.filter(Subscription.id == subscription_id)\
.first()
if subscription:
subscription.warning_sent = 1
session.commit()
return True
return False

59
app/utils/db/utils.py Normal file
View File

@ -0,0 +1,59 @@
import sqlite3
import shutil
from datetime import datetime
from pathlib import Path
import logging
from .. import get_db_path
logger = logging.getLogger(__name__)
def backup_database():
"""Create a backup of the SQLite database"""
try:
db_path = get_db_path()
if not db_path.exists():
logger.error("Database file not found")
return False
backup_dir = db_path.parent / 'backups'
backup_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
backup_path = backup_dir / f'vpn_db_backup_{timestamp}.db'
shutil.copy2(db_path, backup_path)
logger.info(f"Database backed up successfully to {backup_path}")
# Keep only last 5 backups
backups = sorted(backup_dir.glob('vpn_db_backup_*.db'))
if len(backups) > 5:
for old_backup in backups[:-5]:
old_backup.unlink()
logger.info(f"Removed old backup: {old_backup}")
return True
except Exception as e:
logger.error(f"Backup failed: {str(e)}")
return False
def verify_database():
"""Check database integrity"""
try:
db_path = get_db_path()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA integrity_check")
result = cursor.fetchone()[0]
if result != "ok":
logger.error(f"Database integrity check failed: {result}")
return False
logger.info("Database integrity verified")
return True
except Exception as e:
logger.error(f"Database verification failed: {str(e)}")
return False

BIN
data/vpn.db Normal file

Binary file not shown.

View File

@ -3,4 +3,5 @@ pyyaml==6.0.1
python-dotenv==1.0.0 python-dotenv==1.0.0
cryptography==41.0.7 cryptography==41.0.7
ansible==9.1.0 ansible==9.1.0
requests==2.31.0 requests==2.31.0
SQLAlchemy==2.0.25

28
scripts/init_db.py Normal file
View File

@ -0,0 +1,28 @@
# scripts/init_db.py
import sys
from pathlib import Path
# Add the project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from app.utils.db.models import Base
from sqlalchemy import create_engine
def init_db():
# Create data directory if it doesn't exist
data_dir = project_root / 'data'
data_dir.mkdir(exist_ok=True)
# Create database
db_path = data_dir / 'vpn.db'
db_url = f"sqlite:///{db_path}"
engine = create_engine(db_url)
# Create all tables
Base.metadata.create_all(engine)
print(f"Database initialized at: {db_path}")
return engine
if __name__ == "__main__":
init_db()

151
scripts/migrate_db.py Normal file
View File

@ -0,0 +1,151 @@
import sys
from pathlib import Path
import sqlite3
import uuid
import logging
from datetime import datetime
# Add the project root to Python path
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from app.utils.db.models import Base, SubscriptionStatus
from app.utils.db import get_db_path, get_session
from sqlalchemy import create_engine
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def backup_database():
"""Create a backup of the existing database"""
db_path = get_db_path()
if not db_path.exists():
logger.info("No existing database found, skipping backup")
return
backup_path = db_path.parent / f"{db_path.stem}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}{db_path.suffix}"
db_path.rename(backup_path)
logger.info(f"Created database backup at {backup_path}")
return backup_path
def migrate_database():
"""Perform database migration"""
try:
# Create backup
backup_path = backup_database()
# Initialize new database
engine = create_engine(f"sqlite:///{get_db_path()}")
Base.metadata.create_all(engine)
# If there's an old database, migrate the data
if backup_path and backup_path.exists():
logger.info("Migrating data from old database...")
# Connect to old database
old_conn = sqlite3.connect(backup_path)
old_cursor = old_conn.cursor()
# Get schema information
old_cursor.execute("SELECT sql FROM sqlite_master WHERE type='table'")
old_schema = old_cursor.fetchall()
logger.info("Old database schema:")
for table in old_schema:
logger.info(table[0])
with get_session() as session:
try:
# Start transaction
session.begin()
# Migrate users
logger.info("Migrating users...")
old_cursor.execute("SELECT id, email, created_at FROM users")
users = old_cursor.fetchall()
user_id_map = {} # Map old IDs to new UUIDs
for old_id, email, created_at in users:
new_user_id = str(uuid.uuid4())
user_id_map[old_id] = new_user_id
session.execute(
"INSERT INTO users (user_id, created_at) VALUES (?, ?)",
[new_user_id, created_at or datetime.utcnow()]
)
# Migrate subscriptions
logger.info("Migrating subscriptions...")
old_cursor.execute("""
SELECT id, user_id, invoice_id, start_time, expiry_time,
status, warning_sent
FROM subscriptions
""")
subscriptions = old_cursor.fetchall()
for sub in subscriptions:
old_id, old_user_id, invoice_id, start_time, expiry_time, status, warning_sent = sub
if old_user_id in user_id_map:
# Generate a placeholder public key for existing subscriptions
placeholder_pubkey = f"MIGRATED_{uuid.uuid4()}"
session.execute("""
INSERT INTO subscriptions
(user_id, invoice_id, public_key, start_time, expiry_time,
status, warning_sent, assigned_ip)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", [
user_id_map[old_user_id],
invoice_id,
placeholder_pubkey,
start_time,
expiry_time,
status,
warning_sent,
f"10.8.0.{2 + old_id}" # Simple IP assignment
])
# Migrate payments
logger.info("Migrating payments...")
old_cursor.execute("""
SELECT user_id, subscription_id, invoice_id, amount, timestamp
FROM payments
""")
payments = old_cursor.fetchall()
for payment in payments:
old_user_id, sub_id, invoice_id, amount, timestamp = payment
if old_user_id in user_id_map:
session.execute("""
INSERT INTO payments
(user_id, subscription_id, invoice_id, amount, timestamp)
VALUES (?, ?, ?, ?, ?)
""", [
user_id_map[old_user_id],
sub_id,
invoice_id,
amount,
timestamp
])
# Commit transaction
session.commit()
logger.info("Migration completed successfully")
except Exception as e:
session.rollback()
logger.error(f"Migration failed: {str(e)}")
raise
old_conn.close()
else:
logger.info("No existing database to migrate")
except Exception as e:
logger.error(f"Migration failed: {str(e)}")
raise
if __name__ == '__main__':
try:
migrate_database()
except Exception as e:
logger.error(f"Migration failed: {str(e)}")
sys.exit(1)

View File

@ -1,14 +1,12 @@
# scripts/subscription_checker.py
import json
import datetime
import logging import logging
import subprocess import subprocess
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from dateutil.parser import parse from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta from utils.db.operations import DatabaseManager
from utils.db.models import SubscriptionStatus
from app.handlers.payment_handler import BTCPayHandler
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@ -19,20 +17,19 @@ logger = logging.getLogger(__name__)
# Path setup # Path setup
SCRIPT_DIR = Path(__file__).resolve().parent SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = SCRIPT_DIR.parent PROJECT_ROOT = SCRIPT_DIR.parent
SUBSCRIPTION_DB = PROJECT_ROOT / 'data' / 'subscriptions.json'
CLEANUP_PLAYBOOK = PROJECT_ROOT / 'ansible' / 'playbooks' / 'vpn_cleanup.yml' CLEANUP_PLAYBOOK = PROJECT_ROOT / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
INVENTORY_FILE = PROJECT_ROOT / 'inventory.ini' INVENTORY_FILE = PROJECT_ROOT / 'inventory.ini'
# Notification thresholds configuration # Notification thresholds configuration
NOTIFICATION_THRESHOLDS = { NOTIFICATION_THRESHOLDS = {
'minimum_duration': datetime.timedelta(hours=1), # Minimum subscription duration 'minimum_duration': timedelta(hours=1), # Minimum subscription duration
'short_term': { 'short_term': {
'max_duration': datetime.timedelta(days=1), 'max_duration': timedelta(days=1),
'warning_fraction': 0.5, # Warn when 50% of time remains 'warning_fraction': 0.5, # Warn when 50% of time remains
'grace_fraction': 0.1 # Grace period of 10% of subscription length 'grace_fraction': 0.1 # Grace period of 10% of subscription length
}, },
'medium_term': { 'medium_term': {
'max_duration': datetime.timedelta(days=7), 'max_duration': timedelta(days=7),
'warning_fraction': 0.25, # Warn when 25% of time remains 'warning_fraction': 0.25, # Warn when 25% of time remains
'grace_hours': 12 # Fixed 12-hour grace period 'grace_hours': 12 # Fixed 12-hour grace period
}, },
@ -42,25 +39,9 @@ NOTIFICATION_THRESHOLDS = {
} }
} }
def load_subscriptions(): def run_cleanup_playbook(subscription_id):
"""Load subscription data from JSON file"""
if not SUBSCRIPTION_DB.parent.exists():
SUBSCRIPTION_DB.parent.mkdir(parents=True)
if not SUBSCRIPTION_DB.exists():
return {}
with open(SUBSCRIPTION_DB, 'r') as f:
return json.load(f)
def save_subscriptions(subscriptions):
"""Save subscriptions to JSON file"""
with open(SUBSCRIPTION_DB, 'w') as f:
json.dump(subscriptions, f, indent=2)
def run_cleanup_playbook(sub_id):
"""Run the VPN cleanup playbook""" """Run the VPN cleanup playbook"""
logger.info(f"Running cleanup playbook for subscription {sub_id}") logger.info(f"Running cleanup playbook for subscription {subscription_id}")
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '') vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
if not vault_pass: if not vault_pass:
@ -74,7 +55,7 @@ def run_cleanup_playbook(sub_id):
'ansible-playbook', 'ansible-playbook',
str(CLEANUP_PLAYBOOK), str(CLEANUP_PLAYBOOK),
'-i', str(INVENTORY_FILE), '-i', str(INVENTORY_FILE),
'-e', f'subscription_id={sub_id}', '-e', f'subscription_id={subscription_id}',
'--vault-password-file', vault_pass_file.name, '--vault-password-file', vault_pass_file.name,
'-vvv' '-vvv'
] ]
@ -122,42 +103,36 @@ def calculate_notification_times(start_time, end_time):
warning_delta = duration * NOTIFICATION_THRESHOLDS['medium_term']['warning_fraction'] warning_delta = duration * NOTIFICATION_THRESHOLDS['medium_term']['warning_fraction']
grace_hours = NOTIFICATION_THRESHOLDS['medium_term']['grace_hours'] grace_hours = NOTIFICATION_THRESHOLDS['medium_term']['grace_hours']
warning_time = end_time - warning_delta warning_time = end_time - warning_delta
grace_end_time = end_time + datetime.timedelta(hours=grace_hours) grace_end_time = end_time + timedelta(hours=grace_hours)
# Long-term subscriptions (> 7 days) # Long-term subscriptions (> 7 days)
else: else:
warning_days = NOTIFICATION_THRESHOLDS['long_term']['warning_days'] warning_days = NOTIFICATION_THRESHOLDS['long_term']['warning_days']
grace_days = NOTIFICATION_THRESHOLDS['long_term']['grace_days'] grace_days = NOTIFICATION_THRESHOLDS['long_term']['grace_days']
warning_time = end_time - datetime.timedelta(days=warning_days) warning_time = end_time - timedelta(days=warning_days)
grace_end_time = end_time + datetime.timedelta(days=grace_days) grace_end_time = end_time + timedelta(days=grace_days)
return warning_time, grace_end_time return warning_time, grace_end_time
def get_notification_message(sub_data, remaining_time): def get_notification_message(subscription, remaining_time):
"""Generate appropriate notification message based on subscription duration""" """Generate appropriate notification message based on subscription duration"""
if remaining_time < datetime.timedelta(hours=1): if remaining_time < timedelta(hours=1):
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 60)} minutes!" return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 60)} minutes!"
elif remaining_time < datetime.timedelta(days=1): elif remaining_time < timedelta(days=1):
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 3600)} hours!" return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 3600)} hours!"
else: else:
return f"Your VPN subscription expires in {remaining_time.days} days!" return f"Your VPN subscription expires in {remaining_time.days} days!"
def notify_user(user_id, message): def notify_user(subscription, message):
"""Send notification to user about subscription status""" """Send notification to user about subscription status"""
try: try:
subscriptions = load_subscriptions() if not subscription.user or not subscription.user.email:
sub_data = subscriptions.get(user_id) logger.error(f"No email found for subscription {subscription.id}")
if not sub_data or 'email' not in sub_data:
logger.error(f"No email found for user {user_id}")
return False return False
# Import BTCPayHandler here to avoid circular imports
from app.handlers.payment_handler import BTCPayHandler
btcpay_handler = BTCPayHandler() btcpay_handler = BTCPayHandler()
email_sent = btcpay_handler.send_confirmation_email( email_sent = btcpay_handler.send_confirmation_email(
sub_data['email'], subscription.user.email,
f""" f"""
VPN Subscription Update VPN Subscription Update
@ -168,9 +143,9 @@ def notify_user(user_id, message):
) )
if email_sent: if email_sent:
logger.info(f"Sent notification to {sub_data['email']}: {message}") logger.info(f"Sent notification to {subscription.user.email}: {message}")
else: else:
logger.warning(f"Failed to send notification to {sub_data['email']}") logger.warning(f"Failed to send notification to {subscription.user.email}")
return email_sent return email_sent
@ -182,68 +157,64 @@ def check_subscriptions():
"""Check subscription status and clean up expired ones""" """Check subscription status and clean up expired ones"""
logger.info("Starting subscription check") logger.info("Starting subscription check")
subscriptions = load_subscriptions() try:
now = datetime.datetime.now() active_subscriptions = DatabaseManager.get_active_subscriptions()
modified = False logger.info(f"Checking {len(active_subscriptions)} active subscriptions")
logger.info(f"Checking {len(subscriptions)} subscriptions") now = datetime.utcnow()
for sub_id, sub_data in list(subscriptions.items()): for subscription in active_subscriptions:
try: try:
logger.debug(f"Processing subscription {sub_id}") logger.debug(f"Processing subscription {subscription.id}")
if sub_data.get('status') != 'Active':
logger.debug(f"Skipping inactive subscription {sub_id}")
continue
start_time = parse(sub_data['start_time']) warning_time, grace_end_time = calculate_notification_times(
expiry = parse(sub_data['expiry']) subscription.start_time,
warning_time, grace_end_time = calculate_notification_times(start_time, expiry) subscription.expiry_time
)
# Calculate remaining time
remaining_time = expiry - now
logger.debug(f"Subscription {sub_id} has {remaining_time} remaining")
# Handle warnings
if now >= warning_time and not sub_data.get('warning_sent'):
message = get_notification_message(sub_data, remaining_time)
logger.info(f"Sending notification for subscription {sub_id}: {message}")
notify_user(sub_id, message) # Calculate remaining time
remaining_time = subscription.expiry_time - now
logger.debug(f"Subscription {subscription.id} has {remaining_time} remaining")
sub_data['warning_sent'] = True # Handle warnings
modified = True if now >= warning_time and not subscription.warning_sent:
message = get_notification_message(subscription, remaining_time)
# Handle expiration logger.info(f"Sending notification for subscription {subscription.id}: {message}")
if now >= grace_end_time and sub_data.get('status') == 'Active':
logger.info(f"Processing expiration for subscription {sub_id}")
try:
logger.debug(f"Running cleanup playbook for {sub_id}")
result = run_cleanup_playbook(sub_id)
if result.returncode == 0: if notify_user(subscription, message):
sub_data['status'] = 'Expired' DatabaseManager.update_warning_sent(subscription.id)
sub_data['cleanup_date'] = now.isoformat()
modified = True # Handle expiration
logger.info(f"Successfully cleaned up subscription {sub_id}") if now >= grace_end_time:
else: logger.info(f"Processing expiration for subscription {subscription.id}")
logger.error(f"Cleanup failed: {result.stderr}")
try:
result = run_cleanup_playbook(subscription.invoice_id)
except Exception as e: if result.returncode == 0:
logger.error(f"Error during cleanup: {str(e)}") DatabaseManager.expire_subscription(subscription.id)
logger.error(traceback.format_exc()) logger.info(f"Successfully cleaned up subscription {subscription.id}")
except Exception as e: # Send final notification
logger.error(f"Error processing subscription {sub_id}: {str(e)}") notify_user(subscription, "Your VPN subscription has expired and been deactivated.")
logger.error(traceback.format_exc()) else:
continue logger.error(f"Cleanup failed: {result.stderr}")
if modified: except Exception as e:
save_subscriptions(subscriptions) logger.error(f"Error during cleanup: {str(e)}")
logger.info("Updated subscription database") logger.error(traceback.format_exc())
logger.info("Subscription check completed") except Exception as e:
logger.error(f"Error processing subscription {subscription.id}: {str(e)}")
logger.error(traceback.format_exc())
continue
logger.info("Subscription check completed")
except Exception as e:
logger.error(f"Subscription checker failed: {str(e)}")
logger.error(traceback.format_exc())
raise
if __name__ == '__main__': if __name__ == '__main__':
try: try:

View File

@ -0,0 +1,164 @@
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
/* Greenlet object interface */
#ifndef Py_GREENLETOBJECT_H
#define Py_GREENLETOBJECT_H
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
/* This is deprecated and undocumented. It does not change. */
#define GREENLET_VERSION "1.0.0"
#ifndef GREENLET_MODULE
#define implementation_ptr_t void*
#endif
typedef struct _greenlet {
PyObject_HEAD
PyObject* weakreflist;
PyObject* dict;
implementation_ptr_t pimpl;
} PyGreenlet;
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
/* C API functions */
/* Total number of symbols that are exported */
#define PyGreenlet_API_pointers 12
#define PyGreenlet_Type_NUM 0
#define PyExc_GreenletError_NUM 1
#define PyExc_GreenletExit_NUM 2
#define PyGreenlet_New_NUM 3
#define PyGreenlet_GetCurrent_NUM 4
#define PyGreenlet_Throw_NUM 5
#define PyGreenlet_Switch_NUM 6
#define PyGreenlet_SetParent_NUM 7
#define PyGreenlet_MAIN_NUM 8
#define PyGreenlet_STARTED_NUM 9
#define PyGreenlet_ACTIVE_NUM 10
#define PyGreenlet_GET_PARENT_NUM 11
#ifndef GREENLET_MODULE
/* This section is used by modules that uses the greenlet C API */
static void** _PyGreenlet_API = NULL;
# define PyGreenlet_Type \
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
# define PyExc_GreenletError \
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
# define PyExc_GreenletExit \
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
/*
* PyGreenlet_New(PyObject *args)
*
* greenlet.greenlet(run, parent=None)
*/
# define PyGreenlet_New \
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
_PyGreenlet_API[PyGreenlet_New_NUM])
/*
* PyGreenlet_GetCurrent(void)
*
* greenlet.getcurrent()
*/
# define PyGreenlet_GetCurrent \
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
/*
* PyGreenlet_Throw(
* PyGreenlet *greenlet,
* PyObject *typ,
* PyObject *val,
* PyObject *tb)
*
* g.throw(...)
*/
# define PyGreenlet_Throw \
(*(PyObject * (*)(PyGreenlet * self, \
PyObject * typ, \
PyObject * val, \
PyObject * tb)) \
_PyGreenlet_API[PyGreenlet_Throw_NUM])
/*
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
*
* g.switch(*args, **kwargs)
*/
# define PyGreenlet_Switch \
(*(PyObject * \
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
_PyGreenlet_API[PyGreenlet_Switch_NUM])
/*
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
*
* g.parent = new_parent
*/
# define PyGreenlet_SetParent \
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
/*
* PyGreenlet_GetParent(PyObject* greenlet)
*
* return greenlet.parent;
*
* This could return NULL even if there is no exception active.
* If it does not return NULL, you are responsible for decrementing the
* reference count.
*/
# define PyGreenlet_GetParent \
(*(PyGreenlet* (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
/*
* deprecated, undocumented alias.
*/
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
# define PyGreenlet_MAIN \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
# define PyGreenlet_STARTED \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
# define PyGreenlet_ACTIVE \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
/* Macro that imports greenlet and initializes C API */
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
keep the older definition to be sure older code that might have a copy of
the header still works. */
# define PyGreenlet_Import() \
{ \
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
}
#endif /* GREENLET_MODULE */
#ifdef __cplusplus
}
#endif
#endif /* !Py_GREENLETOBJECT_H */