diff --git a/ansible/playbooks/vpn_provision.yml b/ansible/playbooks/vpn_provision.yml index 4d210d8..d4dc662 100644 --- a/ansible/playbooks/vpn_provision.yml +++ b/ansible/playbooks/vpn_provision.yml @@ -44,7 +44,23 @@ dest: "{{ server_dir }}/public.key" mode: '0644' when: not server_config.stat.exists - + + - name: Update vault with server details + block: + - name: Read server public key + shell: "cat {{ server_dir }}/public.key" + register: pubkey_content + changed_when: false + + - name: Save server details to vault + copy: + content: | + wireguard_server_public_key: "{{ pubkey_content.stdout }}" + wireguard_server_endpoint: "{{ ansible_host }}" + dest: "{{ playbook_dir }}/../group_vars/vpn_servers/vault.yml" + mode: '0600' + when: not server_config.stat.exists + - name: Create initial server config template: src: templates/server.conf.j2 diff --git a/app/__init__.py b/app/__init__.py index f99f452..688ef97 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,5 +1,3 @@ -# app/__init__.py - from flask import Flask, request, jsonify, render_template import logging from .handlers.webhook_handler import handle_payment_webhook diff --git a/app/data/vpn.db b/app/data/vpn.db new file mode 100644 index 0000000..6b7c17d Binary files /dev/null and b/app/data/vpn.db differ diff --git a/app/handlers/payment_handler.py b/app/handlers/payment_handler.py index 7fc8ec7..2bb3984 100644 --- a/app/handlers/payment_handler.py +++ b/app/handlers/payment_handler.py @@ -1,15 +1,11 @@ -# app/handlers/payment_handler.py - import logging import requests import os -from flask import jsonify -import smtplib -from email.mime.text import MIMEText -from email.mime.multipart import MIMEMultipart from pathlib import Path import traceback from .webhook_handler import get_vault_values +from ..utils.db.operations import DatabaseManager +from ..utils.db.models import SubscriptionStatus logger = logging.getLogger(__name__) @@ -25,9 +21,6 @@ class BTCPayHandler: self.api_key = vault_values['btcpay_api_key'] self.store_id = vault_values['btcpay_store_id'] - # Email configuration - self.smtp_config = vault_values.get('smtp_config', {}) - logger.info(f"BTCPayHandler initialized with base URL: {self.base_url}") except Exception as e: @@ -35,10 +28,16 @@ class BTCPayHandler: logger.error(traceback.format_exc()) raise - def create_invoice(self, amount_sats, duration_hours, email): + def create_invoice(self, amount_sats, duration_hours, user_id, public_key): """Create BTCPay invoice for VPN subscription""" try: - logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for {email}") + logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for user {user_id}") + + # First, get or create user + user = DatabaseManager.get_user_by_uuid(user_id) + if not user: + user = DatabaseManager.create_user(user_id) + logger.info(f"Created new user with ID: {user_id}") headers = { 'Authorization': f'token {self.api_key}', @@ -54,11 +53,12 @@ class BTCPayHandler: 'currency': 'SATS', 'metadata': { 'duration_hours': duration_hours, - 'email': email, - 'orderId': f'vpn_sub_{duration_hours}h', + 'userId': user_id, + 'publicKey': public_key, + 'orderId': f'vpn_sub_{duration_hours}h' }, 'checkout': { - 'redirectURL': f'{app_url}/payment/success', + 'redirectURL': f'{app_url}/payment/success?userId={user_id}', 'redirectAutomatically': True } } @@ -80,11 +80,22 @@ class BTCPayHandler: return None invoice_data = response.json() - logger.info(f"Successfully created invoice {invoice_data.get('id')}") + invoice_id = invoice_data['id'] + logger.info(f"Successfully created invoice {invoice_id}") + + # Create pending subscription + subscription = DatabaseManager.create_subscription( + user_id, + invoice_id, + public_key, + duration_hours + ) + logger.info(f"Created pending subscription for invoice {invoice_id}") return { - 'invoice_id': invoice_data['id'], - 'checkout_url': invoice_data['checkoutLink'] + 'invoice_id': invoice_id, + 'checkout_url': invoice_data['checkoutLink'], + 'user_id': user_id } except Exception as e: @@ -92,53 +103,27 @@ class BTCPayHandler: logger.error(traceback.format_exc()) return None - def send_confirmation_email(self, email, config_data): - """Send VPN configuration details via email""" + def get_subscription_config(self, user_id): + """Get WireGuard configuration details for a subscription""" try: - logger.info(f"Sending confirmation email to {email}") - - if not self.smtp_config: - logger.warning("SMTP configuration not found in vault") - return False - - msg = MIMEMultipart() - msg['From'] = self.smtp_config['sender_email'] - msg['To'] = email - msg['Subject'] = "Your VPN Configuration" - - body = f""" - Thank you for subscribing to our VPN service! - - Please find your WireGuard configuration below: - - {config_data} - - Installation instructions: - 1. Install WireGuard client for your platform from https://www.wireguard.com/install/ - 2. Save the above configuration to a file named 'wg0.conf' - 3. Import the configuration file into your WireGuard client - - Need help? Reply to this email for support. - """ - - msg.attach(MIMEText(body, 'plain')) - - logger.debug("Connecting to SMTP server") - with smtplib.SMTP( - self.smtp_config['server'], - self.smtp_config.get('port', 587) - ) as server: - server.starttls() - server.login( - self.smtp_config['username'], - self.smtp_config['password'] - ) - server.send_message(msg) - - logger.info("Confirmation email sent successfully") - return True - + logger.info(f"Fetching subscription config for user {user_id}") + user = DatabaseManager.get_user_by_uuid(user_id) + if not user: + logger.error(f"User {user_id} not found") + return None + + subscription = DatabaseManager.get_active_subscription_for_user(user.id) + if not subscription: + logger.error(f"No active subscription found for user {user_id}") + return None + + return { + 'serverPublicKey': os.getenv('WIREGUARD_SERVER_PUBLIC_KEY'), + 'serverEndpoint': os.getenv('WIREGUARD_SERVER_ENDPOINT'), + 'clientIp': subscription.assigned_ip + } + except Exception as e: - logger.error("Error sending confirmation email:") + logger.error(f"Error getting subscription config: {str(e)}") logger.error(traceback.format_exc()) - return False \ No newline at end of file + return None \ No newline at end of file diff --git a/app/handlers/webhook_handler.py b/app/handlers/webhook_handler.py index b3295f9..9791a40 100644 --- a/app/handlers/webhook_handler.py +++ b/app/handlers/webhook_handler.py @@ -6,11 +6,12 @@ import logging import hmac import hashlib import yaml -import json import datetime import traceback from pathlib import Path from dotenv import load_dotenv +from ..utils.db.operations import DatabaseManager +from ..utils.db.models import SubscriptionStatus load_dotenv() @@ -23,7 +24,6 @@ logger = logging.getLogger(__name__) BASE_DIR = Path(__file__).resolve().parent.parent.parent PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml' CLEANUP_PLAYBOOK = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_cleanup.yml' -SUBSCRIPTION_DB = BASE_DIR / 'data' / 'subscriptions.json' def get_vault_values(): """Get decrypted values from Ansible vault""" @@ -80,31 +80,6 @@ def verify_signature(payload_body, signature_header): logger.error(f"Signature verification failed: {str(e)}") return False -def load_subscriptions(): - """Load subscription data from JSON file""" - if not SUBSCRIPTION_DB.parent.exists(): - SUBSCRIPTION_DB.parent.mkdir(parents=True) - - if not SUBSCRIPTION_DB.exists(): - return {} - - with open(SUBSCRIPTION_DB, 'r') as f: - return json.load(f) - -def save_subscription(subscription_data): - """Save subscription data to JSON file""" - subscriptions = load_subscriptions() - sub_id = subscription_data['subscriptionId'] - subscriptions[sub_id] = subscription_data - - with open(SUBSCRIPTION_DB, 'w') as f: - json.dump(subscriptions, f, indent=2) - -def calculate_expiry(duration_hours): - """Calculate expiry date based on subscription duration""" - return (datetime.datetime.now() + - datetime.timedelta(hours=duration_hours)).isoformat() - def run_ansible_playbook(invoice_id): """Run the VPN provisioning playbook""" vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '') @@ -142,12 +117,15 @@ def handle_subscription_status(data): logger.info(f"Processing subscription status update: {sub_id} -> {status}") - # Store subscription data - data['last_updated'] = datetime.datetime.now().isoformat() - save_subscription(data) - - if status != 'Active': - # Run cleanup playbook for inactive subscriptions + subscription = DatabaseManager.get_subscription_by_invoice(sub_id) + if not subscription: + logger.error(f"Subscription {sub_id} not found") + return jsonify({"error": "Subscription not found"}), 404 + + if status == 'Active': + DatabaseManager.activate_subscription(sub_id) + else: + # Run cleanup for inactive subscriptions result = subprocess.run([ 'ansible-playbook', str(CLEANUP_PLAYBOOK), @@ -158,7 +136,8 @@ def handle_subscription_status(data): if result.returncode != 0: logger.error(f"Failed to clean up subscription {sub_id}: {result.stderr}") - + + DatabaseManager.expire_subscription(subscription.id) logger.info(f"Subscription {sub_id} is no longer active") return jsonify({ @@ -171,9 +150,10 @@ def handle_subscription_renewal(data): sub_id = data['subscriptionId'] logger.info(f"Processing subscription renewal request: {sub_id}") - # Update subscription data - data['renewal_requested'] = datetime.datetime.now().isoformat() - save_subscription(data) + subscription = DatabaseManager.get_subscription_by_invoice(sub_id) + if not subscription: + logger.error(f"Subscription {sub_id} not found") + return jsonify({"error": "Subscription not found"}), 404 # TODO: Send renewal notification to user return jsonify({ @@ -199,6 +179,16 @@ def handle_payment_webhook(request): data = request.json logger.info(f"Received webhook data: {data}") + + # Handle test webhooks + invoice_id = data.get('invoiceId', '') + if invoice_id.startswith('__test__'): + logger.info(f"Received test webhook, acknowledging: {data.get('type')}") + return jsonify({ + "status": "success", + "message": "Test webhook acknowledged" + }) + webhook_type = data.get('type') if webhook_type == 'SubscriptionStatusUpdated': @@ -207,20 +197,13 @@ def handle_payment_webhook(request): elif webhook_type == 'SubscriptionRenewalRequested': return handle_subscription_renewal(data) - elif webhook_type == 'InvoiceSettled': - # Handle regular invoice payment + elif webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']: invoice_id = data.get('invoiceId') - metadata = data.get('metadata', {}) - if not invoice_id: logger.error("Missing invoiceId in webhook data") return jsonify({"error": "Missing invoiceId"}), 400 - - if invoice_id.startswith('__test__') and invoice_id.endswith('__test__'): - invoice_id = invoice_id[8:-8] - logger.info(f"Stripped test markers from invoice ID: {invoice_id}") - # Run Ansible playbook with enhanced logging + # Get subscription and run Ansible playbook logger.info(f"Starting VPN provisioning for invoice {invoice_id}") result = run_ansible_playbook(invoice_id) @@ -236,46 +219,18 @@ def handle_payment_webhook(request): "stderr": result.stderr }), 500 + # Get subscription and activate it + subscription = DatabaseManager.get_subscription_by_invoice(invoice_id) + if subscription: + subscription = DatabaseManager.activate_subscription(invoice_id) + DatabaseManager.record_payment( + subscription.user_id, + subscription.id, + invoice_id, + data.get('amount', 0) + ) + logger.info(f"VPN provisioning completed for invoice {invoice_id}") - - # Update subscription database - try: - duration_hours = metadata.get('duration_hours', 24) - subscriptions = load_subscriptions() - subscriptions[invoice_id] = { - 'email': metadata.get('email'), - 'duration_hours': duration_hours, - 'start_time': datetime.datetime.now().isoformat(), - 'expiry': calculate_expiry(duration_hours), - 'status': 'Active' - } - with open(SUBSCRIPTION_DB, 'w') as f: - json.dump(subscriptions, f, indent=2) - logger.info(f"Updated subscription database for invoice {invoice_id}") - except Exception as e: - logger.error(f"Error updating subscription database: {str(e)}") - - # Send email confirmation if email is provided - try: - email = metadata.get('email') - if email: - config_path = f"/etc/wireguard/clients/{invoice_id}/wg0.conf" - if os.path.exists(config_path): - with open(config_path, 'r') as f: - config_data = f.read() - - btcpay_handler = BTCPayHandler() - email_sent = btcpay_handler.send_confirmation_email(email, config_data) - if email_sent: - logger.info(f"Sent configuration email to {email}") - else: - logger.warning(f"Failed to send configuration email to {email}") - else: - logger.warning(f"Config file not found at {config_path}") - except Exception as e: - logger.error(f"Error sending confirmation email: {str(e)}") - - logger.info(f"Successfully processed invoice {invoice_id}") return jsonify({ "status": "success", "invoice_id": invoice_id, @@ -283,7 +238,6 @@ def handle_payment_webhook(request): }) else: - # Log other webhook types as info instead of warning logger.info(f"Received {webhook_type} webhook - no action required") return jsonify({ "status": "success", diff --git a/app/static/js/components/WireGuardPayment.jsx b/app/static/js/components/WireGuardPayment.jsx new file mode 100644 index 0000000..30c065b --- /dev/null +++ b/app/static/js/components/WireGuardPayment.jsx @@ -0,0 +1,213 @@ +import React, { useState, useEffect } from 'react'; +import { generateKeys } from '../utils/wireguard'; +import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/card'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { AlertCircle } from 'lucide-react'; +import { Alert, AlertDescription } from '@/components/ui/alert'; + +const WireGuardPayment = () => { + const [keyData, setKeyData] = useState(null); + const [duration, setDuration] = useState(24); + const [price, setPrice] = useState(0); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(''); + const [userId, setUserId] = useState(''); + + useEffect(() => { + // Generate random userId and keys when component mounts + const init = async () => { + try { + const randomId = crypto.randomUUID(); + setUserId(randomId); + const keys = await generateKeys(); + setKeyData(keys); + calculatePrice(duration); + } catch (err) { + setError('Failed to initialize keys'); + console.error(err); + } + }; + init(); + }, []); + + const calculatePrice = async (hours) => { + try { + const response = await fetch('/api/calculate-price', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ hours: parseInt(hours) }) + }); + const data = await response.json(); + setPrice(data.price); + } catch (err) { + setError('Failed to calculate price'); + } + }; + + const handleDurationChange = (event) => { + const newDuration = parseInt(event.target.value); + setDuration(newDuration); + calculatePrice(newDuration); + }; + + const handlePayment = async () => { + if (!keyData) { + setError('No keys generated. Please refresh the page.'); + return; + } + + try { + setLoading(true); + const response = await fetch('/create-invoice', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + duration, + userId, + publicKey: keyData.publicKey, + // Don't send private key to server! + configuration: { + type: 'wireguard', + publicKey: keyData.publicKey + } + }) + }); + + if (!response.ok) { + throw new Error('Failed to create payment'); + } + + const data = await response.json(); + + // Save private key to localStorage before redirecting + localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify({ + privateKey: keyData.privateKey, + publicKey: keyData.publicKey, + createdAt: new Date().toISOString() + })); + + window.location.href = data.checkout_url; + } catch (err) { + setError('Failed to initiate payment'); + console.error(err); + } finally { + setLoading(false); + } + }; + + const handleRegenerateKeys = async () => { + try { + const keys = await generateKeys(); + setKeyData(keys); + } catch (err) { + setError('Failed to regenerate keys'); + } + }; + + return ( + + + WireGuard VPN Configuration + + + {error && ( + + + {error} + + )} + +
+ + +

