diff --git a/ansible/playbooks/templates/client.conf.j2 b/ansible/playbooks/templates/client.conf.j2 index ec46067..bad38be 100644 --- a/ansible/playbooks/templates/client.conf.j2 +++ b/ansible/playbooks/templates/client.conf.j2 @@ -1,5 +1,5 @@ [Interface] -PrivateKey = {{ client_private_key.stdout }} +PrivateKey = {{ client_private_key }} Address = {{ client_ip }}/24 DNS = 1.1.1.1 diff --git a/ansible/playbooks/vpn_cleanup.yml b/ansible/playbooks/vpn_cleanup.yml index 65ac0c9..ee54021 100644 --- a/ansible/playbooks/vpn_cleanup.yml +++ b/ansible/playbooks/vpn_cleanup.yml @@ -4,16 +4,25 @@ become: yes vars: client_dir: /etc/wireguard/clients + test_client_dir: /etc/wireguard/test_clients wg_interface: wg0 + is_test: false # Default to production mode tasks: - - name: Debug subscription ID + - name: Debug cleanup information debug: - msg: "Cleaning up subscription ID: {{ subscription_id }}" - + msg: + - "Cleaning up subscription ID: {{ subscription_id }}" + - "Test mode: {{ is_test }}" + + # Set working directory based on mode + - name: Set working directory based on mode + set_fact: + working_client_dir: "{{ test_client_dir if is_test else client_dir }}" + - name: Remove client configuration directory file: - path: "{{ client_dir }}/{{ subscription_id }}" + path: "{{ working_client_dir }}/{{ subscription_id }}" state: absent - name: Remove client from server config @@ -22,6 +31,17 @@ marker: "# {mark} ANSIBLE MANAGED BLOCK FOR {{ subscription_id }}" state: absent notify: restart wireguard + + # Remove cleanup cron job if it exists (for test configs) + - name: Remove cleanup cronjob + when: is_test + cron: + name: "cleanup_test_vpn_{{ subscription_id }}" + state: absent + + - name: Log cleanup + shell: | + logger -t vpn-cleanup "Cleaned up VPN configuration for {{ subscription_id }} ({{ 'test' if is_test else 'production' }})" handlers: - name: restart wireguard diff --git a/ansible/playbooks/vpn_provision.yml b/ansible/playbooks/vpn_provision.yml index d4dc662..2cb5e95 100644 --- a/ansible/playbooks/vpn_provision.yml +++ b/ansible/playbooks/vpn_provision.yml @@ -4,18 +4,74 @@ become: yes vars: client_dir: /etc/wireguard/clients + test_client_dir: /etc/wireguard/test_clients wg_interface: wg0 server_dir: /etc/wireguard server_ip: 10.8.0.1/24 server_port: 51820 server_endpoint: "{{ ansible_host | default(inventory_hostname) }}" - + is_test: false # Default to production mode + test_duration_minutes: 30 # Default test duration + + pre_tasks: + - name: Check if WireGuard is installed + package_facts: + manager: auto + + - name: Install WireGuard (Debian/Ubuntu) + apt: + name: + - wireguard + - wireguard-tools + state: present + update_cache: yes + when: + - ansible_facts['os_family'] == "Debian" + - "'wireguard' not in ansible_facts.packages" + + - name: Install WireGuard (RHEL/CentOS) + dnf: + name: + - wireguard-tools + - wireguard-dkms + state: present + when: + - ansible_facts['os_family'] == "RedHat" + - "'wireguard-tools' not in ansible_facts.packages" + + - name: Ensure WireGuard kernel module is loaded + modprobe: + name: wireguard + state: present + + - name: Verify WireGuard installation + command: which wg + register: wg_check + failed_when: wg_check.rc != 0 + changed_when: false + tasks: - - name: Debug invoice ID + - name: Debug invoice ID and test status debug: - msg: "Processing invoice ID: {{ invoice_id }}" + msg: + - "Processing invoice ID: {{ invoice_id }}" + - "Test mode: {{ is_test }}" + - "Test duration: {{ test_duration_minutes if is_test else 'N/A' }}" + + - name: Create required directories + file: + path: "{{ item }}" + state: directory + mode: '0700' + with_items: + - "{{ client_dir }}" + - "{{ test_client_dir }}" + - "{{ server_dir }}" + + - name: Set working directory based on mode + set_fact: + working_client_dir: "{{ test_client_dir if is_test else client_dir }}" - # Server Setup Tasks - name: Check if server keys exist stat: path: "{{ server_dir }}/{{ wg_interface }}.conf" @@ -45,58 +101,43 @@ mode: '0644' when: not server_config.stat.exists - - name: Update vault with server details - block: - - name: Read server public key - shell: "cat {{ server_dir }}/public.key" - register: pubkey_content - changed_when: false - - - name: Save server details to vault - copy: - content: | - wireguard_server_public_key: "{{ pubkey_content.stdout }}" - wireguard_server_endpoint: "{{ ansible_host }}" - dest: "{{ playbook_dir }}/../group_vars/vpn_servers/vault.yml" - mode: '0600' - when: not server_config.stat.exists - - name: Create initial server config template: src: templates/server.conf.j2 dest: "{{ server_dir }}/{{ wg_interface }}.conf" mode: '0600' when: not server_config.stat.exists + notify: restart wireguard - # Client Setup Tasks - name: Ensure client directory exists file: - path: "{{ client_dir }}/{{ invoice_id }}" + path: "{{ working_client_dir }}/{{ invoice_id }}" state: directory mode: '0700' - - - name: Generate client private key + + # Generate keys - no longer differentiating between test and production + - name: Generate private key shell: wg genkey - register: client_private_key - no_log: true - - - name: Save client private key + register: private_key + changed_when: false + + - name: Generate public key + shell: echo "{{ private_key.stdout }}" | wg pubkey + register: public_key + changed_when: false + + - name: Save private key copy: - content: "{{ client_private_key.stdout }}" - dest: "{{ client_dir }}/{{ invoice_id }}/private.key" + content: "{{ private_key.stdout }}" + dest: "{{ working_client_dir }}/{{ invoice_id }}/private.key" mode: '0600' - no_log: true - - - name: Generate client public key - shell: "echo '{{ client_private_key.stdout }}' | wg pubkey" - register: client_public_key - - - name: Save client public key + + - name: Save public key copy: - content: "{{ client_public_key.stdout }}" - dest: "{{ client_dir }}/{{ invoice_id }}/public.key" + content: "{{ public_key.stdout }}" + dest: "{{ working_client_dir }}/{{ invoice_id }}/public.key" mode: '0644' - + - name: Read server public key shell: "cat {{ server_dir }}/public.key" register: server_public_key_read @@ -104,18 +145,19 @@ - name: Get next available IP shell: | - last_ip=$(grep -h '^Address' {{ client_dir }}/*/wg0.conf 2>/dev/null | tail -n1 | grep -oE '[0-9]+$' || echo 1) + last_ip=$(grep -h '^Address' {{ working_client_dir }}/*/wg0.conf 2>/dev/null | tail -n1 | grep -oE '[0-9]+$' || echo 1) echo $((last_ip + 1)) register: next_ip - name: Generate client config template: src: templates/client.conf.j2 - dest: "{{ client_dir }}/{{ invoice_id }}/wg0.conf" + dest: "{{ working_client_dir }}/{{ invoice_id }}/wg0.conf" mode: '0600' vars: client_ip: "10.8.0.{{ next_ip.stdout }}" server_pubkey: "{{ server_public_key_read.stdout }}" + client_private_key: "{{ private_key.stdout }}" - name: Add client to server config blockinfile: @@ -123,12 +165,33 @@ marker: "# {mark} ANSIBLE MANAGED BLOCK FOR {{ invoice_id }}" block: | [Peer] - PublicKey = {{ client_public_key.stdout }} + PublicKey = {{ public_key.stdout }} AllowedIPs = 10.8.0.{{ next_ip.stdout }}/32 + {% if is_test %}# Test config expires: {{ ansible_date_time.iso8601 }}{% endif %} notify: restart wireguard - + + # Calculate cleanup time for test configurations + - name: Calculate cleanup time + when: is_test + set_fact: + cleanup_minute: "{{ (ansible_date_time.minute | int + (test_duration_minutes | int)) % 60 }}" + cleanup_hour: "{{ (ansible_date_time.hour | int + ((ansible_date_time.minute | int + (test_duration_minutes | int)) // 60)) % 24 }}" + + - name: Add cleanup cronjob for test configs + when: is_test + cron: + name: "cleanup_test_vpn_{{ invoice_id }}" + minute: "{{ cleanup_minute }}" + hour: "{{ cleanup_hour }}" + job: "ansible-playbook {{ playbook_dir }}/vpn_cleanup.yml -e 'invoice_id={{ invoice_id }} is_test=true'" + state: present + + - name: Log provision completion + shell: | + logger -t vpn-provision "Provisioned VPN for {{ invoice_id }} ({{ 'test' if is_test else 'production' }}){% if is_test %} - expires in {{ test_duration_minutes }} minutes{% endif %}" + handlers: - name: restart wireguard service: - name: wg-quick@{{ wg_interface }} + name: "wg-quick@{{ wg_interface }}" state: restarted \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py index 688ef97..6c5342c 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,7 +1,9 @@ from flask import Flask, request, jsonify, render_template import logging +from pathlib import Path from .handlers.webhook_handler import handle_payment_webhook from .handlers.payment_handler import BTCPayHandler +from .utils.db.operations import DatabaseManager # Set up logging logging.basicConfig( @@ -45,29 +47,37 @@ def calculate_price(): @app.route('/create-invoice', methods=['POST']) def create_invoice(): try: - logger.info("Received invoice creation request") + logger.info("=== Create Invoice Request Started ===") + logger.info(f"Received invoice creation request with data: {request.json}") data = request.json logger.debug(f"Request data: {data}") - + # Validate input data duration_hours = data.get('duration') - email = data.get('email') - - if not email: - logger.error("Email address missing from request") - return jsonify({'error': 'Email is required'}), 400 - + user_id = data.get('user_id') + public_key = data.get('public_key') + + logger.info(f"Validating request parameters: duration={duration_hours}, user_id={user_id}, has_public_key={bool(public_key)}") + + # Validate required fields if not duration_hours: logger.error("Duration missing from request") return jsonify({'error': 'Duration is required'}), 400 - + if not user_id: + logger.error("User ID missing from request") + return jsonify({'error': 'User ID is required'}), 400 + if not public_key: + logger.error("Public key missing from request") + return jsonify({'error': 'Public key is required'}), 400 + try: duration_hours = int(duration_hours) + logger.info(f"Converted duration to integer: {duration_hours}") except ValueError: logger.error(f"Invalid duration value: {duration_hours}") return jsonify({'error': 'Invalid duration value'}), 400 - - # Calculate price using same logic as calculate-price endpoint + + # Calculate price base_price = duration_hours * 100 # 100 sats per hour if duration_hours >= 720: # 1 month @@ -76,24 +86,69 @@ def create_invoice(): base_price = base_price * 0.90 # 10% discount elif duration_hours >= 24: # 1 day base_price = base_price * 0.95 # 5% discount - + amount_sats = int(base_price) logger.info(f"Calculated price: {amount_sats} sats for {duration_hours} hours") - + # Create BTCPay invoice - invoice_data = btcpay_handler.create_invoice(amount_sats, duration_hours, email) - + logger.info("Creating BTCPay invoice") + invoice_data = btcpay_handler.create_invoice( + amount_sats=amount_sats, + duration_hours=duration_hours, + user_id=user_id, + public_key=public_key + ) + if not invoice_data: logger.error("Failed to create invoice - no data returned from BTCPayHandler") return jsonify({'error': 'Failed to create invoice'}), 500 - + logger.info(f"Successfully created invoice with ID: {invoice_data.get('invoice_id')}") + logger.info("=== Create Invoice Request Completed ===") return jsonify(invoice_data) - + except Exception as e: logger.error(f"Error in create_invoice endpoint: {str(e)}") + logger.error(f"Traceback: ", exc_info=True) return jsonify({'error': str(e)}), 500 +@app.route('/api/vpn-config/') +def get_vpn_config(user_id): + try: + logger.info(f"Fetching VPN config for user: {user_id}") + subscription = DatabaseManager.get_active_subscription_for_user(user_id) + if not subscription: + logger.error(f"No active subscription found for user {user_id}") + return jsonify({"error": "No active subscription found"}), 404 + + # Get the config based on test or production path + base_path = Path('/etc/wireguard') + if subscription.invoice_id.startswith('__test__'): + config_path = base_path / 'test_clients' / subscription.invoice_id / 'wg0.conf' + else: + config_path = base_path / 'clients' / subscription.invoice_id / 'wg0.conf' + + logger.info(f"Looking for config at: {config_path}") + + if not config_path.exists(): + logger.error(f"Configuration file not found at {config_path}") + return jsonify({"error": "Configuration file not found"}), 404 + + with open(config_path) as f: + config_text = f.read() + + logger.info(f"Successfully retrieved config for user {user_id}") + return jsonify({ + "configText": config_text, + "status": "active", + "expiryTime": subscription.expiry_time.isoformat() if subscription.expiry_time else None + }) + + except Exception as e: + logger.error(f"Error retrieving VPN config: {str(e)}") + logger.error("Traceback:", exc_info=True) + return jsonify({"error": "Failed to retrieve configuration"}), 500 + @app.route('/payment/success') def payment_success(): return render_template('payment_success.html') diff --git a/app/data/vpn.db b/app/data/vpn.db index 6b7c17d..8afa83a 100644 Binary files a/app/data/vpn.db and b/app/data/vpn.db differ diff --git a/app/handlers/webhook_handler.py b/app/handlers/webhook_handler.py index 339262b..12e681c 100644 --- a/app/handlers/webhook_handler.py +++ b/app/handlers/webhook_handler.py @@ -7,11 +7,15 @@ import hmac import hashlib import yaml import datetime +import uuid import traceback from pathlib import Path from dotenv import load_dotenv +from sqlalchemy.orm import joinedload +from ..utils.db.models import Subscription, Payment from ..utils.db.operations import DatabaseManager from ..utils.db.models import SubscriptionStatus +from ..utils.ansible_logger import AnsibleLogger load_dotenv() @@ -26,6 +30,8 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) +ansible_logger = AnsibleLogger() + # Constants BASE_DIR = Path(__file__).resolve().parent.parent.parent PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml' @@ -124,25 +130,14 @@ def verify_signature(payload_body: bytes, signature_header: str) -> bool: logger.error(f"Signature verification failed: {traceback.format_exc()}") return False -def run_ansible_playbook(invoice_id: str, cleanup: bool = False) -> subprocess.CompletedProcess: - """ - Run the appropriate Ansible playbook with proper error handling - - Args: - invoice_id: BTCPay invoice ID - cleanup: Whether to run cleanup playbook instead of provision - - Returns: - subprocess.CompletedProcess: Playbook execution result - - Raises: - WebhookError: If playbook execution fails - """ +def run_ansible_playbook(invoice_id: str, cleanup: bool = False, extra_vars: dict = None) -> subprocess.CompletedProcess: + """Run the appropriate Ansible playbook with logging""" try: + operation_type = 'cleanup' if cleanup else 'provision' vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD') if not vault_pass: raise WebhookError("Vault password not found in environment variables") - + with tempfile.NamedTemporaryFile(mode='w', delete=False) as vault_pass_file: vault_pass_file.write(vault_pass) vault_pass_file.flush() @@ -153,31 +148,73 @@ def run_ansible_playbook(invoice_id: str, cleanup: bool = False) -> subprocess.C 'ansible-playbook', str(playbook), '-i', str(BASE_DIR / 'inventory.ini'), - '-e', f'invoice_id={invoice_id}', + '-e', f'invoice_id={invoice_id}' + ] + + if extra_vars: + for key, value in extra_vars.items(): + cmd.extend(['-e', f'{key}={value}']) + + cmd.extend([ '--vault-password-file', vault_pass_file.name, '-vvv' - ] + ]) logger.info(f"Running ansible-playbook command: {' '.join(cmd)}") + # Run ansible-playbook without check=True to handle errors better result = subprocess.run( cmd, capture_output=True, - text=True, - check=True # This will raise CalledProcessError if playbook fails + text=True ) + # Log detailed output for debugging + logger.info("Ansible STDOUT:") + logger.info(result.stdout) + + if result.stderr: + logger.error("Ansible STDERR:") + logger.error(result.stderr) + + # Check return code manually + if result.returncode != 0: + logger.error(f"Ansible playbook failed with return code {result.returncode}") + logger.error(f"Error output: {result.stderr}") + raise WebhookError(f"Failed {operation_type} for subscription {invoice_id}") + + # Log successful operation + is_test = bool(extra_vars and extra_vars.get('is_test')) + ansible_logger.log_operation( + invoice_id, + operation_type, + result, + is_test=is_test + ) + + # Check for fatal errors in output if "fatal:" in result.stdout or "fatal:" in result.stderr: + logger.error("Fatal error detected in Ansible output") raise WebhookError("Ansible playbook reported fatal error") + logger.info(f"Successfully completed {operation_type} for {invoice_id}") return result except subprocess.CalledProcessError as e: logger.error(f"Playbook execution failed: {e.stderr}") - raise WebhookError(f"Ansible playbook failed with return code {e.returncode}") + # Log failed operation + ansible_logger.log_operation( + invoice_id, + operation_type, + e, + is_test=bool(extra_vars and extra_vars.get('is_test')) + ) + raise WebhookError(f"Failed {operation_type} for subscription {invoice_id}: {e.stderr}") + except Exception as e: logger.error(f"Error running playbook: {traceback.format_exc()}") raise WebhookError(f"Playbook execution failed: {str(e)}") + finally: if 'vault_pass_file' in locals(): os.unlink(vault_pass_file.name) @@ -220,13 +257,106 @@ def handle_subscription_status(data: dict) -> tuple: logger.error(f"Error handling subscription status: {traceback.format_exc()}") return jsonify({"error": str(e)}), 500 -def handle_payment_webhook(request) -> tuple: - """ - Handle BTCPay Server webhook for VPN provisioning +def handle_test_webhook(data, webhook_type): + """Handle test webhook with proper Ansible execution and logging""" + logger.info(f"Processing test webhook: {webhook_type}") + invoice_id = data.get('invoiceId', '') - Returns: - tuple: (response, status_code) - """ + if not invoice_id.startswith('__test__'): + logger.error("Invalid test invoice ID format") + return jsonify({"error": "Invalid test invoice ID"}), 400 + + # Process both types of invoice settlement webhooks + if webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']: + try: + # For test invoices, create a 30-minute subscription + test_duration = 30 # minutes + test_user_id = f"test_{uuid.uuid4()}" + test_pubkey = f"TEST_KEY_{uuid.uuid4()}" + + logger.info(f"Creating test subscription for {test_duration} minutes") + + # Create test subscription entry - now returns a dictionary + subscription_data = DatabaseManager.create_subscription( + user_id=test_user_id, + invoice_id=invoice_id, + public_key=test_pubkey, + duration_hours=0.5 # 30 minutes + ) + + if not subscription_data: + logger.error("Failed to create test subscription") + return jsonify({"error": "Failed to create test subscription"}), 500 + + logger.info(f"Created test subscription: {subscription_data['id']}") + + # Run the provisioning playbook with test flag + try: + logger.info("Running test VPN provision playbook") + result = run_ansible_playbook( + invoice_id=invoice_id, + cleanup=False, + extra_vars={ + "is_test": True, + "test_duration_minutes": test_duration, + "test_public_key": test_pubkey + } + ) + + if result.returncode == 0: + logger.info(f"Test VPN provisioned successfully for {test_duration} minutes") + + # Activate subscription and record payment + activated_data = DatabaseManager.activate_subscription(invoice_id) + if activated_data: + DatabaseManager.record_payment( + test_user_id, + subscription_data['id'], # Use dictionary key instead of object attribute + invoice_id, + data.get('amount', 0) + ) + + cleanup_time = datetime.datetime.utcnow() + datetime.timedelta(minutes=test_duration) + logger.info(f"Scheduling cleanup for {cleanup_time}") + + return jsonify({ + "status": "success", + "message": f"Test VPN provisioned for {test_duration} minutes", + "test_user_id": test_user_id, + "subscription_id": subscription_data['id'], # Include subscription ID in response + "assigned_ip": subscription_data['assigned_ip'], # Include assigned IP + "cleanup_scheduled": cleanup_time.isoformat() + }), 200 + else: + logger.error("Failed to activate subscription") + return jsonify({"error": "Failed to activate subscription"}), 500 + + logger.error(f"Test provisioning failed: {result.stderr}") + return jsonify({"error": "Test provisioning failed"}), 500 + + except WebhookError as e: + logger.error(f"Error in test provision playbook: {str(e)}") + return jsonify({"error": str(e)}), 500 + + except Exception as e: + logger.error(f"Error in test provisioning: {str(e)}") + logger.error(traceback.format_exc()) + return jsonify({"error": str(e)}), 500 + + # Handle test subscription status updates + elif webhook_type == 'SubscriptionStatusUpdated': + return handle_subscription_status(data) + + # For other test webhook types, just acknowledge + else: + logger.info(f"Acknowledged test webhook: {webhook_type}") + return jsonify({ + "status": "success", + "message": f"Test webhook {webhook_type} acknowledged" + }), 200 + +def handle_payment_webhook(request) -> tuple: + """Handle BTCPay Server webhook for VPN provisioning""" try: vault_values = get_vault_values() logger.info(f"Processing webhook on endpoint: {vault_values['webhook_full_url']}") @@ -254,57 +384,85 @@ def handle_payment_webhook(request) -> tuple: logger.info(f"Received webhook data: {data}") - # Handle test webhooks - invoice_id = data.get('invoiceId', '') - if invoice_id.startswith('__test__'): - logger.info(f"Received test webhook, acknowledging: {data.get('type')}") - return jsonify({ - "status": "success", - "message": "Test webhook acknowledged" - }), 200 - + # Extract webhook type and invoice ID webhook_type = data.get('type') + invoice_id = data.get('invoiceId', '') + if not webhook_type: return jsonify({"error": "Missing webhook type"}), 400 - - # Handle different webhook types + + # Handle test webhooks with special processing + if invoice_id.startswith('__test__'): + return handle_test_webhook(data, webhook_type) + + # Handle different webhook types for production if webhook_type == 'SubscriptionStatusUpdated': return handle_subscription_status(data) - elif webhook_type == 'InvoiceSettled' or webhook_type == 'InvoicePaymentSettled': + elif webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']: if not invoice_id: logger.error("Missing invoiceId in webhook data") return jsonify({"error": "Missing invoiceId"}), 400 - try: - # Run VPN provisioning - logger.info(f"Starting VPN provisioning for invoice {invoice_id}") - result = run_ansible_playbook(invoice_id) - - # Update subscription status - subscription = DatabaseManager.get_subscription_by_invoice(invoice_id) - if subscription: - subscription = DatabaseManager.activate_subscription(invoice_id) - DatabaseManager.record_payment( - subscription.user_id, - subscription.id, - invoice_id, - data.get('amount', 0) - ) + from ..utils.db import get_session + with get_session() as session: + try: + # Check if payment already exists + existing_payment = session.query(Payment).filter( + Payment.invoice_id == invoice_id + ).first() - logger.info(f"VPN provisioning completed for invoice {invoice_id}") - return jsonify({ - "status": "success", - "invoice_id": invoice_id, - "message": "VPN provisioning completed" - }), 200 - - except WebhookError as e: - logger.error(f"VPN provisioning failed: {str(e)}") - return jsonify({ - "error": "Provisioning failed", - "details": str(e) - }), 500 + if existing_payment: + logger.info(f"Payment already recorded for invoice {invoice_id}") + return jsonify({ + "status": "success", + "message": "Payment already processed", + "invoice_id": invoice_id + }), 200 + + # Run VPN provisioning + logger.info(f"Starting VPN provisioning for invoice {invoice_id}") + result = run_ansible_playbook(invoice_id) + + # Update subscription status within session + subscription = session.query(Subscription).filter( + Subscription.invoice_id == invoice_id + ).options(joinedload(Subscription.user)).first() + + if subscription: + # Activate subscription + subscription.status = SubscriptionStatus.ACTIVE + + # Record payment only if it doesn't exist + payment = Payment( + user_id=subscription.user_id, + subscription_id=subscription.id, + invoice_id=invoice_id, + amount=data.get('amount', 0) + ) + session.add(payment) + + # Commit all changes + session.commit() + + logger.info(f"VPN provisioning completed for invoice {invoice_id}") + return jsonify({ + "status": "success", + "invoice_id": invoice_id, + "message": "VPN provisioning completed" + }), 200 + else: + logger.error(f"Subscription not found for invoice {invoice_id}") + return jsonify({"error": "Subscription not found"}), 404 + + except Exception as e: + session.rollback() + logger.error(f"VPN provisioning failed: {str(e)}") + logger.error(traceback.format_exc()) + return jsonify({ + "error": "Provisioning failed", + "details": str(e) + }), 500 else: logger.info(f"Received {webhook_type} webhook - no action required") diff --git a/app/static/js/components/WireGuardPayment.jsx b/app/static/js/components/WireGuardPayment.jsx deleted file mode 100644 index 30c065b..0000000 --- a/app/static/js/components/WireGuardPayment.jsx +++ /dev/null @@ -1,213 +0,0 @@ -import React, { useState, useEffect } from 'react'; -import { generateKeys } from '../utils/wireguard'; -import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/card'; -import { Button } from '@/components/ui/button'; -import { Input } from '@/components/ui/input'; -import { AlertCircle } from 'lucide-react'; -import { Alert, AlertDescription } from '@/components/ui/alert'; - -const WireGuardPayment = () => { - const [keyData, setKeyData] = useState(null); - const [duration, setDuration] = useState(24); - const [price, setPrice] = useState(0); - const [loading, setLoading] = useState(false); - const [error, setError] = useState(''); - const [userId, setUserId] = useState(''); - - useEffect(() => { - // Generate random userId and keys when component mounts - const init = async () => { - try { - const randomId = crypto.randomUUID(); - setUserId(randomId); - const keys = await generateKeys(); - setKeyData(keys); - calculatePrice(duration); - } catch (err) { - setError('Failed to initialize keys'); - console.error(err); - } - }; - init(); - }, []); - - const calculatePrice = async (hours) => { - try { - const response = await fetch('/api/calculate-price', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ hours: parseInt(hours) }) - }); - const data = await response.json(); - setPrice(data.price); - } catch (err) { - setError('Failed to calculate price'); - } - }; - - const handleDurationChange = (event) => { - const newDuration = parseInt(event.target.value); - setDuration(newDuration); - calculatePrice(newDuration); - }; - - const handlePayment = async () => { - if (!keyData) { - setError('No keys generated. Please refresh the page.'); - return; - } - - try { - setLoading(true); - const response = await fetch('/create-invoice', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - duration, - userId, - publicKey: keyData.publicKey, - // Don't send private key to server! - configuration: { - type: 'wireguard', - publicKey: keyData.publicKey - } - }) - }); - - if (!response.ok) { - throw new Error('Failed to create payment'); - } - - const data = await response.json(); - - // Save private key to localStorage before redirecting - localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify({ - privateKey: keyData.privateKey, - publicKey: keyData.publicKey, - createdAt: new Date().toISOString() - })); - - window.location.href = data.checkout_url; - } catch (err) { - setError('Failed to initiate payment'); - console.error(err); - } finally { - setLoading(false); - } - }; - - const handleRegenerateKeys = async () => { - try { - const keys = await generateKeys(); - setKeyData(keys); - } catch (err) { - setError('Failed to regenerate keys'); - } - }; - - return ( - - - WireGuard VPN Configuration - - - {error && ( - - - {error} - - )} - -
- - -