+ Save this ID - you'll need it to manage your subscription +

+
+ + {keyData && ( +
+ + + +
+ )} + +
+ +
+ +
+ + + +
+

{duration} hours

+
+
+ +
+

{price} sats

+
+ + + +
+

• Keys are generated securely in your browser

+

• Your private key never leaves your device

+

• Configuration will be available after payment

+
+
+
+ ); +}; + +export default WireGuardPayment; \ No newline at end of file diff --git a/app/static/js/pricing.js b/app/static/js/pricing.js index 5169466..698d9a7 100644 --- a/app/static/js/pricing.js +++ b/app/static/js/pricing.js @@ -1,11 +1,39 @@ -// app/static/js/pricing.js -document.addEventListener('DOMContentLoaded', function() { +// Base64 encoding/decoding utilities +const b64 = { + encode: array => btoa(String.fromCharCode.apply(null, array)), + decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0)) +}; + +async function generateKeyPair() { + const keyPair = await window.crypto.subtle.generateKey( + { + name: 'X25519', + namedCurve: 'X25519', + }, + true, + ['deriveKey', 'deriveBits'] + ); + + const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey); + const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey); + + return { + privateKey: b64.encode(new Uint8Array(privateKey)), + publicKey: b64.encode(new Uint8Array(publicKey)) + }; +} + +document.addEventListener('DOMContentLoaded', async function() { const form = document.getElementById('subscription-form'); const slider = document.getElementById('duration-slider'); const durationDisplay = document.getElementById('duration-display'); const priceDisplay = document.getElementById('price-display'); const presetButtons = document.querySelectorAll('.duration-preset'); - const emailInput = document.getElementById('email'); + const userIdInput = document.getElementById('user-id'); + const publicKeyInput = document.getElementById('public-key'); + const regenerateButton = document.getElementById('regenerate-keys'); + + let currentKeyPair = null; function formatDuration(hours) { if (hours < 24) return `${hours} hour${hours === 1 ? '' : 's'}`; @@ -14,6 +42,32 @@ document.addEventListener('DOMContentLoaded', function() { return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`; } + async function generateNewKeys() { + try { + currentKeyPair = await generateKeyPair(); + publicKeyInput.value = currentKeyPair.publicKey; + + // Save private key to localStorage + const keyData = { + privateKey: currentKeyPair.privateKey, + publicKey: currentKeyPair.publicKey, + createdAt: new Date().toISOString() + }; + localStorage.setItem(`vpn_keys_${userIdInput.value}`, JSON.stringify(keyData)); + } catch (error) { + console.error('Failed to generate keys:', error); + alert('Failed to generate WireGuard keys. Please try again.'); + } + } + + async function initializeForm() { + // Generate user ID + userIdInput.value = crypto.randomUUID(); + + // Generate initial keys + await generateNewKeys(); + } + async function updatePrice(hours) { try { const response = await fetch('/api/calculate-price', { @@ -29,7 +83,27 @@ document.addEventListener('DOMContentLoaded', function() { } } - async function createInvoice(duration, email) { + // Event listeners + slider.addEventListener('input', () => updatePrice(slider.value)); + + regenerateButton.addEventListener('click', generateNewKeys); + + presetButtons.forEach(button => { + button.addEventListener('click', (e) => { + const hours = e.target.dataset.hours; + slider.value = hours; + updatePrice(hours); + }); + }); + + form.addEventListener('submit', async (e) => { + e.preventDefault(); + + if (!currentKeyPair) { + alert('No keys generated. Please refresh the page.'); + return; + } + try { const response = await fetch('/create-invoice', { method: 'POST', @@ -37,8 +111,9 @@ document.addEventListener('DOMContentLoaded', function() { 'Content-Type': 'application/json' }, body: JSON.stringify({ - duration: parseInt(duration), - email: email + duration: parseInt(slider.value), + userId: userIdInput.value, + publicKey: currentKeyPair.publicKey }) }); @@ -53,32 +128,10 @@ document.addEventListener('DOMContentLoaded', function() { console.error('Error creating invoice:', error); alert('Failed to create payment invoice. Please try again.'); } - } - - // Event listeners - slider.addEventListener('input', () => updatePrice(slider.value)); - - presetButtons.forEach(button => { - button.addEventListener('click', (e) => { - const hours = e.target.dataset.hours; - slider.value = hours; - updatePrice(hours); - }); - }); - - form.addEventListener('submit', async (e) => { - e.preventDefault(); - - const email = emailInput.value.trim(); - if (!email) { - alert('Please enter your email address'); - return; - } - - const duration = slider.value; - await createInvoice(duration, email); }); + // Initialize the form + await initializeForm(); // Initial price calculation updatePrice(slider.value); }); \ No newline at end of file diff --git a/app/static/js/utils/wireguard.js b/app/static/js/utils/wireguard.js new file mode 100644 index 0000000..21985ba --- /dev/null +++ b/app/static/js/utils/wireguard.js @@ -0,0 +1,54 @@ +// Base64 encoding/decoding utilities +const b64 = { + encode: array => btoa(String.fromCharCode.apply(null, array)), + decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0)) + }; + + async function generateKeyPair() { + // Generate a random key pair using Web Crypto API + const keyPair = await window.crypto.subtle.generateKey( + { + name: 'X25519', + namedCurve: 'X25519', + }, + true, + ['deriveKey', 'deriveBits'] + ); + + // Export keys in raw format + const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey); + const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey); + + // Convert to base64 + return { + privateKey: b64.encode(new Uint8Array(privateKey)), + publicKey: b64.encode(new Uint8Array(publicKey)) + }; + } + + export async function generateWireGuardConfig(serverPublicKey, serverEndpoint, address) { + const keys = await generateKeyPair(); + + return { + keys, + config: `[Interface] + PrivateKey = ${keys.privateKey} + Address = ${address} + DNS = 1.1.1.1 + + [Peer] + PublicKey = ${serverPublicKey} + Endpoint = ${serverEndpoint} + AllowedIPs = 0.0.0.0/0 + PersistentKeepalive = 25` + }; + } + + export async function generateKeys() { + try { + return await generateKeyPair(); + } catch (error) { + console.error('Failed to generate WireGuard keys:', error); + throw error; + } + } \ No newline at end of file diff --git a/app/templates/index.html b/app/templates/index.html index 2125312..284d890 100644 --- a/app/templates/index.html +++ b/app/templates/index.html @@ -1,30 +1,45 @@ {% extends "base.html" %} - {% block content %}