- Save this ID - you'll need it to manage your subscription -

-
- - {keyData && ( -
- - - -
- )} - -
- -
- -
- - - -
-

{duration} hours

-
-
- -
-

{price} sats

-
- - - -
-

• Keys are generated securely in your browser

-

• Your private key never leaves your device

-

• Configuration will be available after payment

-
-
-
- ); -}; - -export default WireGuardPayment; \ No newline at end of file diff --git a/app/static/js/pricing.js b/app/static/js/pricing.js index 8cadf86..1c01668 100644 --- a/app/static/js/pricing.js +++ b/app/static/js/pricing.js @@ -1,116 +1,190 @@ +// Constants for pricing +const HOURLY_RATE = 100; // 100 sats per hour +const MIN_HOURS = 1; +const MAX_HOURS = 2160; // 3 months +const MIN_SATS = HOURLY_RATE * MIN_HOURS; +const MAX_SATS = 216000; // Maximum for 3 months + // Utility functions for duration formatting function formatDuration(hours) { - if (hours < 24) { - return `${hours} hour${hours === 1 ? '' : 's'}`; + const exactHours = `${hours} hour${hours === 1 ? '' : 's'}`; + + // Break down the time into components + const months = Math.floor(hours / 720); + const remainingAfterMonths = hours % 720; + const weeks = Math.floor(remainingAfterMonths / 168); + const remainingAfterWeeks = remainingAfterMonths % 168; + const days = Math.floor(remainingAfterWeeks / 24); + const remainingHours = remainingAfterWeeks % 24; + + // Build the detailed breakdown + const parts = []; + if (months > 0) { + parts.push(`${months} month${months === 1 ? '' : 's'}`); } - if (hours < 168) { - return `${hours / 24} day${hours === 24 ? '' : 's'}`; + if (weeks > 0) { + parts.push(`${weeks} week${weeks === 1 ? '' : 's'}`); } - if (hours < 720) { - return `${Math.floor(hours / 168)} week${hours === 168 ? '' : 's'}`; + if (days > 0) { + parts.push(`${days} day${days === 1 ? '' : 's'}`); } - return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`; + if (remainingHours > 0 || parts.length === 0) { + parts.push(`${remainingHours} hour${remainingHours === 1 ? '' : 's'}`); + } + + // Combine all parts with proper grammar + let breakdown = ''; + if (parts.length > 1) { + const lastPart = parts.pop(); + breakdown = parts.join(', ') + ' and ' + lastPart; + } else { + breakdown = parts[0]; + } + + return `${exactHours} (${breakdown})`; } // Price calculation with volume discounts -async function calculatePrice(hours) { +function calculatePrice(hours) { try { - const response = await fetch('/api/calculate-price', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ hours: parseInt(hours) }) - }); - - if (!response.ok) { - throw new Error('Failed to calculate price'); + hours = parseInt(hours); + if (hours < MIN_HOURS) return MIN_SATS; + + let basePrice = hours * HOURLY_RATE; + if (hours >= 2160) { // 3 months + basePrice = basePrice * 0.75; + } else if (hours >= 720) { // 30 days + basePrice = basePrice * 0.85; + } else if (hours >= 168) { // 7 days + basePrice = basePrice * 0.90; + } else if (hours >= 24) { // 1 day + basePrice = basePrice * 0.95; } - - const data = await response.json(); - return { - price: data.price, - formattedDuration: formatDuration(hours) - }; + return Math.round(basePrice); } catch (error) { - console.error('Price calculation failed:', error); - throw error; + console.error('Error calculating price:', error); + return MIN_SATS; } } -// Form initialization and event handling -function initializeForm(config) { - const { - formId = 'subscription-form', - sliderId = 'duration-slider', - priceDisplayId = 'price-display', - durationDisplayId = 'duration-display', - presetButtonClass = 'duration-preset' - } = config; - - const form = document.getElementById(formId); - const slider = document.getElementById(sliderId); - const priceDisplay = document.getElementById(priceDisplayId); - const durationDisplay = document.getElementById(durationDisplayId); - const presetButtons = document.querySelectorAll(`.${presetButtonClass}`); - - if (!form || !slider || !priceDisplay || !durationDisplay) { - throw new Error('Required elements not found'); +// Calculate hours from price +function calculateHoursFromPrice(sats) { + try { + sats = parseInt(sats); + if (sats < MIN_SATS) return MIN_HOURS; + if (sats > MAX_SATS) return MAX_HOURS; + + // Binary search for the closest hour value + const binarySearchHours = (min, max, targetSats) => { + while (min <= max) { + const mid = Math.floor((min + max) / 2); + const price = calculatePrice(mid); + + if (price === targetSats) return mid; + if (price < targetSats) min = mid + 1; + else max = mid - 1; + } + return max; + }; + + let hours = 0; + if (sats >= calculatePrice(2160)) { + hours = Math.floor(sats / (HOURLY_RATE * 0.75)); + } else if (sats >= calculatePrice(720)) { + hours = binarySearchHours(720, 2159, sats); + } else if (sats >= calculatePrice(168)) { + hours = binarySearchHours(168, 719, sats); + } else if (sats >= calculatePrice(24)) { + hours = binarySearchHours(24, 167, sats); + } else { + hours = binarySearchHours(1, 23, sats); + } + + return Math.max(MIN_HOURS, Math.min(MAX_HOURS, hours)); + } catch (error) { + console.error('Error calculating hours from price:', error); + return MIN_HOURS; } +} - return { - form, - slider, - priceDisplay, - durationDisplay, - presetButtons +// Update all displays and inputs +function updateDisplays(hours, skipSource = null) { + const elements = { + priceDisplay: document.getElementById('price-display'), + durationDisplay: document.getElementById('duration-display'), + customHours: document.getElementById('custom-hours'), + customSats: document.getElementById('custom-sats') }; + + hours = Math.max(MIN_HOURS, Math.min(MAX_HOURS, hours)); + const price = calculatePrice(hours); + + // Update displays + if (elements.priceDisplay && elements.durationDisplay) { + elements.priceDisplay.textContent = price; + elements.durationDisplay.textContent = formatDuration(hours); + } + + // Update inputs (skip the source of the update) + if (skipSource !== 'hours' && elements.customHours) { + elements.customHours.value = hours; + } + if (skipSource !== 'sats' && elements.customSats) { + elements.customSats.value = price; + } } // Main pricing interface export const Pricing = { - async init(config = {}) { - try { - const elements = initializeForm(config); - const { form, slider, priceDisplay, durationDisplay, presetButtons } = elements; + init() { + console.log('Initializing pricing system...'); + const elements = { + customHours: document.getElementById('custom-hours'), + customSats: document.getElementById('custom-sats'), + presetButtons: document.querySelectorAll('.duration-preset') + }; - // Update price when duration changes - const updateDisplay = async (hours) => { - try { - const { price, formattedDuration } = await calculatePrice(hours); - priceDisplay.textContent = price; - durationDisplay.textContent = formattedDuration; - } catch (error) { - console.error('Failed to update price display:', error); - priceDisplay.textContent = 'Error'; - durationDisplay.textContent = 'Error calculating duration'; - } - }; + // Initial display + updateDisplays(24); // Start with 24 hours as default - // Set up event listeners - slider.addEventListener('input', () => updateDisplay(slider.value)); - - presetButtons.forEach(button => { - button.addEventListener('click', (e) => { - const hours = e.target.dataset.hours; - slider.value = hours; - updateDisplay(hours); - }); + // Event listeners for custom inputs + elements.customHours?.addEventListener('input', (e) => { + let hours = parseInt(e.target.value) || MIN_HOURS; + hours = Math.max(MIN_HOURS, Math.min(MAX_HOURS, hours)); + updateDisplays(hours, 'hours'); + }); + + elements.customSats?.addEventListener('input', (e) => { + let sats = parseInt(e.target.value) || MIN_SATS; + sats = Math.max(MIN_SATS, Math.min(MAX_SATS, sats)); + const hours = calculateHoursFromPrice(sats); + updateDisplays(hours, 'sats'); + }); + + // Add blur events to enforce minimums + elements.customHours?.addEventListener('blur', (e) => { + if (!e.target.value || parseInt(e.target.value) < MIN_HOURS) { + updateDisplays(MIN_HOURS, 'hours'); + } + }); + + elements.customSats?.addEventListener('blur', (e) => { + if (!e.target.value || parseInt(e.target.value) < MIN_SATS) { + updateDisplays(MIN_HOURS, 'sats'); + } + }); + + // Handle preset buttons + elements.presetButtons.forEach(button => { + button.addEventListener('click', () => { + const hours = parseInt(button.getAttribute('data-hours')); + updateDisplays(hours); }); - - // Initial price calculation - await updateDisplay(slider.value); - - return { - updatePrice: updateDisplay, - getCurrentDuration: () => parseInt(slider.value) - }; - } catch (error) { - console.error('Failed to initialize pricing:', error); - throw error; - } - }, - - formatDuration, - calculatePrice + }); + } }; -export default Pricing; \ No newline at end of file +// Auto-initialize on script load +document.addEventListener('DOMContentLoaded', () => { + Pricing.init(); +}); \ No newline at end of file diff --git a/app/static/js/utils/wireguard.js b/app/static/js/utils/wireguard.js index eb6d189..fc7b32d 100644 --- a/app/static/js/utils/wireguard.js +++ b/app/static/js/utils/wireguard.js @@ -1,134 +1,117 @@ -// Base64 encoding/decoding utilities with error handling +// Base64 encoding/decoding utilities const b64 = { - encode: (array) => { - try { - return btoa(String.fromCharCode.apply(null, array)); - } catch (error) { - console.error('Base64 encoding failed:', error); - throw new Error('Failed to encode key data'); - } - }, - decode: (str) => { - try { - return Uint8Array.from(atob(str), c => c.charCodeAt(0)); - } catch (error) { - console.error('Base64 decoding failed:', error); - throw new Error('Failed to decode key data'); - } - } + encode: (array) => { + try { + return btoa(String.fromCharCode.apply(null, array)) + .replace(/[+/]/g, char => char === '+' ? '-' : '_') + .replace(/=+$/, ''); + } catch (error) { + console.error('Base64 encoding failed:', error); + throw new Error('Failed to encode key data'); + } + }, + decode: (str) => { + try { + str = str.replace(/[-_]/g, char => char === '-' ? '+' : '/'); + while (str.length % 4) str += '='; + return Uint8Array.from(atob(str), c => c.charCodeAt(0)); + } catch (error) { + console.error('Base64 decoding failed:', error); + throw new Error('Failed to decode key data'); + } + } }; -// Key storage management -const keyStorage = { - store: (userId, keyData) => { - try { - const data = { - privateKey: keyData.privateKey, - publicKey: keyData.publicKey, - createdAt: new Date().toISOString() - }; - localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify(data)); - } catch (error) { - console.error('Failed to store keys:', error); - throw new Error('Failed to save key data'); - } - }, +// Check if we're in a secure context (HTTPS) or development mode +const isDevelopment = window.location.hostname === 'localhost' || + window.location.hostname === '127.0.0.1' || + /^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$/.test(window.location.hostname); - retrieve: (userId) => { - try { - const data = localStorage.getItem(`vpn_keys_${userId}`); - return data ? JSON.parse(data) : null; - } catch (error) { - console.error('Failed to retrieve keys:', error); - throw new Error('Failed to retrieve key data'); - } - }, +// Generate secure random bytes +async function getRandomBytes(length) { + const array = new Uint8Array(length); + crypto.getRandomValues(array); + return array; +} - remove: (userId) => { - try { - localStorage.removeItem(`vpn_keys_${userId}`); - } catch (error) { - console.error('Failed to remove keys:', error); - } - } -}; - -// Main key generation function +// Generate a WireGuard key pair async function generateKeyPair() { - try { - const keyPair = await window.crypto.subtle.generateKey( - { - name: 'X25519', - namedCurve: 'X25519', - }, - true, - ['deriveKey', 'deriveBits'] - ); + try { + console.log('Generating WireGuard keys...'); + console.log('Environment:', isDevelopment ? 'Development' : 'Production'); - const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey); - const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey); + // Generate private key (32 random bytes) + const privateKeyBytes = await getRandomBytes(32); + const privateKey = b64.encode(privateKeyBytes); + console.log('Private key generated'); - return { - privateKey: b64.encode(new Uint8Array(privateKey)), - publicKey: b64.encode(new Uint8Array(publicKey)) - }; - } catch (error) { - console.error('Key generation failed:', error); - throw new Error('Failed to generate WireGuard keys'); - } + let publicKey; + let publicKeyBytes; + + // Use Web Crypto API in production/HTTPS, fallback for development/HTTP + if (!isDevelopment && window.crypto.subtle) { + const keyPair = await window.crypto.subtle.generateKey( + { + name: 'ECDH', + namedCurve: 'P-256', + }, + true, + ['deriveKey', 'deriveBits'] + ); + publicKeyBytes = await window.crypto.subtle.exportKey( + 'raw', + keyPair.publicKey + ); + publicKey = b64.encode(new Uint8Array(publicKeyBytes)); + } else { + // Development fallback + console.log('Using development key generation mode'); + publicKeyBytes = await getRandomBytes(32); + publicKey = b64.encode(publicKeyBytes); + } + console.log('Public key generated'); + + // Generate preshared key + const presharedKeyBytes = await getRandomBytes(32); + const presharedKey = b64.encode(presharedKeyBytes); + console.log('Preshared key generated'); + + return { privateKey, publicKey, presharedKey }; + } catch (error) { + console.error('Key generation failed:', error); + throw new Error('Failed to generate WireGuard keys'); + } } -// Key validation function -function validateKey(key) { - try { - const decoded = b64.decode(key); - return decoded.length === 32; - } catch { - return false; - } -} - -// WireGuard config generation -function generateConfig(keys, serverPublicKey, serverEndpoint, clientIp) { - if (!keys || !serverPublicKey || !serverEndpoint || !clientIp) { - throw new Error('Missing required configuration parameters'); - } - - return `[Interface] -PrivateKey = ${keys.privateKey} -Address = ${clientIp}/24 -DNS = 1.1.1.1 - -[Peer] -PublicKey = ${serverPublicKey} -Endpoint = ${serverEndpoint}:51820 -AllowedIPs = 0.0.0.0/0 -PersistentKeepalive = 25`; -} - -// Main interface for key management +// Export WireGuard interface export const WireGuard = { - generateKeys: async () => { - return await generateKeyPair(); - }, + generateKeys: async () => { + try { + console.log('Starting key generation process...'); + const keys = await generateKeyPair(); + console.log('Keys generated successfully:', { + privateKeyLength: keys.privateKey.length, + publicKeyLength: keys.publicKey.length, + presharedKeyLength: keys.presharedKey.length + }); + return keys; + } catch (error) { + console.error('Error in generateKeys:', error); + throw error; + } + }, - saveKeys: (userId, keyPair) => { - if (!validateKey(keyPair.publicKey) || !validateKey(keyPair.privateKey)) { - throw new Error('Invalid key data'); - } - keyStorage.store(userId, keyPair); - }, + validateKey: (key) => { + try { + const decoded = b64.decode(key); + return decoded.length === 32; + } catch { + return false; + } + }, - getKeys: (userId) => { - return keyStorage.retrieve(userId); - }, - - removeKeys: (userId) => { - keyStorage.remove(userId); - }, - - generateConfig, - validateKey + // Expose environment information + isDevelopment }; export default WireGuard; \ No newline at end of file diff --git a/app/templates/index.html b/app/templates/index.html index 24c1115..6f9e818 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -6,104 +6,261 @@
- -

Save this ID - you'll need it to manage your subscription

-
- -
- - - + +

Only used for subscription management

- +
- +
- - + + + +
+ +
+
+ +
+ +
+
+
+ +
+ +
+
-

+
+
+ 24 hours +
+ - + sats +
-
-

- - sats -

+
+ + +
- - -
-

• Keys are generated securely in your browser

-

• Your private key never leaves your device

-

• Configuration will be available after payment

-
+ + + {% endblock %} \ No newline at end of file diff --git a/app/templates/payment_success.html b/app/templates/payment_success.html index c335c9f..ffd373e 100644 --- a/app/templates/payment_success.html +++ b/app/templates/payment_success.html @@ -1,5 +1,4 @@ {% extends "base.html" %} - {% block content %}
@@ -65,84 +64,82 @@
{% endblock %} \ No newline at end of file diff --git a/app/utils/ansible_logger.py b/app/utils/ansible_logger.py new file mode 100644 index 0000000..7060ff6 --- /dev/null +++ b/app/utils/ansible_logger.py @@ -0,0 +1,112 @@ +# app/utils/ansible_logger.py +import logging +import json +from datetime import datetime +from pathlib import Path +from .db.operations import DatabaseManager + +class AnsibleLogger: + def __init__(self, log_dir=None): + """Initialize the Ansible logger""" + # Use data directory from project structure + self.base_dir = Path(__file__).resolve().parent.parent.parent + self.log_dir = log_dir or (self.base_dir / 'data' / 'logs') + self.log_dir.mkdir(parents=True, exist_ok=True) + + # Set up file handler + self.logger = logging.getLogger('ansible_operations') + self.logger.setLevel(logging.DEBUG) + + # Create a detailed log file + detailed_log = self.log_dir / 'ansible_operations.log' + file_handler = logging.FileHandler(detailed_log) + file_handler.setLevel(logging.DEBUG) + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + file_handler.setFormatter(formatter) + self.logger.addHandler(file_handler) + + def log_operation(self, subscription_id, operation_type, result, is_test=False): + """Log an Ansible operation""" + try: + # Get the subscription + subscription = DatabaseManager.get_subscription_by_invoice(subscription_id) + if not subscription: + self.logger.error(f"Subscription {subscription_id} not found") + return + + # Create detailed log entry + log_entry = { + 'timestamp': datetime.utcnow().isoformat(), + 'subscription_id': subscription_id, + 'operation_type': operation_type, + 'is_test': is_test, + 'return_code': result.returncode, + 'stdout': result.stdout, + 'stderr': result.stderr, + 'assigned_ip': subscription.assigned_ip + } + + # Create log filename with timestamp + log_file = self.log_dir / f"{operation_type}_{subscription_id}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}.json" + + # Write detailed JSON log + with open(log_file, 'w') as f: + json.dump(log_entry, f, indent=2) + + # Create provision log in database + DatabaseManager.create_provision_log({ + 'subscription_id': subscription.id, + 'action': operation_type, + 'status': 'success' if result.returncode == 0 else 'failure', + 'ansible_output': result.stdout, + 'error_message': result.stderr if result.returncode != 0 else None + }) + + # Log summary + if result.returncode == 0: + self.logger.info(f"Successfully completed {operation_type} for subscription {subscription_id}") + else: + self.logger.error(f"Failed {operation_type} for subscription {subscription_id}: {result.stderr}") + + except Exception as e: + self.logger.error(f"Error logging operation: {str(e)}") + + def get_logs(self, subscription_id=None, hours=24, operation_type=None): + """Get recent Ansible operation logs""" + try: + log_files = [] + pattern = f"*{subscription_id if subscription_id else ''}*.json" + + for log_file in self.log_dir.glob(pattern): + if operation_type and operation_type not in log_file.name: + continue + log_files.append(log_file) + + # Sort by modification time and return most recent first + log_files.sort(key=lambda x: x.stat().st_mtime, reverse=True) + + logs = [] + for log_file in log_files: + with open(log_file) as f: + logs.append(json.load(f)) + + return logs + + except Exception as e: + self.logger.error(f"Error retrieving logs: {str(e)}") + return [] + + def cleanup_old_logs(self, days=30): + """Clean up logs older than specified days""" + try: + cutoff = datetime.now().timestamp() - (days * 24 * 60 * 60) + + for log_file in self.log_dir.glob('*.json'): + if log_file.stat().st_mtime < cutoff: + log_file.unlink() + self.logger.info(f"Cleaned up old log file: {log_file}") + + except Exception as e: + self.logger.error(f"Error cleaning up logs: {str(e)}") \ No newline at end of file diff --git a/app/utils/db/models.py b/app/utils/db/models.py index 9390fda..746fc7f 100644 --- a/app/utils/db/models.py +++ b/app/utils/db/models.py @@ -1,4 +1,7 @@ -from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Enum, Text +from sqlalchemy import ( + create_engine, Column, Integer, String, DateTime, + ForeignKey, Enum, Text, JSON, Boolean, Float +) from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship import enum @@ -11,6 +14,14 @@ class SubscriptionStatus(enum.Enum): EXPIRED = "expired" PENDING = "pending" CANCELLED = "cancelled" + FAILED = "failed" # New status for failed provisions + SUSPENDED = "suspended" # New status for temp suspension + +class LogLevel(enum.Enum): + DEBUG = "debug" + INFO = "info" + WARNING = "warning" + ERROR = "error" class User(Base): __tablename__ = 'users' @@ -18,9 +29,13 @@ class User(Base): id = Column(Integer, primary_key=True) user_id = Column(String, unique=True, nullable=False) # UUID generated in frontend created_at = Column(DateTime, default=datetime.datetime.utcnow) + last_login = Column(DateTime, nullable=True) # New: Track last login time + is_active = Column(Boolean, default=True, nullable=True) # New: User account status + user_data = Column(JSON, nullable=True) # New: Optional user metadata subscriptions = relationship("Subscription", back_populates="user") payments = relationship("Payment", back_populates="user") + provision_logs = relationship("ProvisionLog", back_populates="user") # New relationship class Subscription(Base): __tablename__ = 'subscriptions' @@ -35,8 +50,17 @@ class Subscription(Base): warning_sent = Column(Integer, default=0) assigned_ip = Column(String) # WireGuard IP address assigned to this subscription + # New fields for monitoring + last_connection = Column(DateTime, nullable=True) # Track last connection time + data_usage = Column(Float, default=0.0) # Track data usage in MB + is_test = Column(Boolean, default=False) # Flag for test subscriptions + provision_attempts = Column(Integer, default=0) # Count provision attempts + cleanup_attempts = Column(Integer, default=0) # Count cleanup attempts + config_data = Column(JSON, nullable=True) # Store additional config data + user = relationship("User", back_populates="subscriptions") payments = relationship("Payment", back_populates="subscription") + provision_logs = relationship("ProvisionLog", back_populates="subscription") class Payment(Base): __tablename__ = 'payments' @@ -48,5 +72,26 @@ class Payment(Base): amount = Column(Integer, nullable=False) # Amount in sats timestamp = Column(DateTime, default=datetime.datetime.utcnow) + # New payment tracking fields + payment_method = Column(String, nullable=True) # Payment method used + payment_status = Column(String, nullable=True) # Payment status + confirmations = Column(Integer, default=0) # Number of confirmations + payment_data = Column(JSON, nullable=True) # Additional payment data + user = relationship("User", back_populates="payments") - subscription = relationship("Subscription", back_populates="payments") \ No newline at end of file + subscription = relationship("Subscription", back_populates="payments") + +class ProvisionLog(Base): + __tablename__ = 'provision_logs' + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('users.id')) + subscription_id = Column(Integer, ForeignKey('subscriptions.id')) + timestamp = Column(DateTime, default=datetime.datetime.utcnow) + action = Column(String, nullable=False) # 'provision' or 'cleanup' + status = Column(String, nullable=False) # 'success' or 'failure' + ansible_output = Column(Text, nullable=True) # Store Ansible output + error_message = Column(Text, nullable=True) # Store error messages + + user = relationship("User", back_populates="provision_logs") + subscription = relationship("Subscription", back_populates="provision_logs") \ No newline at end of file diff --git a/app/utils/db/operations.py b/app/utils/db/operations.py index e5f00f5..4540451 100644 --- a/app/utils/db/operations.py +++ b/app/utils/db/operations.py @@ -1,9 +1,10 @@ -from datetime import datetime +from datetime import datetime, timedelta from sqlalchemy.exc import SQLAlchemyError from . import get_session from .models import User, Subscription, Payment, SubscriptionStatus import logging import ipaddress +from .models import User, Subscription, Payment, ProvisionLog, SubscriptionStatus logger = logging.getLogger(__name__) @@ -58,12 +59,14 @@ class DatabaseManager: with get_session() as session: try: # Get user or create if doesn't exist - user = DatabaseManager.get_user_by_uuid(user_id) + user = session.query(User).filter(User.user_id == user_id).first() if not user: - user = DatabaseManager.create_user(user_id) - + user = User(user_id=user_id) + session.add(user) + session.flush() + start_time = datetime.utcnow() - expiry_time = start_time + datetime.timedelta(hours=duration_hours) + expiry_time = start_time + timedelta(hours=duration_hours) # Get next available IP assigned_ip = DatabaseManager.get_next_available_ip() @@ -77,9 +80,22 @@ class DatabaseManager: status=SubscriptionStatus.PENDING, assigned_ip=assigned_ip ) + session.add(subscription) session.commit() - return subscription + + # Return a dictionary of values instead of the SQLAlchemy object + return { + 'id': subscription.id, + 'user_id': user.id, + 'invoice_id': subscription.invoice_id, + 'public_key': subscription.public_key, + 'assigned_ip': subscription.assigned_ip, + 'start_time': subscription.start_time, + 'expiry_time': subscription.expiry_time, + 'status': subscription.status.value + } + except Exception as e: logger.error(f"Error creating subscription: {str(e)}") session.rollback() @@ -149,4 +165,33 @@ class DatabaseManager: subscription.warning_sent = 1 session.commit() return True - return False \ No newline at end of file + return False + + @staticmethod + def create_provision_log(log_data): + """Create a new provision log entry""" + with get_session() as session: + try: + provision_log = ProvisionLog( + subscription_id=log_data['subscription_id'], + action=log_data['action'], + status=log_data['status'], + ansible_output=log_data['ansible_output'], + error_message=log_data.get('error_message') + ) + session.add(provision_log) + session.commit() + return provision_log + except Exception as e: + session.rollback() + logger.error(f"Error creating provision log: {str(e)}") + raise + + @staticmethod + def get_provision_logs(subscription_id=None, limit=100): + """Get provision logs, optionally filtered by subscription""" + with get_session() as session: + query = session.query(ProvisionLog) + if subscription_id: + query = query.filter(ProvisionLog.subscription_id == subscription_id) + return query.order_by(ProvisionLog.timestamp.desc()).limit(limit).all() \ No newline at end of file diff --git a/data/subscriptions.json b/data/subscriptions.json deleted file mode 100644 index 961769a..0000000 --- a/data/subscriptions.json +++ /dev/null @@ -1,92 +0,0 @@ -{ - "__test__ee2b820c-d1a0-4c4d-8b7c-0e5550fbb42e__test__": { - "deliveryId": "P4su7aNvuaa7mGEJLFJ5mk", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__73469547-03db-4961-9d2f-da9b1ed7519f__test__", - "isRedelivery": false, - "type": "SubscriptionRenewalRequested", - "timestamp": 1733961630, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__dda0ab87-3962-435e-9cf6-2c9749dd21dc__test__", - "subscriptionId": "__test__ee2b820c-d1a0-4c4d-8b7c-0e5550fbb42e__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "renewal_requested": "2024-12-12T00:00:31.600171" - }, - "__test__f3e4e7cf-304b-4a97-a341-9e78e0f0bf7f__test__": { - "deliveryId": "7BPjwRNtFd2k3kzQLH8RoJ", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__50e7d2ea-18d3-46b0-bad3-fc5475a9994f__test__", - "isRedelivery": false, - "type": "SubscriptionRenewalRequested", - "timestamp": 1733972813, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__aa25dc16-84ac-4687-a2e9-41a97c5cc112__test__", - "subscriptionId": "__test__f3e4e7cf-304b-4a97-a341-9e78e0f0bf7f__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "renewal_requested": "2024-12-12T03:06:54.340514" - }, - "__test__0caaab4d-026d-4df6-b09b-164fa0edde79__test__": { - "deliveryId": "PhRMLk717pAWj2L1Z3LZtr", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__36b2350d-45df-4fe0-a4fa-c701684c9209__test__", - "isRedelivery": false, - "type": "SubscriptionStatusUpdated", - "timestamp": 1733972824, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__7dc7f4d8-020a-41b6-b728-84ca1ca0ab4c__test__", - "subscriptionId": "__test__0caaab4d-026d-4df6-b09b-164fa0edde79__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "last_updated": "2024-12-12T03:07:05.439848" - }, - "__test__9aa786d3-053c-4b70-8b01-6df0f0207b79__test__": { - "deliveryId": "U7stmUDcB4qseDh29mhke7", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__c5eba94d-ea9e-46e9-b127-74234fcd7002__test__", - "isRedelivery": false, - "type": "SubscriptionStatusUpdated", - "timestamp": 1733973182, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__c3149b37-ad19-428a-aa95-251adc96e5ae__test__", - "subscriptionId": "__test__9aa786d3-053c-4b70-8b01-6df0f0207b79__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "last_updated": "2024-12-12T03:13:03.056437" - }, - "__test__d57b8102-87f7-4143-b175-32353b6eaec7__test__": { - "deliveryId": "834dSuc5bdeXF2zRxovj2x", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__16c072e3-5702-4f57-bc32-7181f585bef2__test__", - "isRedelivery": false, - "type": "SubscriptionStatusUpdated", - "timestamp": 1733973198, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__59579b1c-c393-4fa4-b4be-6691b7a6db27__test__", - "subscriptionId": "__test__d57b8102-87f7-4143-b175-32353b6eaec7__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "last_updated": "2024-12-12T03:13:19.677073" - }, - "__test__5f28edc1-1c9e-4f1a-9fe9-591bc4ae7988__test__": { - "deliveryId": "9LUJAQmU89k8ofSKc7M7Kt", - "webhookId": "CoGqJKKuE3838AWZQkncSJ", - "originalDeliveryId": "__test__c274b638-568b-459f-b4f4-fd10029f5dce__test__", - "isRedelivery": false, - "type": "SubscriptionRenewalRequested", - "timestamp": 1733973203, - "storeId": "DcnEUCckb8eo5WBFBABb7EXGRP49a8UjYQRKkvz7AcJY", - "appId": "__test__6896db98-91f3-458e-b8c8-6c6280a5698f__test__", - "subscriptionId": "__test__5f28edc1-1c9e-4f1a-9fe9-591bc4ae7988__test__", - "status": "Active", - "paymentRequestId": null, - "email": null, - "renewal_requested": "2024-12-12T03:13:24.040298" - } -} \ No newline at end of file diff --git a/data/vpn.db b/data/vpn.db index 6ab8f2a..28a0b83 100644 Binary files a/data/vpn.db and b/data/vpn.db differ diff --git a/scripts/init_db.py b/scripts/init_db.py index cb7649e..6c1c01f 100644 --- a/scripts/init_db.py +++ b/scripts/init_db.py @@ -1,28 +1,157 @@ # scripts/init_db.py +import os import sys +import logging from pathlib import Path +from datetime import datetime + +# Configure logging before imports +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) # Add the project root to Python path project_root = Path(__file__).resolve().parent.parent sys.path.append(str(project_root)) +# Import only what's needed for DB initialization +from sqlalchemy import create_engine, text from app.utils.db.models import Base -from sqlalchemy import create_engine -def init_db(): - # Create data directory if it doesn't exist +def get_db_path(): + """Get the database path""" data_dir = project_root / 'data' data_dir.mkdir(exist_ok=True) - - # Create database - db_path = data_dir / 'vpn.db' - db_url = f"sqlite:///{db_path}" - engine = create_engine(db_url) - - # Create all tables - Base.metadata.create_all(engine) - print(f"Database initialized at: {db_path}") - return engine + return data_dir / 'vpn.db' + +def backup_existing_db(): + """Backup existing database if it exists""" + try: + db_path = get_db_path() + if db_path.exists(): + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + backup_path = db_path.parent / f'vpn_backup_{timestamp}.db' + db_path.rename(backup_path) + logger.info(f"Created backup at: {backup_path}") + return backup_path + return None + except Exception as e: + logger.error(f"Backup failed: {str(e)}") + return None + +def init_db(force=False): + """Initialize the database with all tables""" + try: + db_path = get_db_path() + + # Check if database already exists + if db_path.exists() and not force: + logger.warning(f"Database already exists at {db_path}") + logger.warning("Use --force to recreate the database") + return None + + # Backup existing database if force is True + if force and db_path.exists(): + backup_existing_db() + + logger.info(f"Initializing database at: {db_path}") + + # Create database URL + db_url = f"sqlite:///{db_path}" + + # Create engine with pragma statements for foreign keys + engine = create_engine( + db_url, + connect_args={"check_same_thread": False} + ) + + # Enable foreign key support using PRAGMA + with engine.connect() as conn: + conn.execute(text("PRAGMA foreign_keys = ON")) + + # Create all tables + Base.metadata.create_all(engine) + logger.info("Successfully created all database tables") + + # Log created tables + tables = Base.metadata.tables.keys() + logger.info("Created tables:") + for table in tables: + logger.info(f" - {table}") + + return engine + + except Exception as e: + logger.error(f"Database initialization failed: {str(e)}") + raise + +def verify_tables(engine): + """Verify that all tables were created correctly""" + try: + # Get list of all tables in the database + with engine.connect() as conn: + # SQLite specific query to get table info + query = text("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") + result = conn.execute(query) + existing_tables = {row[0] for row in result} + + # Get list of all tables defined in models + expected_tables = set(Base.metadata.tables.keys()) + + # Check for missing tables + missing_tables = expected_tables - existing_tables + if missing_tables: + logger.error(f"Missing tables: {missing_tables}") + return False + + # Verify table schemas + for table_name in existing_tables: + schema_query = text(f"PRAGMA table_info({table_name})") + result = conn.execute(schema_query) + logger.info(f"\nSchema for {table_name}:") + # SQLite PRAGMA table_info returns: (cid, name, type, notnull, dflt_value, pk) + for row in result: + cid, name, type_, notnull, dflt_value, pk = row + pk_str = "PRIMARY KEY" if pk else "" + null_str = "NOT NULL" if notnull else "NULL" + default_str = f"DEFAULT {dflt_value}" if dflt_value is not None else "" + logger.info(f" - {name} ({type_}) {null_str} {default_str} {pk_str}".strip()) + + logger.info("All expected tables were created successfully") + return True + + except Exception as e: + logger.error(f"Table verification failed: {str(e)}") + return False + +def main(): + """Main function to handle database initialization""" + try: + import argparse + parser = argparse.ArgumentParser(description='Initialize the VPN database') + parser.add_argument('--force', action='store_true', + help='Force database recreation') + args = parser.parse_args() + + logger.info("Starting database initialization") + engine = init_db(force=args.force) + + if engine is None: + return 1 + + # Verify tables were created correctly + if verify_tables(engine): + logger.info("Database initialization completed successfully") + return 0 + else: + logger.error("Database initialization failed - tables missing") + return 1 + + except Exception as e: + logger.error(f"Database initialization failed: {str(e)}") + return 1 if __name__ == "__main__": - init_db() \ No newline at end of file + sys.exit(main()) \ No newline at end of file diff --git a/scripts/migrate_db.py b/scripts/migrate_db.py index 4e68743..6659f78 100644 --- a/scripts/migrate_db.py +++ b/scripts/migrate_db.py @@ -58,48 +58,46 @@ def migrate_database(): # Start transaction session.begin() - # Migrate users + # Migrate users - note we only use user_id now logger.info("Migrating users...") - old_cursor.execute("SELECT id, email, created_at FROM users") + old_cursor.execute("SELECT id, user_id, created_at FROM users") users = old_cursor.fetchall() user_id_map = {} # Map old IDs to new UUIDs - for old_id, email, created_at in users: - new_user_id = str(uuid.uuid4()) - user_id_map[old_id] = new_user_id + for old_id, user_id, created_at in users: + user_id_map[old_id] = user_id # Keep the same user_id session.execute( "INSERT INTO users (user_id, created_at) VALUES (?, ?)", - [new_user_id, created_at or datetime.utcnow()] + [user_id, created_at or datetime.utcnow()] ) # Migrate subscriptions logger.info("Migrating subscriptions...") old_cursor.execute(""" - SELECT id, user_id, invoice_id, start_time, expiry_time, - status, warning_sent + SELECT id, user_id, invoice_id, public_key, start_time, + expiry_time, status, warning_sent, assigned_ip FROM subscriptions """) subscriptions = old_cursor.fetchall() for sub in subscriptions: - old_id, old_user_id, invoice_id, start_time, expiry_time, status, warning_sent = sub + old_id, old_user_id, invoice_id, public_key, start_time, \ + expiry_time, status, warning_sent, assigned_ip = sub if old_user_id in user_id_map: - # Generate a placeholder public key for existing subscriptions - placeholder_pubkey = f"MIGRATED_{uuid.uuid4()}" session.execute(""" INSERT INTO subscriptions (user_id, invoice_id, public_key, start_time, expiry_time, status, warning_sent, assigned_ip) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, [ - user_id_map[old_user_id], + old_user_id, # Use the original user_id invoice_id, - placeholder_pubkey, + public_key, start_time, expiry_time, status, warning_sent, - f"10.8.0.{2 + old_id}" # Simple IP assignment + assigned_ip ]) # Migrate payments @@ -118,7 +116,7 @@ def migrate_database(): (user_id, subscription_id, invoice_id, amount, timestamp) VALUES (?, ?, ?, ?, ?) """, [ - user_id_map[old_user_id], + old_user_id, # Use the original user_id sub_id, invoice_id, amount, @@ -142,7 +140,7 @@ def migrate_database(): except Exception as e: logger.error(f"Migration failed: {str(e)}") raise - + if __name__ == '__main__': try: migrate_database()