Subscribe to VPN Service

-
- - User ID + +

Save this ID - you'll need it to manage your subscription

+
+ + + +
+
- @@ -44,12 +59,18 @@

- + +
+

• Keys are generated securely in your browser

+

• Your private key never leaves your device

+

• Configuration will be available after payment

+
diff --git a/app/templates/payment_success.html b/app/templates/payment_success.html index 05bd9ba..d309da0 100644 --- a/app/templates/payment_success.html +++ b/app/templates/payment_success.html @@ -2,17 +2,110 @@ {% block content %}
-
-

Payment Successful!

-

- Thank you for your payment. Your VPN configuration will be sent to your email shortly. -

-

- Please check your email for further instructions on setting up your VPN connection. -

- - Return to Home - +
+

Payment Successful!

+ +
+

+ Your VPN subscription is now active. Please follow the instructions below to set up your VPN connection. +

+
+ +
+ + + + +
+

Installation Instructions

+
    +
  1. Download WireGuard for your platform: + +
  2. +
  3. Create a new tunnel in WireGuard
  4. +
  5. Copy the configuration above and paste it into the new tunnel
  6. +
  7. Activate the tunnel to connect
  8. +
+
+ +
+

• Save your configuration securely - you'll need it to reconnect

+

• Your private key is stored in your browser's local storage

+

• For security, clear your browser data after saving the configuration

+
+
+ +
+ + {% endblock %} \ No newline at end of file diff --git a/app/utils/db/__init__.py b/app/utils/db/__init__.py new file mode 100644 index 0000000..8fdadd7 --- /dev/null +++ b/app/utils/db/__init__.py @@ -0,0 +1,23 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from pathlib import Path + +def get_db_path(): + base_dir = Path(__file__).resolve().parent.parent.parent + data_dir = base_dir / 'data' + data_dir.mkdir(exist_ok=True) + return data_dir / 'vpn.db' + +def init_db(): + """Initialize the database""" + from .models import Base + db_url = f"sqlite:///{get_db_path()}" + engine = create_engine(db_url) + Base.metadata.create_all(engine) + return engine + +def get_session(): + """Get a database session""" + engine = init_db() + Session = sessionmaker(bind=engine) + return Session() \ No newline at end of file diff --git a/app/utils/db/models.py b/app/utils/db/models.py new file mode 100644 index 0000000..9390fda --- /dev/null +++ b/app/utils/db/models.py @@ -0,0 +1,52 @@ +from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Enum, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship +import enum +import datetime + +Base = declarative_base() + +class SubscriptionStatus(enum.Enum): + ACTIVE = "active" + EXPIRED = "expired" + PENDING = "pending" + CANCELLED = "cancelled" + +class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + user_id = Column(String, unique=True, nullable=False) # UUID generated in frontend + created_at = Column(DateTime, default=datetime.datetime.utcnow) + + subscriptions = relationship("Subscription", back_populates="user") + payments = relationship("Payment", back_populates="user") + +class Subscription(Base): + __tablename__ = 'subscriptions' + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('users.id')) + invoice_id = Column(String, unique=True) + public_key = Column(Text, nullable=False) # WireGuard public key + start_time = Column(DateTime, nullable=False) + expiry_time = Column(DateTime, nullable=False) + status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.PENDING) + warning_sent = Column(Integer, default=0) + assigned_ip = Column(String) # WireGuard IP address assigned to this subscription + + user = relationship("User", back_populates="subscriptions") + payments = relationship("Payment", back_populates="subscription") + +class Payment(Base): + __tablename__ = 'payments' + + id = Column(Integer, primary_key=True) + user_id = Column(Integer, ForeignKey('users.id')) + subscription_id = Column(Integer, ForeignKey('subscriptions.id')) + invoice_id = Column(String, unique=True) + amount = Column(Integer, nullable=False) # Amount in sats + timestamp = Column(DateTime, default=datetime.datetime.utcnow) + + user = relationship("User", back_populates="payments") + subscription = relationship("Subscription", back_populates="payments") \ No newline at end of file diff --git a/app/utils/db/operations.py b/app/utils/db/operations.py new file mode 100644 index 0000000..e5f00f5 --- /dev/null +++ b/app/utils/db/operations.py @@ -0,0 +1,152 @@ +from datetime import datetime +from sqlalchemy.exc import SQLAlchemyError +from . import get_session +from .models import User, Subscription, Payment, SubscriptionStatus +import logging +import ipaddress + +logger = logging.getLogger(__name__) + +class DatabaseManager: + @staticmethod + def get_next_available_ip(): + """Get the next available IP from the WireGuard subnet""" + with get_session() as session: + # Get all assigned IPs + assigned_ips = session.query(Subscription.assigned_ip)\ + .filter(Subscription.assigned_ip.isnot(None))\ + .all() + assigned_ips = [ip[0] for ip in assigned_ips] + + # Start from 10.8.0.2 (10.8.0.1 is server) + network = ipaddress.IPv4Network('10.8.0.0/24') + for ip in network.hosts(): + str_ip = str(ip) + if str_ip == '10.8.0.1': # Skip server IP + continue + if str_ip not in assigned_ips: + return str_ip + + raise Exception("No available IPs in the subnet") + + @staticmethod + def create_user(user_id): + """Create a new user with UUID""" + with get_session() as session: + user = User(user_id=user_id) + session.add(user) + session.commit() + return user + + @staticmethod + def get_user_by_uuid(user_id): + """Get user by UUID""" + with get_session() as session: + return session.query(User).filter(User.user_id == user_id).first() + + @staticmethod + def get_subscription_by_invoice(invoice_id): + """Get subscription by invoice ID""" + with get_session() as session: + return session.query(Subscription)\ + .filter(Subscription.invoice_id == invoice_id)\ + .first() + + @staticmethod + def create_subscription(user_id, invoice_id, public_key, duration_hours): + """Create a new subscription""" + with get_session() as session: + try: + # Get user or create if doesn't exist + user = DatabaseManager.get_user_by_uuid(user_id) + if not user: + user = DatabaseManager.create_user(user_id) + + start_time = datetime.utcnow() + expiry_time = start_time + datetime.timedelta(hours=duration_hours) + + # Get next available IP + assigned_ip = DatabaseManager.get_next_available_ip() + + subscription = Subscription( + user_id=user.id, + invoice_id=invoice_id, + public_key=public_key, + start_time=start_time, + expiry_time=expiry_time, + status=SubscriptionStatus.PENDING, + assigned_ip=assigned_ip + ) + session.add(subscription) + session.commit() + return subscription + except Exception as e: + logger.error(f"Error creating subscription: {str(e)}") + session.rollback() + raise + + @staticmethod + def activate_subscription(invoice_id): + """Activate a subscription after payment""" + with get_session() as session: + subscription = session.query(Subscription)\ + .filter(Subscription.invoice_id == invoice_id)\ + .first() + + if subscription: + subscription.status = SubscriptionStatus.ACTIVE + session.commit() + return subscription + + @staticmethod + def record_payment(user_id, subscription_id, invoice_id, amount): + """Record a payment""" + with get_session() as session: + user = DatabaseManager.get_user_by_uuid(user_id) + if not user: + raise ValueError(f"User {user_id} not found") + + payment = Payment( + user_id=user.id, + subscription_id=subscription_id, + invoice_id=invoice_id, + amount=amount + ) + session.add(payment) + session.commit() + return payment + + @staticmethod + def get_active_subscriptions(): + """Get all active subscriptions""" + with get_session() as session: + return session.query(Subscription)\ + .filter(Subscription.status == SubscriptionStatus.ACTIVE)\ + .all() + + @staticmethod + def expire_subscription(subscription_id): + """Mark a subscription as expired""" + with get_session() as session: + subscription = session.query(Subscription)\ + .filter(Subscription.id == subscription_id)\ + .first() + + if subscription: + subscription.status = SubscriptionStatus.EXPIRED + session.commit() + return True + return False + + @staticmethod + def update_warning_sent(subscription_id): + """Update the warning_sent flag for a subscription""" + with get_session() as session: + subscription = session.query(Subscription)\ + .filter(Subscription.id == subscription_id)\ + .first() + if subscription: + subscription.warning_sent = 1 + session.commit() + return True + return False \ No newline at end of file diff --git a/app/utils/db/utils.py b/app/utils/db/utils.py new file mode 100644 index 0000000..17f1e58 --- /dev/null +++ b/app/utils/db/utils.py @@ -0,0 +1,59 @@ +import sqlite3 +import shutil +from datetime import datetime +from pathlib import Path +import logging +from .. import get_db_path + +logger = logging.getLogger(__name__) + +def backup_database(): + """Create a backup of the SQLite database""" + try: + db_path = get_db_path() + if not db_path.exists(): + logger.error("Database file not found") + return False + + backup_dir = db_path.parent / 'backups' + backup_dir.mkdir(exist_ok=True) + + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + backup_path = backup_dir / f'vpn_db_backup_{timestamp}.db' + + shutil.copy2(db_path, backup_path) + logger.info(f"Database backed up successfully to {backup_path}") + + # Keep only last 5 backups + backups = sorted(backup_dir.glob('vpn_db_backup_*.db')) + if len(backups) > 5: + for old_backup in backups[:-5]: + old_backup.unlink() + logger.info(f"Removed old backup: {old_backup}") + + return True + + except Exception as e: + logger.error(f"Backup failed: {str(e)}") + return False + +def verify_database(): + """Check database integrity""" + try: + db_path = get_db_path() + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + cursor.execute("PRAGMA integrity_check") + result = cursor.fetchone()[0] + + if result != "ok": + logger.error(f"Database integrity check failed: {result}") + return False + + logger.info("Database integrity verified") + return True + + except Exception as e: + logger.error(f"Database verification failed: {str(e)}") + return False \ No newline at end of file diff --git a/data/vpn.db b/data/vpn.db new file mode 100644 index 0000000..6ab8f2a Binary files /dev/null and b/data/vpn.db differ diff --git a/requirements.txt b/requirements.txt index 99ca526..5e36ecb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ pyyaml==6.0.1 python-dotenv==1.0.0 cryptography==41.0.7 ansible==9.1.0 -requests==2.31.0 \ No newline at end of file +requests==2.31.0 +SQLAlchemy==2.0.25 \ No newline at end of file diff --git a/scripts/init_db.py b/scripts/init_db.py new file mode 100644 index 0000000..cb7649e --- /dev/null +++ b/scripts/init_db.py @@ -0,0 +1,28 @@ +# scripts/init_db.py +import sys +from pathlib import Path + +# Add the project root to Python path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from app.utils.db.models import Base +from sqlalchemy import create_engine + +def init_db(): + # Create data directory if it doesn't exist + data_dir = project_root / 'data' + data_dir.mkdir(exist_ok=True) + + # Create database + db_path = data_dir / 'vpn.db' + db_url = f"sqlite:///{db_path}" + engine = create_engine(db_url) + + # Create all tables + Base.metadata.create_all(engine) + print(f"Database initialized at: {db_path}") + return engine + +if __name__ == "__main__": + init_db() \ No newline at end of file diff --git a/scripts/migrate_db.py b/scripts/migrate_db.py new file mode 100644 index 0000000..4e68743 --- /dev/null +++ b/scripts/migrate_db.py @@ -0,0 +1,151 @@ +import sys +from pathlib import Path +import sqlite3 +import uuid +import logging +from datetime import datetime + +# Add the project root to Python path +project_root = Path(__file__).resolve().parent.parent +sys.path.append(str(project_root)) + +from app.utils.db.models import Base, SubscriptionStatus +from app.utils.db import get_db_path, get_session +from sqlalchemy import create_engine + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def backup_database(): + """Create a backup of the existing database""" + db_path = get_db_path() + if not db_path.exists(): + logger.info("No existing database found, skipping backup") + return + + backup_path = db_path.parent / f"{db_path.stem}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}{db_path.suffix}" + db_path.rename(backup_path) + logger.info(f"Created database backup at {backup_path}") + return backup_path + +def migrate_database(): + """Perform database migration""" + try: + # Create backup + backup_path = backup_database() + + # Initialize new database + engine = create_engine(f"sqlite:///{get_db_path()}") + Base.metadata.create_all(engine) + + # If there's an old database, migrate the data + if backup_path and backup_path.exists(): + logger.info("Migrating data from old database...") + + # Connect to old database + old_conn = sqlite3.connect(backup_path) + old_cursor = old_conn.cursor() + + # Get schema information + old_cursor.execute("SELECT sql FROM sqlite_master WHERE type='table'") + old_schema = old_cursor.fetchall() + logger.info("Old database schema:") + for table in old_schema: + logger.info(table[0]) + + with get_session() as session: + try: + # Start transaction + session.begin() + + # Migrate users + logger.info("Migrating users...") + old_cursor.execute("SELECT id, email, created_at FROM users") + users = old_cursor.fetchall() + user_id_map = {} # Map old IDs to new UUIDs + + for old_id, email, created_at in users: + new_user_id = str(uuid.uuid4()) + user_id_map[old_id] = new_user_id + session.execute( + "INSERT INTO users (user_id, created_at) VALUES (?, ?)", + [new_user_id, created_at or datetime.utcnow()] + ) + + # Migrate subscriptions + logger.info("Migrating subscriptions...") + old_cursor.execute(""" + SELECT id, user_id, invoice_id, start_time, expiry_time, + status, warning_sent + FROM subscriptions + """) + subscriptions = old_cursor.fetchall() + + for sub in subscriptions: + old_id, old_user_id, invoice_id, start_time, expiry_time, status, warning_sent = sub + if old_user_id in user_id_map: + # Generate a placeholder public key for existing subscriptions + placeholder_pubkey = f"MIGRATED_{uuid.uuid4()}" + session.execute(""" + INSERT INTO subscriptions + (user_id, invoice_id, public_key, start_time, expiry_time, + status, warning_sent, assigned_ip) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, [ + user_id_map[old_user_id], + invoice_id, + placeholder_pubkey, + start_time, + expiry_time, + status, + warning_sent, + f"10.8.0.{2 + old_id}" # Simple IP assignment + ]) + + # Migrate payments + logger.info("Migrating payments...") + old_cursor.execute(""" + SELECT user_id, subscription_id, invoice_id, amount, timestamp + FROM payments + """) + payments = old_cursor.fetchall() + + for payment in payments: + old_user_id, sub_id, invoice_id, amount, timestamp = payment + if old_user_id in user_id_map: + session.execute(""" + INSERT INTO payments + (user_id, subscription_id, invoice_id, amount, timestamp) + VALUES (?, ?, ?, ?, ?) + """, [ + user_id_map[old_user_id], + sub_id, + invoice_id, + amount, + timestamp + ]) + + # Commit transaction + session.commit() + logger.info("Migration completed successfully") + + except Exception as e: + session.rollback() + logger.error(f"Migration failed: {str(e)}") + raise + + old_conn.close() + + else: + logger.info("No existing database to migrate") + + except Exception as e: + logger.error(f"Migration failed: {str(e)}") + raise + +if __name__ == '__main__': + try: + migrate_database() + except Exception as e: + logger.error(f"Migration failed: {str(e)}") + sys.exit(1) \ No newline at end of file diff --git a/scripts/subscription_checker.py b/scripts/subscription_checker.py index e3c3acc..9e6f368 100644 --- a/scripts/subscription_checker.py +++ b/scripts/subscription_checker.py @@ -1,14 +1,12 @@ -# scripts/subscription_checker.py - -import json -import datetime import logging import subprocess import os import tempfile from pathlib import Path -from dateutil.parser import parse -from dateutil.relativedelta import relativedelta +from datetime import datetime, timedelta +from utils.db.operations import DatabaseManager +from utils.db.models import SubscriptionStatus +from app.handlers.payment_handler import BTCPayHandler logging.basicConfig( level=logging.INFO, @@ -19,20 +17,19 @@ logger = logging.getLogger(__name__) # Path setup SCRIPT_DIR = Path(__file__).resolve().parent PROJECT_ROOT = SCRIPT_DIR.parent -SUBSCRIPTION_DB = PROJECT_ROOT / 'data' / 'subscriptions.json' CLEANUP_PLAYBOOK = PROJECT_ROOT / 'ansible' / 'playbooks' / 'vpn_cleanup.yml' INVENTORY_FILE = PROJECT_ROOT / 'inventory.ini' # Notification thresholds configuration NOTIFICATION_THRESHOLDS = { - 'minimum_duration': datetime.timedelta(hours=1), # Minimum subscription duration + 'minimum_duration': timedelta(hours=1), # Minimum subscription duration 'short_term': { - 'max_duration': datetime.timedelta(days=1), + 'max_duration': timedelta(days=1), 'warning_fraction': 0.5, # Warn when 50% of time remains 'grace_fraction': 0.1 # Grace period of 10% of subscription length }, 'medium_term': { - 'max_duration': datetime.timedelta(days=7), + 'max_duration': timedelta(days=7), 'warning_fraction': 0.25, # Warn when 25% of time remains 'grace_hours': 12 # Fixed 12-hour grace period }, @@ -42,25 +39,9 @@ NOTIFICATION_THRESHOLDS = { } } -def load_subscriptions(): - """Load subscription data from JSON file""" - if not SUBSCRIPTION_DB.parent.exists(): - SUBSCRIPTION_DB.parent.mkdir(parents=True) - - if not SUBSCRIPTION_DB.exists(): - return {} - - with open(SUBSCRIPTION_DB, 'r') as f: - return json.load(f) - -def save_subscriptions(subscriptions): - """Save subscriptions to JSON file""" - with open(SUBSCRIPTION_DB, 'w') as f: - json.dump(subscriptions, f, indent=2) - -def run_cleanup_playbook(sub_id): +def run_cleanup_playbook(subscription_id): """Run the VPN cleanup playbook""" - logger.info(f"Running cleanup playbook for subscription {sub_id}") + logger.info(f"Running cleanup playbook for subscription {subscription_id}") vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '') if not vault_pass: @@ -74,7 +55,7 @@ def run_cleanup_playbook(sub_id): 'ansible-playbook', str(CLEANUP_PLAYBOOK), '-i', str(INVENTORY_FILE), - '-e', f'subscription_id={sub_id}', + '-e', f'subscription_id={subscription_id}', '--vault-password-file', vault_pass_file.name, '-vvv' ] @@ -122,42 +103,36 @@ def calculate_notification_times(start_time, end_time): warning_delta = duration * NOTIFICATION_THRESHOLDS['medium_term']['warning_fraction'] grace_hours = NOTIFICATION_THRESHOLDS['medium_term']['grace_hours'] warning_time = end_time - warning_delta - grace_end_time = end_time + datetime.timedelta(hours=grace_hours) + grace_end_time = end_time + timedelta(hours=grace_hours) # Long-term subscriptions (> 7 days) else: warning_days = NOTIFICATION_THRESHOLDS['long_term']['warning_days'] grace_days = NOTIFICATION_THRESHOLDS['long_term']['grace_days'] - warning_time = end_time - datetime.timedelta(days=warning_days) - grace_end_time = end_time + datetime.timedelta(days=grace_days) + warning_time = end_time - timedelta(days=warning_days) + grace_end_time = end_time + timedelta(days=grace_days) return warning_time, grace_end_time -def get_notification_message(sub_data, remaining_time): +def get_notification_message(subscription, remaining_time): """Generate appropriate notification message based on subscription duration""" - if remaining_time < datetime.timedelta(hours=1): + if remaining_time < timedelta(hours=1): return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 60)} minutes!" - elif remaining_time < datetime.timedelta(days=1): + elif remaining_time < timedelta(days=1): return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 3600)} hours!" else: return f"Your VPN subscription expires in {remaining_time.days} days!" -def notify_user(user_id, message): +def notify_user(subscription, message): """Send notification to user about subscription status""" try: - subscriptions = load_subscriptions() - sub_data = subscriptions.get(user_id) - - if not sub_data or 'email' not in sub_data: - logger.error(f"No email found for user {user_id}") + if not subscription.user or not subscription.user.email: + logger.error(f"No email found for subscription {subscription.id}") return False - - # Import BTCPayHandler here to avoid circular imports - from app.handlers.payment_handler import BTCPayHandler btcpay_handler = BTCPayHandler() email_sent = btcpay_handler.send_confirmation_email( - sub_data['email'], + subscription.user.email, f""" VPN Subscription Update @@ -168,9 +143,9 @@ def notify_user(user_id, message): ) if email_sent: - logger.info(f"Sent notification to {sub_data['email']}: {message}") + logger.info(f"Sent notification to {subscription.user.email}: {message}") else: - logger.warning(f"Failed to send notification to {sub_data['email']}") + logger.warning(f"Failed to send notification to {subscription.user.email}") return email_sent @@ -182,68 +157,64 @@ def check_subscriptions(): """Check subscription status and clean up expired ones""" logger.info("Starting subscription check") - subscriptions = load_subscriptions() - now = datetime.datetime.now() - modified = False - - logger.info(f"Checking {len(subscriptions)} subscriptions") - - for sub_id, sub_data in list(subscriptions.items()): - try: - logger.debug(f"Processing subscription {sub_id}") - - if sub_data.get('status') != 'Active': - logger.debug(f"Skipping inactive subscription {sub_id}") - continue + try: + active_subscriptions = DatabaseManager.get_active_subscriptions() + logger.info(f"Checking {len(active_subscriptions)} active subscriptions") + + now = datetime.utcnow() + + for subscription in active_subscriptions: + try: + logger.debug(f"Processing subscription {subscription.id}") - start_time = parse(sub_data['start_time']) - expiry = parse(sub_data['expiry']) - warning_time, grace_end_time = calculate_notification_times(start_time, expiry) - - # Calculate remaining time - remaining_time = expiry - now - logger.debug(f"Subscription {sub_id} has {remaining_time} remaining") - - # Handle warnings - if now >= warning_time and not sub_data.get('warning_sent'): - message = get_notification_message(sub_data, remaining_time) - logger.info(f"Sending notification for subscription {sub_id}: {message}") + warning_time, grace_end_time = calculate_notification_times( + subscription.start_time, + subscription.expiry_time + ) - notify_user(sub_id, message) + # Calculate remaining time + remaining_time = subscription.expiry_time - now + logger.debug(f"Subscription {subscription.id} has {remaining_time} remaining") - sub_data['warning_sent'] = True - modified = True - - # Handle expiration - if now >= grace_end_time and sub_data.get('status') == 'Active': - logger.info(f"Processing expiration for subscription {sub_id}") - - try: - logger.debug(f"Running cleanup playbook for {sub_id}") - result = run_cleanup_playbook(sub_id) + # Handle warnings + if now >= warning_time and not subscription.warning_sent: + message = get_notification_message(subscription, remaining_time) + logger.info(f"Sending notification for subscription {subscription.id}: {message}") - if result.returncode == 0: - sub_data['status'] = 'Expired' - sub_data['cleanup_date'] = now.isoformat() - modified = True - logger.info(f"Successfully cleaned up subscription {sub_id}") - else: - logger.error(f"Cleanup failed: {result.stderr}") + if notify_user(subscription, message): + DatabaseManager.update_warning_sent(subscription.id) + + # Handle expiration + if now >= grace_end_time: + logger.info(f"Processing expiration for subscription {subscription.id}") + + try: + result = run_cleanup_playbook(subscription.invoice_id) - except Exception as e: - logger.error(f"Error during cleanup: {str(e)}") - logger.error(traceback.format_exc()) - - except Exception as e: - logger.error(f"Error processing subscription {sub_id}: {str(e)}") - logger.error(traceback.format_exc()) - continue - - if modified: - save_subscriptions(subscriptions) - logger.info("Updated subscription database") - - logger.info("Subscription check completed") + if result.returncode == 0: + DatabaseManager.expire_subscription(subscription.id) + logger.info(f"Successfully cleaned up subscription {subscription.id}") + + # Send final notification + notify_user(subscription, "Your VPN subscription has expired and been deactivated.") + else: + logger.error(f"Cleanup failed: {result.stderr}") + + except Exception as e: + logger.error(f"Error during cleanup: {str(e)}") + logger.error(traceback.format_exc()) + + except Exception as e: + logger.error(f"Error processing subscription {subscription.id}: {str(e)}") + logger.error(traceback.format_exc()) + continue + + logger.info("Subscription check completed") + + except Exception as e: + logger.error(f"Subscription checker failed: {str(e)}") + logger.error(traceback.format_exc()) + raise if __name__ == '__main__': try: diff --git a/venv/include/site/python3.11/greenlet/greenlet.h b/venv/include/site/python3.11/greenlet/greenlet.h new file mode 100644 index 0000000..d02a16e --- /dev/null +++ b/venv/include/site/python3.11/greenlet/greenlet.h @@ -0,0 +1,164 @@ +/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */ + +/* Greenlet object interface */ + +#ifndef Py_GREENLETOBJECT_H +#define Py_GREENLETOBJECT_H + + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* This is deprecated and undocumented. It does not change. */ +#define GREENLET_VERSION "1.0.0" + +#ifndef GREENLET_MODULE +#define implementation_ptr_t void* +#endif + +typedef struct _greenlet { + PyObject_HEAD + PyObject* weakreflist; + PyObject* dict; + implementation_ptr_t pimpl; +} PyGreenlet; + +#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type)) + + +/* C API functions */ + +/* Total number of symbols that are exported */ +#define PyGreenlet_API_pointers 12 + +#define PyGreenlet_Type_NUM 0 +#define PyExc_GreenletError_NUM 1 +#define PyExc_GreenletExit_NUM 2 + +#define PyGreenlet_New_NUM 3 +#define PyGreenlet_GetCurrent_NUM 4 +#define PyGreenlet_Throw_NUM 5 +#define PyGreenlet_Switch_NUM 6 +#define PyGreenlet_SetParent_NUM 7 + +#define PyGreenlet_MAIN_NUM 8 +#define PyGreenlet_STARTED_NUM 9 +#define PyGreenlet_ACTIVE_NUM 10 +#define PyGreenlet_GET_PARENT_NUM 11 + +#ifndef GREENLET_MODULE +/* This section is used by modules that uses the greenlet C API */ +static void** _PyGreenlet_API = NULL; + +# define PyGreenlet_Type \ + (*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM]) + +# define PyExc_GreenletError \ + ((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM]) + +# define PyExc_GreenletExit \ + ((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM]) + +/* + * PyGreenlet_New(PyObject *args) + * + * greenlet.greenlet(run, parent=None) + */ +# define PyGreenlet_New \ + (*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \ + _PyGreenlet_API[PyGreenlet_New_NUM]) + +/* + * PyGreenlet_GetCurrent(void) + * + * greenlet.getcurrent() + */ +# define PyGreenlet_GetCurrent \ + (*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM]) + +/* + * PyGreenlet_Throw( + * PyGreenlet *greenlet, + * PyObject *typ, + * PyObject *val, + * PyObject *tb) + * + * g.throw(...) + */ +# define PyGreenlet_Throw \ + (*(PyObject * (*)(PyGreenlet * self, \ + PyObject * typ, \ + PyObject * val, \ + PyObject * tb)) \ + _PyGreenlet_API[PyGreenlet_Throw_NUM]) + +/* + * PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args) + * + * g.switch(*args, **kwargs) + */ +# define PyGreenlet_Switch \ + (*(PyObject * \ + (*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \ + _PyGreenlet_API[PyGreenlet_Switch_NUM]) + +/* + * PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent) + * + * g.parent = new_parent + */ +# define PyGreenlet_SetParent \ + (*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \ + _PyGreenlet_API[PyGreenlet_SetParent_NUM]) + +/* + * PyGreenlet_GetParent(PyObject* greenlet) + * + * return greenlet.parent; + * + * This could return NULL even if there is no exception active. + * If it does not return NULL, you are responsible for decrementing the + * reference count. + */ +# define PyGreenlet_GetParent \ + (*(PyGreenlet* (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_GET_PARENT_NUM]) + +/* + * deprecated, undocumented alias. + */ +# define PyGreenlet_GET_PARENT PyGreenlet_GetParent + +# define PyGreenlet_MAIN \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_MAIN_NUM]) + +# define PyGreenlet_STARTED \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_STARTED_NUM]) + +# define PyGreenlet_ACTIVE \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_ACTIVE_NUM]) + + + + +/* Macro that imports greenlet and initializes C API */ +/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we + keep the older definition to be sure older code that might have a copy of + the header still works. */ +# define PyGreenlet_Import() \ + { \ + _PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \ + } + +#endif /* GREENLET_MODULE */ + +#ifdef __cplusplus +} +#endif +#endif /* !Py_GREENLETOBJECT_H */