DB update.
This commit is contained in:
parent
bcd6c435f7
commit
72972efe5e
@ -45,6 +45,22 @@
|
||||
mode: '0644'
|
||||
when: not server_config.stat.exists
|
||||
|
||||
- name: Update vault with server details
|
||||
block:
|
||||
- name: Read server public key
|
||||
shell: "cat {{ server_dir }}/public.key"
|
||||
register: pubkey_content
|
||||
changed_when: false
|
||||
|
||||
- name: Save server details to vault
|
||||
copy:
|
||||
content: |
|
||||
wireguard_server_public_key: "{{ pubkey_content.stdout }}"
|
||||
wireguard_server_endpoint: "{{ ansible_host }}"
|
||||
dest: "{{ playbook_dir }}/../group_vars/vpn_servers/vault.yml"
|
||||
mode: '0600'
|
||||
when: not server_config.stat.exists
|
||||
|
||||
- name: Create initial server config
|
||||
template:
|
||||
src: templates/server.conf.j2
|
||||
|
@ -1,5 +1,3 @@
|
||||
# app/__init__.py
|
||||
|
||||
from flask import Flask, request, jsonify, render_template
|
||||
import logging
|
||||
from .handlers.webhook_handler import handle_payment_webhook
|
||||
|
BIN
app/data/vpn.db
Normal file
BIN
app/data/vpn.db
Normal file
Binary file not shown.
@ -1,15 +1,11 @@
|
||||
# app/handlers/payment_handler.py
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import os
|
||||
from flask import jsonify
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
from .webhook_handler import get_vault_values
|
||||
from ..utils.db.operations import DatabaseManager
|
||||
from ..utils.db.models import SubscriptionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -25,9 +21,6 @@ class BTCPayHandler:
|
||||
self.api_key = vault_values['btcpay_api_key']
|
||||
self.store_id = vault_values['btcpay_store_id']
|
||||
|
||||
# Email configuration
|
||||
self.smtp_config = vault_values.get('smtp_config', {})
|
||||
|
||||
logger.info(f"BTCPayHandler initialized with base URL: {self.base_url}")
|
||||
|
||||
except Exception as e:
|
||||
@ -35,10 +28,16 @@ class BTCPayHandler:
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
def create_invoice(self, amount_sats, duration_hours, email):
|
||||
def create_invoice(self, amount_sats, duration_hours, user_id, public_key):
|
||||
"""Create BTCPay invoice for VPN subscription"""
|
||||
try:
|
||||
logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for {email}")
|
||||
logger.info(f"Creating invoice: {amount_sats} sats, {duration_hours}h for user {user_id}")
|
||||
|
||||
# First, get or create user
|
||||
user = DatabaseManager.get_user_by_uuid(user_id)
|
||||
if not user:
|
||||
user = DatabaseManager.create_user(user_id)
|
||||
logger.info(f"Created new user with ID: {user_id}")
|
||||
|
||||
headers = {
|
||||
'Authorization': f'token {self.api_key}',
|
||||
@ -54,11 +53,12 @@ class BTCPayHandler:
|
||||
'currency': 'SATS',
|
||||
'metadata': {
|
||||
'duration_hours': duration_hours,
|
||||
'email': email,
|
||||
'orderId': f'vpn_sub_{duration_hours}h',
|
||||
'userId': user_id,
|
||||
'publicKey': public_key,
|
||||
'orderId': f'vpn_sub_{duration_hours}h'
|
||||
},
|
||||
'checkout': {
|
||||
'redirectURL': f'{app_url}/payment/success',
|
||||
'redirectURL': f'{app_url}/payment/success?userId={user_id}',
|
||||
'redirectAutomatically': True
|
||||
}
|
||||
}
|
||||
@ -80,11 +80,22 @@ class BTCPayHandler:
|
||||
return None
|
||||
|
||||
invoice_data = response.json()
|
||||
logger.info(f"Successfully created invoice {invoice_data.get('id')}")
|
||||
invoice_id = invoice_data['id']
|
||||
logger.info(f"Successfully created invoice {invoice_id}")
|
||||
|
||||
# Create pending subscription
|
||||
subscription = DatabaseManager.create_subscription(
|
||||
user_id,
|
||||
invoice_id,
|
||||
public_key,
|
||||
duration_hours
|
||||
)
|
||||
logger.info(f"Created pending subscription for invoice {invoice_id}")
|
||||
|
||||
return {
|
||||
'invoice_id': invoice_data['id'],
|
||||
'checkout_url': invoice_data['checkoutLink']
|
||||
'invoice_id': invoice_id,
|
||||
'checkout_url': invoice_data['checkoutLink'],
|
||||
'user_id': user_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@ -92,53 +103,27 @@ class BTCPayHandler:
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def send_confirmation_email(self, email, config_data):
|
||||
"""Send VPN configuration details via email"""
|
||||
def get_subscription_config(self, user_id):
|
||||
"""Get WireGuard configuration details for a subscription"""
|
||||
try:
|
||||
logger.info(f"Sending confirmation email to {email}")
|
||||
logger.info(f"Fetching subscription config for user {user_id}")
|
||||
user = DatabaseManager.get_user_by_uuid(user_id)
|
||||
if not user:
|
||||
logger.error(f"User {user_id} not found")
|
||||
return None
|
||||
|
||||
if not self.smtp_config:
|
||||
logger.warning("SMTP configuration not found in vault")
|
||||
return False
|
||||
subscription = DatabaseManager.get_active_subscription_for_user(user.id)
|
||||
if not subscription:
|
||||
logger.error(f"No active subscription found for user {user_id}")
|
||||
return None
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = self.smtp_config['sender_email']
|
||||
msg['To'] = email
|
||||
msg['Subject'] = "Your VPN Configuration"
|
||||
|
||||
body = f"""
|
||||
Thank you for subscribing to our VPN service!
|
||||
|
||||
Please find your WireGuard configuration below:
|
||||
|
||||
{config_data}
|
||||
|
||||
Installation instructions:
|
||||
1. Install WireGuard client for your platform from https://www.wireguard.com/install/
|
||||
2. Save the above configuration to a file named 'wg0.conf'
|
||||
3. Import the configuration file into your WireGuard client
|
||||
|
||||
Need help? Reply to this email for support.
|
||||
"""
|
||||
|
||||
msg.attach(MIMEText(body, 'plain'))
|
||||
|
||||
logger.debug("Connecting to SMTP server")
|
||||
with smtplib.SMTP(
|
||||
self.smtp_config['server'],
|
||||
self.smtp_config.get('port', 587)
|
||||
) as server:
|
||||
server.starttls()
|
||||
server.login(
|
||||
self.smtp_config['username'],
|
||||
self.smtp_config['password']
|
||||
)
|
||||
server.send_message(msg)
|
||||
|
||||
logger.info("Confirmation email sent successfully")
|
||||
return True
|
||||
return {
|
||||
'serverPublicKey': os.getenv('WIREGUARD_SERVER_PUBLIC_KEY'),
|
||||
'serverEndpoint': os.getenv('WIREGUARD_SERVER_ENDPOINT'),
|
||||
'clientIp': subscription.assigned_ip
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending confirmation email:")
|
||||
logger.error(f"Error getting subscription config: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
return None
|
@ -6,11 +6,12 @@ import logging
|
||||
import hmac
|
||||
import hashlib
|
||||
import yaml
|
||||
import json
|
||||
import datetime
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
from ..utils.db.operations import DatabaseManager
|
||||
from ..utils.db.models import SubscriptionStatus
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@ -23,7 +24,6 @@ logger = logging.getLogger(__name__)
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
PLAYBOOK_PATH = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_provision.yml'
|
||||
CLEANUP_PLAYBOOK = BASE_DIR / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
|
||||
SUBSCRIPTION_DB = BASE_DIR / 'data' / 'subscriptions.json'
|
||||
|
||||
def get_vault_values():
|
||||
"""Get decrypted values from Ansible vault"""
|
||||
@ -80,31 +80,6 @@ def verify_signature(payload_body, signature_header):
|
||||
logger.error(f"Signature verification failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def load_subscriptions():
|
||||
"""Load subscription data from JSON file"""
|
||||
if not SUBSCRIPTION_DB.parent.exists():
|
||||
SUBSCRIPTION_DB.parent.mkdir(parents=True)
|
||||
|
||||
if not SUBSCRIPTION_DB.exists():
|
||||
return {}
|
||||
|
||||
with open(SUBSCRIPTION_DB, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_subscription(subscription_data):
|
||||
"""Save subscription data to JSON file"""
|
||||
subscriptions = load_subscriptions()
|
||||
sub_id = subscription_data['subscriptionId']
|
||||
subscriptions[sub_id] = subscription_data
|
||||
|
||||
with open(SUBSCRIPTION_DB, 'w') as f:
|
||||
json.dump(subscriptions, f, indent=2)
|
||||
|
||||
def calculate_expiry(duration_hours):
|
||||
"""Calculate expiry date based on subscription duration"""
|
||||
return (datetime.datetime.now() +
|
||||
datetime.timedelta(hours=duration_hours)).isoformat()
|
||||
|
||||
def run_ansible_playbook(invoice_id):
|
||||
"""Run the VPN provisioning playbook"""
|
||||
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
|
||||
@ -142,12 +117,15 @@ def handle_subscription_status(data):
|
||||
|
||||
logger.info(f"Processing subscription status update: {sub_id} -> {status}")
|
||||
|
||||
# Store subscription data
|
||||
data['last_updated'] = datetime.datetime.now().isoformat()
|
||||
save_subscription(data)
|
||||
subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
|
||||
if not subscription:
|
||||
logger.error(f"Subscription {sub_id} not found")
|
||||
return jsonify({"error": "Subscription not found"}), 404
|
||||
|
||||
if status != 'Active':
|
||||
# Run cleanup playbook for inactive subscriptions
|
||||
if status == 'Active':
|
||||
DatabaseManager.activate_subscription(sub_id)
|
||||
else:
|
||||
# Run cleanup for inactive subscriptions
|
||||
result = subprocess.run([
|
||||
'ansible-playbook',
|
||||
str(CLEANUP_PLAYBOOK),
|
||||
@ -159,6 +137,7 @@ def handle_subscription_status(data):
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to clean up subscription {sub_id}: {result.stderr}")
|
||||
|
||||
DatabaseManager.expire_subscription(subscription.id)
|
||||
logger.info(f"Subscription {sub_id} is no longer active")
|
||||
|
||||
return jsonify({
|
||||
@ -171,9 +150,10 @@ def handle_subscription_renewal(data):
|
||||
sub_id = data['subscriptionId']
|
||||
logger.info(f"Processing subscription renewal request: {sub_id}")
|
||||
|
||||
# Update subscription data
|
||||
data['renewal_requested'] = datetime.datetime.now().isoformat()
|
||||
save_subscription(data)
|
||||
subscription = DatabaseManager.get_subscription_by_invoice(sub_id)
|
||||
if not subscription:
|
||||
logger.error(f"Subscription {sub_id} not found")
|
||||
return jsonify({"error": "Subscription not found"}), 404
|
||||
|
||||
# TODO: Send renewal notification to user
|
||||
return jsonify({
|
||||
@ -199,6 +179,16 @@ def handle_payment_webhook(request):
|
||||
|
||||
data = request.json
|
||||
logger.info(f"Received webhook data: {data}")
|
||||
|
||||
# Handle test webhooks
|
||||
invoice_id = data.get('invoiceId', '')
|
||||
if invoice_id.startswith('__test__'):
|
||||
logger.info(f"Received test webhook, acknowledging: {data.get('type')}")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"message": "Test webhook acknowledged"
|
||||
})
|
||||
|
||||
webhook_type = data.get('type')
|
||||
|
||||
if webhook_type == 'SubscriptionStatusUpdated':
|
||||
@ -207,20 +197,13 @@ def handle_payment_webhook(request):
|
||||
elif webhook_type == 'SubscriptionRenewalRequested':
|
||||
return handle_subscription_renewal(data)
|
||||
|
||||
elif webhook_type == 'InvoiceSettled':
|
||||
# Handle regular invoice payment
|
||||
elif webhook_type in ['InvoiceSettled', 'InvoicePaymentSettled']:
|
||||
invoice_id = data.get('invoiceId')
|
||||
metadata = data.get('metadata', {})
|
||||
|
||||
if not invoice_id:
|
||||
logger.error("Missing invoiceId in webhook data")
|
||||
return jsonify({"error": "Missing invoiceId"}), 400
|
||||
|
||||
if invoice_id.startswith('__test__') and invoice_id.endswith('__test__'):
|
||||
invoice_id = invoice_id[8:-8]
|
||||
logger.info(f"Stripped test markers from invoice ID: {invoice_id}")
|
||||
|
||||
# Run Ansible playbook with enhanced logging
|
||||
# Get subscription and run Ansible playbook
|
||||
logger.info(f"Starting VPN provisioning for invoice {invoice_id}")
|
||||
result = run_ansible_playbook(invoice_id)
|
||||
|
||||
@ -236,46 +219,18 @@ def handle_payment_webhook(request):
|
||||
"stderr": result.stderr
|
||||
}), 500
|
||||
|
||||
# Get subscription and activate it
|
||||
subscription = DatabaseManager.get_subscription_by_invoice(invoice_id)
|
||||
if subscription:
|
||||
subscription = DatabaseManager.activate_subscription(invoice_id)
|
||||
DatabaseManager.record_payment(
|
||||
subscription.user_id,
|
||||
subscription.id,
|
||||
invoice_id,
|
||||
data.get('amount', 0)
|
||||
)
|
||||
|
||||
logger.info(f"VPN provisioning completed for invoice {invoice_id}")
|
||||
|
||||
# Update subscription database
|
||||
try:
|
||||
duration_hours = metadata.get('duration_hours', 24)
|
||||
subscriptions = load_subscriptions()
|
||||
subscriptions[invoice_id] = {
|
||||
'email': metadata.get('email'),
|
||||
'duration_hours': duration_hours,
|
||||
'start_time': datetime.datetime.now().isoformat(),
|
||||
'expiry': calculate_expiry(duration_hours),
|
||||
'status': 'Active'
|
||||
}
|
||||
with open(SUBSCRIPTION_DB, 'w') as f:
|
||||
json.dump(subscriptions, f, indent=2)
|
||||
logger.info(f"Updated subscription database for invoice {invoice_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating subscription database: {str(e)}")
|
||||
|
||||
# Send email confirmation if email is provided
|
||||
try:
|
||||
email = metadata.get('email')
|
||||
if email:
|
||||
config_path = f"/etc/wireguard/clients/{invoice_id}/wg0.conf"
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r') as f:
|
||||
config_data = f.read()
|
||||
|
||||
btcpay_handler = BTCPayHandler()
|
||||
email_sent = btcpay_handler.send_confirmation_email(email, config_data)
|
||||
if email_sent:
|
||||
logger.info(f"Sent configuration email to {email}")
|
||||
else:
|
||||
logger.warning(f"Failed to send configuration email to {email}")
|
||||
else:
|
||||
logger.warning(f"Config file not found at {config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending confirmation email: {str(e)}")
|
||||
|
||||
logger.info(f"Successfully processed invoice {invoice_id}")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
"invoice_id": invoice_id,
|
||||
@ -283,7 +238,6 @@ def handle_payment_webhook(request):
|
||||
})
|
||||
|
||||
else:
|
||||
# Log other webhook types as info instead of warning
|
||||
logger.info(f"Received {webhook_type} webhook - no action required")
|
||||
return jsonify({
|
||||
"status": "success",
|
||||
|
213
app/static/js/components/WireGuardPayment.jsx
Normal file
213
app/static/js/components/WireGuardPayment.jsx
Normal file
@ -0,0 +1,213 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { generateKeys } from '../utils/wireguard';
|
||||
import { Card, CardHeader, CardTitle, CardContent } from '@/components/ui/card';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { AlertCircle } from 'lucide-react';
|
||||
import { Alert, AlertDescription } from '@/components/ui/alert';
|
||||
|
||||
const WireGuardPayment = () => {
|
||||
const [keyData, setKeyData] = useState(null);
|
||||
const [duration, setDuration] = useState(24);
|
||||
const [price, setPrice] = useState(0);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState('');
|
||||
const [userId, setUserId] = useState('');
|
||||
|
||||
useEffect(() => {
|
||||
// Generate random userId and keys when component mounts
|
||||
const init = async () => {
|
||||
try {
|
||||
const randomId = crypto.randomUUID();
|
||||
setUserId(randomId);
|
||||
const keys = await generateKeys();
|
||||
setKeyData(keys);
|
||||
calculatePrice(duration);
|
||||
} catch (err) {
|
||||
setError('Failed to initialize keys');
|
||||
console.error(err);
|
||||
}
|
||||
};
|
||||
init();
|
||||
}, []);
|
||||
|
||||
const calculatePrice = async (hours) => {
|
||||
try {
|
||||
const response = await fetch('/api/calculate-price', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ hours: parseInt(hours) })
|
||||
});
|
||||
const data = await response.json();
|
||||
setPrice(data.price);
|
||||
} catch (err) {
|
||||
setError('Failed to calculate price');
|
||||
}
|
||||
};
|
||||
|
||||
const handleDurationChange = (event) => {
|
||||
const newDuration = parseInt(event.target.value);
|
||||
setDuration(newDuration);
|
||||
calculatePrice(newDuration);
|
||||
};
|
||||
|
||||
const handlePayment = async () => {
|
||||
if (!keyData) {
|
||||
setError('No keys generated. Please refresh the page.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setLoading(true);
|
||||
const response = await fetch('/create-invoice', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
duration,
|
||||
userId,
|
||||
publicKey: keyData.publicKey,
|
||||
// Don't send private key to server!
|
||||
configuration: {
|
||||
type: 'wireguard',
|
||||
publicKey: keyData.publicKey
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to create payment');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
// Save private key to localStorage before redirecting
|
||||
localStorage.setItem(`vpn_keys_${userId}`, JSON.stringify({
|
||||
privateKey: keyData.privateKey,
|
||||
publicKey: keyData.publicKey,
|
||||
createdAt: new Date().toISOString()
|
||||
}));
|
||||
|
||||
window.location.href = data.checkout_url;
|
||||
} catch (err) {
|
||||
setError('Failed to initiate payment');
|
||||
console.error(err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRegenerateKeys = async () => {
|
||||
try {
|
||||
const keys = await generateKeys();
|
||||
setKeyData(keys);
|
||||
} catch (err) {
|
||||
setError('Failed to regenerate keys');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card className="w-full max-w-xl mx-auto">
|
||||
<CardHeader>
|
||||
<CardTitle>WireGuard VPN Configuration</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{error && (
|
||||
<Alert variant="destructive">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Your User ID:</label>
|
||||
<Input value={userId} readOnly className="font-mono text-sm" />
|
||||
<p className="text-sm text-gray-500">
|
||||
Save this ID - you'll need it to manage your subscription
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{keyData && (
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Your Public Key:</label>
|
||||
<Input value={keyData.publicKey} readOnly className="font-mono text-sm" />
|
||||
<Button
|
||||
onClick={handleRegenerateKeys}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="w-full"
|
||||
>
|
||||
Regenerate Keys
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm font-medium">Duration:</label>
|
||||
<div className="space-y-1">
|
||||
<Input
|
||||
type="range"
|
||||
min="1"
|
||||
max="720"
|
||||
value={duration}
|
||||
onChange={handleDurationChange}
|
||||
className="w-full"
|
||||
/>
|
||||
<div className="flex justify-between text-sm text-gray-500">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setDuration(24);
|
||||
calculatePrice(24);
|
||||
}}
|
||||
>
|
||||
1 Day
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setDuration(168);
|
||||
calculatePrice(168);
|
||||
}}
|
||||
>
|
||||
1 Week
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setDuration(720);
|
||||
calculatePrice(720);
|
||||
}}
|
||||
>
|
||||
30 Days
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-center font-medium">{duration} hours</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="text-center py-4">
|
||||
<p className="text-3xl font-bold text-blue-500">{price} sats</p>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
onClick={handlePayment}
|
||||
disabled={loading || !keyData}
|
||||
className="w-full"
|
||||
>
|
||||
{loading ? 'Processing...' : 'Pay with Bitcoin'}
|
||||
</Button>
|
||||
|
||||
<div className="mt-4 text-sm text-gray-500 space-y-1">
|
||||
<p>• Keys are generated securely in your browser</p>
|
||||
<p>• Your private key never leaves your device</p>
|
||||
<p>• Configuration will be available after payment</p>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default WireGuardPayment;
|
@ -1,11 +1,39 @@
|
||||
// app/static/js/pricing.js
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
// Base64 encoding/decoding utilities
|
||||
const b64 = {
|
||||
encode: array => btoa(String.fromCharCode.apply(null, array)),
|
||||
decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
|
||||
};
|
||||
|
||||
async function generateKeyPair() {
|
||||
const keyPair = await window.crypto.subtle.generateKey(
|
||||
{
|
||||
name: 'X25519',
|
||||
namedCurve: 'X25519',
|
||||
},
|
||||
true,
|
||||
['deriveKey', 'deriveBits']
|
||||
);
|
||||
|
||||
const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
|
||||
const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
|
||||
|
||||
return {
|
||||
privateKey: b64.encode(new Uint8Array(privateKey)),
|
||||
publicKey: b64.encode(new Uint8Array(publicKey))
|
||||
};
|
||||
}
|
||||
|
||||
document.addEventListener('DOMContentLoaded', async function() {
|
||||
const form = document.getElementById('subscription-form');
|
||||
const slider = document.getElementById('duration-slider');
|
||||
const durationDisplay = document.getElementById('duration-display');
|
||||
const priceDisplay = document.getElementById('price-display');
|
||||
const presetButtons = document.querySelectorAll('.duration-preset');
|
||||
const emailInput = document.getElementById('email');
|
||||
const userIdInput = document.getElementById('user-id');
|
||||
const publicKeyInput = document.getElementById('public-key');
|
||||
const regenerateButton = document.getElementById('regenerate-keys');
|
||||
|
||||
let currentKeyPair = null;
|
||||
|
||||
function formatDuration(hours) {
|
||||
if (hours < 24) return `${hours} hour${hours === 1 ? '' : 's'}`;
|
||||
@ -14,6 +42,32 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
return `${Math.floor(hours / 720)} month${hours === 720 ? '' : 's'}`;
|
||||
}
|
||||
|
||||
async function generateNewKeys() {
|
||||
try {
|
||||
currentKeyPair = await generateKeyPair();
|
||||
publicKeyInput.value = currentKeyPair.publicKey;
|
||||
|
||||
// Save private key to localStorage
|
||||
const keyData = {
|
||||
privateKey: currentKeyPair.privateKey,
|
||||
publicKey: currentKeyPair.publicKey,
|
||||
createdAt: new Date().toISOString()
|
||||
};
|
||||
localStorage.setItem(`vpn_keys_${userIdInput.value}`, JSON.stringify(keyData));
|
||||
} catch (error) {
|
||||
console.error('Failed to generate keys:', error);
|
||||
alert('Failed to generate WireGuard keys. Please try again.');
|
||||
}
|
||||
}
|
||||
|
||||
async function initializeForm() {
|
||||
// Generate user ID
|
||||
userIdInput.value = crypto.randomUUID();
|
||||
|
||||
// Generate initial keys
|
||||
await generateNewKeys();
|
||||
}
|
||||
|
||||
async function updatePrice(hours) {
|
||||
try {
|
||||
const response = await fetch('/api/calculate-price', {
|
||||
@ -29,7 +83,27 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
}
|
||||
}
|
||||
|
||||
async function createInvoice(duration, email) {
|
||||
// Event listeners
|
||||
slider.addEventListener('input', () => updatePrice(slider.value));
|
||||
|
||||
regenerateButton.addEventListener('click', generateNewKeys);
|
||||
|
||||
presetButtons.forEach(button => {
|
||||
button.addEventListener('click', (e) => {
|
||||
const hours = e.target.dataset.hours;
|
||||
slider.value = hours;
|
||||
updatePrice(hours);
|
||||
});
|
||||
});
|
||||
|
||||
form.addEventListener('submit', async (e) => {
|
||||
e.preventDefault();
|
||||
|
||||
if (!currentKeyPair) {
|
||||
alert('No keys generated. Please refresh the page.');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch('/create-invoice', {
|
||||
method: 'POST',
|
||||
@ -37,8 +111,9 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
duration: parseInt(duration),
|
||||
email: email
|
||||
duration: parseInt(slider.value),
|
||||
userId: userIdInput.value,
|
||||
publicKey: currentKeyPair.publicKey
|
||||
})
|
||||
});
|
||||
|
||||
@ -53,32 +128,10 @@ document.addEventListener('DOMContentLoaded', function() {
|
||||
console.error('Error creating invoice:', error);
|
||||
alert('Failed to create payment invoice. Please try again.');
|
||||
}
|
||||
}
|
||||
|
||||
// Event listeners
|
||||
slider.addEventListener('input', () => updatePrice(slider.value));
|
||||
|
||||
presetButtons.forEach(button => {
|
||||
button.addEventListener('click', (e) => {
|
||||
const hours = e.target.dataset.hours;
|
||||
slider.value = hours;
|
||||
updatePrice(hours);
|
||||
});
|
||||
});
|
||||
|
||||
form.addEventListener('submit', async (e) => {
|
||||
e.preventDefault();
|
||||
|
||||
const email = emailInput.value.trim();
|
||||
if (!email) {
|
||||
alert('Please enter your email address');
|
||||
return;
|
||||
}
|
||||
|
||||
const duration = slider.value;
|
||||
await createInvoice(duration, email);
|
||||
});
|
||||
|
||||
// Initialize the form
|
||||
await initializeForm();
|
||||
// Initial price calculation
|
||||
updatePrice(slider.value);
|
||||
});
|
54
app/static/js/utils/wireguard.js
Normal file
54
app/static/js/utils/wireguard.js
Normal file
@ -0,0 +1,54 @@
|
||||
// Base64 encoding/decoding utilities
|
||||
const b64 = {
|
||||
encode: array => btoa(String.fromCharCode.apply(null, array)),
|
||||
decode: str => Uint8Array.from(atob(str), c => c.charCodeAt(0))
|
||||
};
|
||||
|
||||
async function generateKeyPair() {
|
||||
// Generate a random key pair using Web Crypto API
|
||||
const keyPair = await window.crypto.subtle.generateKey(
|
||||
{
|
||||
name: 'X25519',
|
||||
namedCurve: 'X25519',
|
||||
},
|
||||
true,
|
||||
['deriveKey', 'deriveBits']
|
||||
);
|
||||
|
||||
// Export keys in raw format
|
||||
const privateKey = await window.crypto.subtle.exportKey('raw', keyPair.privateKey);
|
||||
const publicKey = await window.crypto.subtle.exportKey('raw', keyPair.publicKey);
|
||||
|
||||
// Convert to base64
|
||||
return {
|
||||
privateKey: b64.encode(new Uint8Array(privateKey)),
|
||||
publicKey: b64.encode(new Uint8Array(publicKey))
|
||||
};
|
||||
}
|
||||
|
||||
export async function generateWireGuardConfig(serverPublicKey, serverEndpoint, address) {
|
||||
const keys = await generateKeyPair();
|
||||
|
||||
return {
|
||||
keys,
|
||||
config: `[Interface]
|
||||
PrivateKey = ${keys.privateKey}
|
||||
Address = ${address}
|
||||
DNS = 1.1.1.1
|
||||
|
||||
[Peer]
|
||||
PublicKey = ${serverPublicKey}
|
||||
Endpoint = ${serverEndpoint}
|
||||
AllowedIPs = 0.0.0.0/0
|
||||
PersistentKeepalive = 25`
|
||||
};
|
||||
}
|
||||
|
||||
export async function generateKeys() {
|
||||
try {
|
||||
return await generateKeyPair();
|
||||
} catch (error) {
|
||||
console.error('Failed to generate WireGuard keys:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
@ -1,20 +1,35 @@
|
||||
{% extends "base.html" %}
|
||||
|
||||
{% block content %}
|
||||
<div class="min-h-screen bg-dark py-8 px-4">
|
||||
<div class="max-w-xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6">
|
||||
<h1 class="text-2xl font-bold mb-6 text-center">Subscribe to VPN Service</h1>
|
||||
|
||||
<form id="subscription-form" class="space-y-6">
|
||||
<div>
|
||||
<label for="email" class="block text-sm font-medium mb-2">Email Address</label>
|
||||
<label class="block text-sm font-medium mb-2">User ID</label>
|
||||
<input
|
||||
type="email"
|
||||
id="email"
|
||||
required
|
||||
class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500"
|
||||
placeholder="your@email.com"
|
||||
type="text"
|
||||
id="user-id"
|
||||
readonly
|
||||
class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500 font-mono text-sm"
|
||||
>
|
||||
<p class="mt-1 text-sm text-gray-400">Save this ID - you'll need it to manage your subscription</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label class="block text-sm font-medium mb-2">WireGuard Public Key</label>
|
||||
<input
|
||||
type="text"
|
||||
id="public-key"
|
||||
readonly
|
||||
class="w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white focus:outline-none focus:border-blue-500 font-mono text-sm"
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
id="regenerate-keys"
|
||||
class="mt-2 w-full px-3 py-2 bg-dark border border-gray-600 rounded-md text-white hover:bg-gray-800 transition-colors"
|
||||
>
|
||||
Regenerate Keys
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
@ -50,6 +65,12 @@
|
||||
>
|
||||
Pay with Bitcoin
|
||||
</button>
|
||||
|
||||
<div class="mt-4 text-sm text-gray-400 space-y-1">
|
||||
<p>• Keys are generated securely in your browser</p>
|
||||
<p>• Your private key never leaves your device</p>
|
||||
<p>• Configuration will be available after payment</p>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -2,17 +2,110 @@
|
||||
|
||||
{% block content %}
|
||||
<div class="min-h-screen bg-dark py-8 px-4">
|
||||
<div class="max-w-xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6 text-center">
|
||||
<h1 class="text-2xl font-bold mb-4">Payment Successful!</h1>
|
||||
<p class="text-gray-300 mb-6">
|
||||
Thank you for your payment. Your VPN configuration will be sent to your email shortly.
|
||||
<div class="max-w-2xl mx-auto bg-dark-lighter rounded-lg shadow-lg p-6">
|
||||
<h1 class="text-2xl font-bold mb-4 text-center">Payment Successful!</h1>
|
||||
|
||||
<div class="mb-8">
|
||||
<p class="text-gray-300 text-center mb-4">
|
||||
Your VPN subscription is now active. Please follow the instructions below to set up your VPN connection.
|
||||
</p>
|
||||
<p class="text-gray-400 mb-4">
|
||||
Please check your email for further instructions on setting up your VPN connection.
|
||||
</p>
|
||||
<a href="/" class="inline-block bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors">
|
||||
</div>
|
||||
|
||||
<div class="space-y-6">
|
||||
<!-- WireGuard Configuration Section -->
|
||||
<div id="config-section" class="hidden">
|
||||
<h2 class="text-xl font-semibold mb-2">Your WireGuard Configuration</h2>
|
||||
<pre id="wireguard-config" class="bg-dark p-4 rounded-md font-mono text-sm overflow-x-auto"></pre>
|
||||
<button
|
||||
id="copy-config"
|
||||
class="mt-2 w-full bg-blue-600 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded transition-colors"
|
||||
>
|
||||
Copy Configuration
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Installation Instructions -->
|
||||
<div class="space-y-4">
|
||||
<h2 class="text-xl font-semibold">Installation Instructions</h2>
|
||||
<ol class="list-decimal list-inside space-y-2 text-gray-300">
|
||||
<li>Download WireGuard for your platform:
|
||||
<div class="ml-6 mt-2 space-y-2">
|
||||
<a href="https://www.wireguard.com/install/"
|
||||
target="_blank"
|
||||
class="text-blue-400 hover:text-blue-300 block">
|
||||
Official WireGuard Downloads
|
||||
</a>
|
||||
</div>
|
||||
</li>
|
||||
<li>Create a new tunnel in WireGuard</li>
|
||||
<li>Copy the configuration above and paste it into the new tunnel</li>
|
||||
<li>Activate the tunnel to connect</li>
|
||||
</ol>
|
||||
</div>
|
||||
|
||||
<div class="space-y-2 text-sm text-gray-400">
|
||||
<p>• Save your configuration securely - you'll need it to reconnect</p>
|
||||
<p>• Your private key is stored in your browser's local storage</p>
|
||||
<p>• For security, clear your browser data after saving the configuration</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-8 text-center">
|
||||
<a href="/" class="text-blue-400 hover:text-blue-300">
|
||||
Return to Home
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const userId = new URLSearchParams(window.location.search).get('userId');
|
||||
if (userId) {
|
||||
// Retrieve keys from localStorage
|
||||
const keyData = localStorage.getItem(`vpn_keys_${userId}`);
|
||||
if (keyData) {
|
||||
const keys = JSON.parse(keyData);
|
||||
|
||||
// Make config section visible
|
||||
document.getElementById('config-section').classList.remove('hidden');
|
||||
|
||||
// Get server details from response
|
||||
fetch(`/api/subscription/config?userId=${userId}`)
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
const config = `[Interface]
|
||||
PrivateKey = ${keys.privateKey}
|
||||
Address = ${data.clientIp}/24
|
||||
DNS = 1.1.1.1
|
||||
|
||||
[Peer]
|
||||
PublicKey = ${data.serverPublicKey}
|
||||
Endpoint = ${data.serverEndpoint}:51820
|
||||
AllowedIPs = 0.0.0.0/0
|
||||
PersistentKeepalive = 25`;
|
||||
|
||||
document.getElementById('wireguard-config').textContent = config;
|
||||
})
|
||||
.catch(error => {
|
||||
console.error('Error fetching configuration:', error);
|
||||
});
|
||||
|
||||
// Setup copy button
|
||||
document.getElementById('copy-config').addEventListener('click', function() {
|
||||
const config = document.getElementById('wireguard-config').textContent;
|
||||
navigator.clipboard.writeText(config).then(() => {
|
||||
this.textContent = 'Copied!';
|
||||
setTimeout(() => {
|
||||
this.textContent = 'Copy Configuration';
|
||||
}, 2000);
|
||||
});
|
||||
});
|
||||
|
||||
// Clear keys from localStorage after showing config
|
||||
localStorage.removeItem(`vpn_keys_${userId}`);
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
{% endblock %}
|
23
app/utils/db/__init__.py
Normal file
23
app/utils/db/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pathlib import Path
|
||||
|
||||
def get_db_path():
|
||||
base_dir = Path(__file__).resolve().parent.parent.parent
|
||||
data_dir = base_dir / 'data'
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
return data_dir / 'vpn.db'
|
||||
|
||||
def init_db():
|
||||
"""Initialize the database"""
|
||||
from .models import Base
|
||||
db_url = f"sqlite:///{get_db_path()}"
|
||||
engine = create_engine(db_url)
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
def get_session():
|
||||
"""Get a database session"""
|
||||
engine = init_db()
|
||||
Session = sessionmaker(bind=engine)
|
||||
return Session()
|
52
app/utils/db/models.py
Normal file
52
app/utils/db/models.py
Normal file
@ -0,0 +1,52 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, DateTime, ForeignKey, Enum, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
import enum
|
||||
import datetime
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class SubscriptionStatus(enum.Enum):
|
||||
ACTIVE = "active"
|
||||
EXPIRED = "expired"
|
||||
PENDING = "pending"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'users'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(String, unique=True, nullable=False) # UUID generated in frontend
|
||||
created_at = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
|
||||
subscriptions = relationship("Subscription", back_populates="user")
|
||||
payments = relationship("Payment", back_populates="user")
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = 'subscriptions'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'))
|
||||
invoice_id = Column(String, unique=True)
|
||||
public_key = Column(Text, nullable=False) # WireGuard public key
|
||||
start_time = Column(DateTime, nullable=False)
|
||||
expiry_time = Column(DateTime, nullable=False)
|
||||
status = Column(Enum(SubscriptionStatus), default=SubscriptionStatus.PENDING)
|
||||
warning_sent = Column(Integer, default=0)
|
||||
assigned_ip = Column(String) # WireGuard IP address assigned to this subscription
|
||||
|
||||
user = relationship("User", back_populates="subscriptions")
|
||||
payments = relationship("Payment", back_populates="subscription")
|
||||
|
||||
class Payment(Base):
|
||||
__tablename__ = 'payments'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'))
|
||||
subscription_id = Column(Integer, ForeignKey('subscriptions.id'))
|
||||
invoice_id = Column(String, unique=True)
|
||||
amount = Column(Integer, nullable=False) # Amount in sats
|
||||
timestamp = Column(DateTime, default=datetime.datetime.utcnow)
|
||||
|
||||
user = relationship("User", back_populates="payments")
|
||||
subscription = relationship("Subscription", back_populates="payments")
|
152
app/utils/db/operations.py
Normal file
152
app/utils/db/operations.py
Normal file
@ -0,0 +1,152 @@
|
||||
from datetime import datetime
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from . import get_session
|
||||
from .models import User, Subscription, Payment, SubscriptionStatus
|
||||
import logging
|
||||
import ipaddress
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DatabaseManager:
|
||||
@staticmethod
|
||||
def get_next_available_ip():
|
||||
"""Get the next available IP from the WireGuard subnet"""
|
||||
with get_session() as session:
|
||||
# Get all assigned IPs
|
||||
assigned_ips = session.query(Subscription.assigned_ip)\
|
||||
.filter(Subscription.assigned_ip.isnot(None))\
|
||||
.all()
|
||||
assigned_ips = [ip[0] for ip in assigned_ips]
|
||||
|
||||
# Start from 10.8.0.2 (10.8.0.1 is server)
|
||||
network = ipaddress.IPv4Network('10.8.0.0/24')
|
||||
for ip in network.hosts():
|
||||
str_ip = str(ip)
|
||||
if str_ip == '10.8.0.1': # Skip server IP
|
||||
continue
|
||||
if str_ip not in assigned_ips:
|
||||
return str_ip
|
||||
|
||||
raise Exception("No available IPs in the subnet")
|
||||
|
||||
@staticmethod
|
||||
def create_user(user_id):
|
||||
"""Create a new user with UUID"""
|
||||
with get_session() as session:
|
||||
user = User(user_id=user_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_uuid(user_id):
|
||||
"""Get user by UUID"""
|
||||
with get_session() as session:
|
||||
return session.query(User).filter(User.user_id == user_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_subscription_by_invoice(invoice_id):
|
||||
"""Get subscription by invoice ID"""
|
||||
with get_session() as session:
|
||||
return session.query(Subscription)\
|
||||
.filter(Subscription.invoice_id == invoice_id)\
|
||||
.first()
|
||||
|
||||
@staticmethod
|
||||
def create_subscription(user_id, invoice_id, public_key, duration_hours):
|
||||
"""Create a new subscription"""
|
||||
with get_session() as session:
|
||||
try:
|
||||
# Get user or create if doesn't exist
|
||||
user = DatabaseManager.get_user_by_uuid(user_id)
|
||||
if not user:
|
||||
user = DatabaseManager.create_user(user_id)
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
expiry_time = start_time + datetime.timedelta(hours=duration_hours)
|
||||
|
||||
# Get next available IP
|
||||
assigned_ip = DatabaseManager.get_next_available_ip()
|
||||
|
||||
subscription = Subscription(
|
||||
user_id=user.id,
|
||||
invoice_id=invoice_id,
|
||||
public_key=public_key,
|
||||
start_time=start_time,
|
||||
expiry_time=expiry_time,
|
||||
status=SubscriptionStatus.PENDING,
|
||||
assigned_ip=assigned_ip
|
||||
)
|
||||
session.add(subscription)
|
||||
session.commit()
|
||||
return subscription
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating subscription: {str(e)}")
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def activate_subscription(invoice_id):
|
||||
"""Activate a subscription after payment"""
|
||||
with get_session() as session:
|
||||
subscription = session.query(Subscription)\
|
||||
.filter(Subscription.invoice_id == invoice_id)\
|
||||
.first()
|
||||
|
||||
if subscription:
|
||||
subscription.status = SubscriptionStatus.ACTIVE
|
||||
session.commit()
|
||||
return subscription
|
||||
|
||||
@staticmethod
|
||||
def record_payment(user_id, subscription_id, invoice_id, amount):
|
||||
"""Record a payment"""
|
||||
with get_session() as session:
|
||||
user = DatabaseManager.get_user_by_uuid(user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
payment = Payment(
|
||||
user_id=user.id,
|
||||
subscription_id=subscription_id,
|
||||
invoice_id=invoice_id,
|
||||
amount=amount
|
||||
)
|
||||
session.add(payment)
|
||||
session.commit()
|
||||
return payment
|
||||
|
||||
@staticmethod
|
||||
def get_active_subscriptions():
|
||||
"""Get all active subscriptions"""
|
||||
with get_session() as session:
|
||||
return session.query(Subscription)\
|
||||
.filter(Subscription.status == SubscriptionStatus.ACTIVE)\
|
||||
.all()
|
||||
|
||||
@staticmethod
|
||||
def expire_subscription(subscription_id):
|
||||
"""Mark a subscription as expired"""
|
||||
with get_session() as session:
|
||||
subscription = session.query(Subscription)\
|
||||
.filter(Subscription.id == subscription_id)\
|
||||
.first()
|
||||
|
||||
if subscription:
|
||||
subscription.status = SubscriptionStatus.EXPIRED
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_warning_sent(subscription_id):
|
||||
"""Update the warning_sent flag for a subscription"""
|
||||
with get_session() as session:
|
||||
subscription = session.query(Subscription)\
|
||||
.filter(Subscription.id == subscription_id)\
|
||||
.first()
|
||||
if subscription:
|
||||
subscription.warning_sent = 1
|
||||
session.commit()
|
||||
return True
|
||||
return False
|
59
app/utils/db/utils.py
Normal file
59
app/utils/db/utils.py
Normal file
@ -0,0 +1,59 @@
|
||||
import sqlite3
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from .. import get_db_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def backup_database():
|
||||
"""Create a backup of the SQLite database"""
|
||||
try:
|
||||
db_path = get_db_path()
|
||||
if not db_path.exists():
|
||||
logger.error("Database file not found")
|
||||
return False
|
||||
|
||||
backup_dir = db_path.parent / 'backups'
|
||||
backup_dir.mkdir(exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
backup_path = backup_dir / f'vpn_db_backup_{timestamp}.db'
|
||||
|
||||
shutil.copy2(db_path, backup_path)
|
||||
logger.info(f"Database backed up successfully to {backup_path}")
|
||||
|
||||
# Keep only last 5 backups
|
||||
backups = sorted(backup_dir.glob('vpn_db_backup_*.db'))
|
||||
if len(backups) > 5:
|
||||
for old_backup in backups[:-5]:
|
||||
old_backup.unlink()
|
||||
logger.info(f"Removed old backup: {old_backup}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Backup failed: {str(e)}")
|
||||
return False
|
||||
|
||||
def verify_database():
|
||||
"""Check database integrity"""
|
||||
try:
|
||||
db_path = get_db_path()
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("PRAGMA integrity_check")
|
||||
result = cursor.fetchone()[0]
|
||||
|
||||
if result != "ok":
|
||||
logger.error(f"Database integrity check failed: {result}")
|
||||
return False
|
||||
|
||||
logger.info("Database integrity verified")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Database verification failed: {str(e)}")
|
||||
return False
|
BIN
data/vpn.db
Normal file
BIN
data/vpn.db
Normal file
Binary file not shown.
@ -4,3 +4,4 @@ python-dotenv==1.0.0
|
||||
cryptography==41.0.7
|
||||
ansible==9.1.0
|
||||
requests==2.31.0
|
||||
SQLAlchemy==2.0.25
|
28
scripts/init_db.py
Normal file
28
scripts/init_db.py
Normal file
@ -0,0 +1,28 @@
|
||||
# scripts/init_db.py
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the project root to Python path
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from app.utils.db.models import Base
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
def init_db():
|
||||
# Create data directory if it doesn't exist
|
||||
data_dir = project_root / 'data'
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create database
|
||||
db_path = data_dir / 'vpn.db'
|
||||
db_url = f"sqlite:///{db_path}"
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# Create all tables
|
||||
Base.metadata.create_all(engine)
|
||||
print(f"Database initialized at: {db_path}")
|
||||
return engine
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_db()
|
151
scripts/migrate_db.py
Normal file
151
scripts/migrate_db.py
Normal file
@ -0,0 +1,151 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
# Add the project root to Python path
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from app.utils.db.models import Base, SubscriptionStatus
|
||||
from app.utils.db import get_db_path, get_session
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def backup_database():
|
||||
"""Create a backup of the existing database"""
|
||||
db_path = get_db_path()
|
||||
if not db_path.exists():
|
||||
logger.info("No existing database found, skipping backup")
|
||||
return
|
||||
|
||||
backup_path = db_path.parent / f"{db_path.stem}_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}{db_path.suffix}"
|
||||
db_path.rename(backup_path)
|
||||
logger.info(f"Created database backup at {backup_path}")
|
||||
return backup_path
|
||||
|
||||
def migrate_database():
|
||||
"""Perform database migration"""
|
||||
try:
|
||||
# Create backup
|
||||
backup_path = backup_database()
|
||||
|
||||
# Initialize new database
|
||||
engine = create_engine(f"sqlite:///{get_db_path()}")
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
# If there's an old database, migrate the data
|
||||
if backup_path and backup_path.exists():
|
||||
logger.info("Migrating data from old database...")
|
||||
|
||||
# Connect to old database
|
||||
old_conn = sqlite3.connect(backup_path)
|
||||
old_cursor = old_conn.cursor()
|
||||
|
||||
# Get schema information
|
||||
old_cursor.execute("SELECT sql FROM sqlite_master WHERE type='table'")
|
||||
old_schema = old_cursor.fetchall()
|
||||
logger.info("Old database schema:")
|
||||
for table in old_schema:
|
||||
logger.info(table[0])
|
||||
|
||||
with get_session() as session:
|
||||
try:
|
||||
# Start transaction
|
||||
session.begin()
|
||||
|
||||
# Migrate users
|
||||
logger.info("Migrating users...")
|
||||
old_cursor.execute("SELECT id, email, created_at FROM users")
|
||||
users = old_cursor.fetchall()
|
||||
user_id_map = {} # Map old IDs to new UUIDs
|
||||
|
||||
for old_id, email, created_at in users:
|
||||
new_user_id = str(uuid.uuid4())
|
||||
user_id_map[old_id] = new_user_id
|
||||
session.execute(
|
||||
"INSERT INTO users (user_id, created_at) VALUES (?, ?)",
|
||||
[new_user_id, created_at or datetime.utcnow()]
|
||||
)
|
||||
|
||||
# Migrate subscriptions
|
||||
logger.info("Migrating subscriptions...")
|
||||
old_cursor.execute("""
|
||||
SELECT id, user_id, invoice_id, start_time, expiry_time,
|
||||
status, warning_sent
|
||||
FROM subscriptions
|
||||
""")
|
||||
subscriptions = old_cursor.fetchall()
|
||||
|
||||
for sub in subscriptions:
|
||||
old_id, old_user_id, invoice_id, start_time, expiry_time, status, warning_sent = sub
|
||||
if old_user_id in user_id_map:
|
||||
# Generate a placeholder public key for existing subscriptions
|
||||
placeholder_pubkey = f"MIGRATED_{uuid.uuid4()}"
|
||||
session.execute("""
|
||||
INSERT INTO subscriptions
|
||||
(user_id, invoice_id, public_key, start_time, expiry_time,
|
||||
status, warning_sent, assigned_ip)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", [
|
||||
user_id_map[old_user_id],
|
||||
invoice_id,
|
||||
placeholder_pubkey,
|
||||
start_time,
|
||||
expiry_time,
|
||||
status,
|
||||
warning_sent,
|
||||
f"10.8.0.{2 + old_id}" # Simple IP assignment
|
||||
])
|
||||
|
||||
# Migrate payments
|
||||
logger.info("Migrating payments...")
|
||||
old_cursor.execute("""
|
||||
SELECT user_id, subscription_id, invoice_id, amount, timestamp
|
||||
FROM payments
|
||||
""")
|
||||
payments = old_cursor.fetchall()
|
||||
|
||||
for payment in payments:
|
||||
old_user_id, sub_id, invoice_id, amount, timestamp = payment
|
||||
if old_user_id in user_id_map:
|
||||
session.execute("""
|
||||
INSERT INTO payments
|
||||
(user_id, subscription_id, invoice_id, amount, timestamp)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", [
|
||||
user_id_map[old_user_id],
|
||||
sub_id,
|
||||
invoice_id,
|
||||
amount,
|
||||
timestamp
|
||||
])
|
||||
|
||||
# Commit transaction
|
||||
session.commit()
|
||||
logger.info("Migration completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"Migration failed: {str(e)}")
|
||||
raise
|
||||
|
||||
old_conn.close()
|
||||
|
||||
else:
|
||||
logger.info("No existing database to migrate")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {str(e)}")
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
migrate_database()
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {str(e)}")
|
||||
sys.exit(1)
|
@ -1,14 +1,12 @@
|
||||
# scripts/subscription_checker.py
|
||||
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import subprocess
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from dateutil.parser import parse
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from datetime import datetime, timedelta
|
||||
from utils.db.operations import DatabaseManager
|
||||
from utils.db.models import SubscriptionStatus
|
||||
from app.handlers.payment_handler import BTCPayHandler
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@ -19,20 +17,19 @@ logger = logging.getLogger(__name__)
|
||||
# Path setup
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
PROJECT_ROOT = SCRIPT_DIR.parent
|
||||
SUBSCRIPTION_DB = PROJECT_ROOT / 'data' / 'subscriptions.json'
|
||||
CLEANUP_PLAYBOOK = PROJECT_ROOT / 'ansible' / 'playbooks' / 'vpn_cleanup.yml'
|
||||
INVENTORY_FILE = PROJECT_ROOT / 'inventory.ini'
|
||||
|
||||
# Notification thresholds configuration
|
||||
NOTIFICATION_THRESHOLDS = {
|
||||
'minimum_duration': datetime.timedelta(hours=1), # Minimum subscription duration
|
||||
'minimum_duration': timedelta(hours=1), # Minimum subscription duration
|
||||
'short_term': {
|
||||
'max_duration': datetime.timedelta(days=1),
|
||||
'max_duration': timedelta(days=1),
|
||||
'warning_fraction': 0.5, # Warn when 50% of time remains
|
||||
'grace_fraction': 0.1 # Grace period of 10% of subscription length
|
||||
},
|
||||
'medium_term': {
|
||||
'max_duration': datetime.timedelta(days=7),
|
||||
'max_duration': timedelta(days=7),
|
||||
'warning_fraction': 0.25, # Warn when 25% of time remains
|
||||
'grace_hours': 12 # Fixed 12-hour grace period
|
||||
},
|
||||
@ -42,25 +39,9 @@ NOTIFICATION_THRESHOLDS = {
|
||||
}
|
||||
}
|
||||
|
||||
def load_subscriptions():
|
||||
"""Load subscription data from JSON file"""
|
||||
if not SUBSCRIPTION_DB.parent.exists():
|
||||
SUBSCRIPTION_DB.parent.mkdir(parents=True)
|
||||
|
||||
if not SUBSCRIPTION_DB.exists():
|
||||
return {}
|
||||
|
||||
with open(SUBSCRIPTION_DB, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_subscriptions(subscriptions):
|
||||
"""Save subscriptions to JSON file"""
|
||||
with open(SUBSCRIPTION_DB, 'w') as f:
|
||||
json.dump(subscriptions, f, indent=2)
|
||||
|
||||
def run_cleanup_playbook(sub_id):
|
||||
def run_cleanup_playbook(subscription_id):
|
||||
"""Run the VPN cleanup playbook"""
|
||||
logger.info(f"Running cleanup playbook for subscription {sub_id}")
|
||||
logger.info(f"Running cleanup playbook for subscription {subscription_id}")
|
||||
|
||||
vault_pass = os.getenv('ANSIBLE_VAULT_PASSWORD', '')
|
||||
if not vault_pass:
|
||||
@ -74,7 +55,7 @@ def run_cleanup_playbook(sub_id):
|
||||
'ansible-playbook',
|
||||
str(CLEANUP_PLAYBOOK),
|
||||
'-i', str(INVENTORY_FILE),
|
||||
'-e', f'subscription_id={sub_id}',
|
||||
'-e', f'subscription_id={subscription_id}',
|
||||
'--vault-password-file', vault_pass_file.name,
|
||||
'-vvv'
|
||||
]
|
||||
@ -122,42 +103,36 @@ def calculate_notification_times(start_time, end_time):
|
||||
warning_delta = duration * NOTIFICATION_THRESHOLDS['medium_term']['warning_fraction']
|
||||
grace_hours = NOTIFICATION_THRESHOLDS['medium_term']['grace_hours']
|
||||
warning_time = end_time - warning_delta
|
||||
grace_end_time = end_time + datetime.timedelta(hours=grace_hours)
|
||||
grace_end_time = end_time + timedelta(hours=grace_hours)
|
||||
|
||||
# Long-term subscriptions (> 7 days)
|
||||
else:
|
||||
warning_days = NOTIFICATION_THRESHOLDS['long_term']['warning_days']
|
||||
grace_days = NOTIFICATION_THRESHOLDS['long_term']['grace_days']
|
||||
warning_time = end_time - datetime.timedelta(days=warning_days)
|
||||
grace_end_time = end_time + datetime.timedelta(days=grace_days)
|
||||
warning_time = end_time - timedelta(days=warning_days)
|
||||
grace_end_time = end_time + timedelta(days=grace_days)
|
||||
|
||||
return warning_time, grace_end_time
|
||||
|
||||
def get_notification_message(sub_data, remaining_time):
|
||||
def get_notification_message(subscription, remaining_time):
|
||||
"""Generate appropriate notification message based on subscription duration"""
|
||||
if remaining_time < datetime.timedelta(hours=1):
|
||||
if remaining_time < timedelta(hours=1):
|
||||
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 60)} minutes!"
|
||||
elif remaining_time < datetime.timedelta(days=1):
|
||||
elif remaining_time < timedelta(days=1):
|
||||
return f"Your VPN subscription expires in {int(remaining_time.total_seconds() / 3600)} hours!"
|
||||
else:
|
||||
return f"Your VPN subscription expires in {remaining_time.days} days!"
|
||||
|
||||
def notify_user(user_id, message):
|
||||
def notify_user(subscription, message):
|
||||
"""Send notification to user about subscription status"""
|
||||
try:
|
||||
subscriptions = load_subscriptions()
|
||||
sub_data = subscriptions.get(user_id)
|
||||
|
||||
if not sub_data or 'email' not in sub_data:
|
||||
logger.error(f"No email found for user {user_id}")
|
||||
if not subscription.user or not subscription.user.email:
|
||||
logger.error(f"No email found for subscription {subscription.id}")
|
||||
return False
|
||||
|
||||
# Import BTCPayHandler here to avoid circular imports
|
||||
from app.handlers.payment_handler import BTCPayHandler
|
||||
|
||||
btcpay_handler = BTCPayHandler()
|
||||
email_sent = btcpay_handler.send_confirmation_email(
|
||||
sub_data['email'],
|
||||
subscription.user.email,
|
||||
f"""
|
||||
VPN Subscription Update
|
||||
|
||||
@ -168,9 +143,9 @@ def notify_user(user_id, message):
|
||||
)
|
||||
|
||||
if email_sent:
|
||||
logger.info(f"Sent notification to {sub_data['email']}: {message}")
|
||||
logger.info(f"Sent notification to {subscription.user.email}: {message}")
|
||||
else:
|
||||
logger.warning(f"Failed to send notification to {sub_data['email']}")
|
||||
logger.warning(f"Failed to send notification to {subscription.user.email}")
|
||||
|
||||
return email_sent
|
||||
|
||||
@ -182,51 +157,46 @@ def check_subscriptions():
|
||||
"""Check subscription status and clean up expired ones"""
|
||||
logger.info("Starting subscription check")
|
||||
|
||||
subscriptions = load_subscriptions()
|
||||
now = datetime.datetime.now()
|
||||
modified = False
|
||||
|
||||
logger.info(f"Checking {len(subscriptions)} subscriptions")
|
||||
|
||||
for sub_id, sub_data in list(subscriptions.items()):
|
||||
try:
|
||||
logger.debug(f"Processing subscription {sub_id}")
|
||||
active_subscriptions = DatabaseManager.get_active_subscriptions()
|
||||
logger.info(f"Checking {len(active_subscriptions)} active subscriptions")
|
||||
|
||||
if sub_data.get('status') != 'Active':
|
||||
logger.debug(f"Skipping inactive subscription {sub_id}")
|
||||
continue
|
||||
now = datetime.utcnow()
|
||||
|
||||
start_time = parse(sub_data['start_time'])
|
||||
expiry = parse(sub_data['expiry'])
|
||||
warning_time, grace_end_time = calculate_notification_times(start_time, expiry)
|
||||
for subscription in active_subscriptions:
|
||||
try:
|
||||
logger.debug(f"Processing subscription {subscription.id}")
|
||||
|
||||
warning_time, grace_end_time = calculate_notification_times(
|
||||
subscription.start_time,
|
||||
subscription.expiry_time
|
||||
)
|
||||
|
||||
# Calculate remaining time
|
||||
remaining_time = expiry - now
|
||||
logger.debug(f"Subscription {sub_id} has {remaining_time} remaining")
|
||||
remaining_time = subscription.expiry_time - now
|
||||
logger.debug(f"Subscription {subscription.id} has {remaining_time} remaining")
|
||||
|
||||
# Handle warnings
|
||||
if now >= warning_time and not sub_data.get('warning_sent'):
|
||||
message = get_notification_message(sub_data, remaining_time)
|
||||
logger.info(f"Sending notification for subscription {sub_id}: {message}")
|
||||
if now >= warning_time and not subscription.warning_sent:
|
||||
message = get_notification_message(subscription, remaining_time)
|
||||
logger.info(f"Sending notification for subscription {subscription.id}: {message}")
|
||||
|
||||
notify_user(sub_id, message)
|
||||
|
||||
sub_data['warning_sent'] = True
|
||||
modified = True
|
||||
if notify_user(subscription, message):
|
||||
DatabaseManager.update_warning_sent(subscription.id)
|
||||
|
||||
# Handle expiration
|
||||
if now >= grace_end_time and sub_data.get('status') == 'Active':
|
||||
logger.info(f"Processing expiration for subscription {sub_id}")
|
||||
if now >= grace_end_time:
|
||||
logger.info(f"Processing expiration for subscription {subscription.id}")
|
||||
|
||||
try:
|
||||
logger.debug(f"Running cleanup playbook for {sub_id}")
|
||||
result = run_cleanup_playbook(sub_id)
|
||||
result = run_cleanup_playbook(subscription.invoice_id)
|
||||
|
||||
if result.returncode == 0:
|
||||
sub_data['status'] = 'Expired'
|
||||
sub_data['cleanup_date'] = now.isoformat()
|
||||
modified = True
|
||||
logger.info(f"Successfully cleaned up subscription {sub_id}")
|
||||
DatabaseManager.expire_subscription(subscription.id)
|
||||
logger.info(f"Successfully cleaned up subscription {subscription.id}")
|
||||
|
||||
# Send final notification
|
||||
notify_user(subscription, "Your VPN subscription has expired and been deactivated.")
|
||||
else:
|
||||
logger.error(f"Cleanup failed: {result.stderr}")
|
||||
|
||||
@ -235,16 +205,17 @@ def check_subscriptions():
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing subscription {sub_id}: {str(e)}")
|
||||
logger.error(f"Error processing subscription {subscription.id}: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
continue
|
||||
|
||||
if modified:
|
||||
save_subscriptions(subscriptions)
|
||||
logger.info("Updated subscription database")
|
||||
|
||||
logger.info("Subscription check completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Subscription checker failed: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
check_subscriptions()
|
||||
|
164
venv/include/site/python3.11/greenlet/greenlet.h
Normal file
164
venv/include/site/python3.11/greenlet/greenlet.h
Normal file
@ -0,0 +1,164 @@
|
||||
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
|
||||
|
||||
/* Greenlet object interface */
|
||||
|
||||
#ifndef Py_GREENLETOBJECT_H
|
||||
#define Py_GREENLETOBJECT_H
|
||||
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* This is deprecated and undocumented. It does not change. */
|
||||
#define GREENLET_VERSION "1.0.0"
|
||||
|
||||
#ifndef GREENLET_MODULE
|
||||
#define implementation_ptr_t void*
|
||||
#endif
|
||||
|
||||
typedef struct _greenlet {
|
||||
PyObject_HEAD
|
||||
PyObject* weakreflist;
|
||||
PyObject* dict;
|
||||
implementation_ptr_t pimpl;
|
||||
} PyGreenlet;
|
||||
|
||||
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
|
||||
|
||||
|
||||
/* C API functions */
|
||||
|
||||
/* Total number of symbols that are exported */
|
||||
#define PyGreenlet_API_pointers 12
|
||||
|
||||
#define PyGreenlet_Type_NUM 0
|
||||
#define PyExc_GreenletError_NUM 1
|
||||
#define PyExc_GreenletExit_NUM 2
|
||||
|
||||
#define PyGreenlet_New_NUM 3
|
||||
#define PyGreenlet_GetCurrent_NUM 4
|
||||
#define PyGreenlet_Throw_NUM 5
|
||||
#define PyGreenlet_Switch_NUM 6
|
||||
#define PyGreenlet_SetParent_NUM 7
|
||||
|
||||
#define PyGreenlet_MAIN_NUM 8
|
||||
#define PyGreenlet_STARTED_NUM 9
|
||||
#define PyGreenlet_ACTIVE_NUM 10
|
||||
#define PyGreenlet_GET_PARENT_NUM 11
|
||||
|
||||
#ifndef GREENLET_MODULE
|
||||
/* This section is used by modules that uses the greenlet C API */
|
||||
static void** _PyGreenlet_API = NULL;
|
||||
|
||||
# define PyGreenlet_Type \
|
||||
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
|
||||
|
||||
# define PyExc_GreenletError \
|
||||
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
|
||||
|
||||
# define PyExc_GreenletExit \
|
||||
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_New(PyObject *args)
|
||||
*
|
||||
* greenlet.greenlet(run, parent=None)
|
||||
*/
|
||||
# define PyGreenlet_New \
|
||||
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
|
||||
_PyGreenlet_API[PyGreenlet_New_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_GetCurrent(void)
|
||||
*
|
||||
* greenlet.getcurrent()
|
||||
*/
|
||||
# define PyGreenlet_GetCurrent \
|
||||
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_Throw(
|
||||
* PyGreenlet *greenlet,
|
||||
* PyObject *typ,
|
||||
* PyObject *val,
|
||||
* PyObject *tb)
|
||||
*
|
||||
* g.throw(...)
|
||||
*/
|
||||
# define PyGreenlet_Throw \
|
||||
(*(PyObject * (*)(PyGreenlet * self, \
|
||||
PyObject * typ, \
|
||||
PyObject * val, \
|
||||
PyObject * tb)) \
|
||||
_PyGreenlet_API[PyGreenlet_Throw_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
|
||||
*
|
||||
* g.switch(*args, **kwargs)
|
||||
*/
|
||||
# define PyGreenlet_Switch \
|
||||
(*(PyObject * \
|
||||
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
|
||||
_PyGreenlet_API[PyGreenlet_Switch_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
|
||||
*
|
||||
* g.parent = new_parent
|
||||
*/
|
||||
# define PyGreenlet_SetParent \
|
||||
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
|
||||
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
|
||||
|
||||
/*
|
||||
* PyGreenlet_GetParent(PyObject* greenlet)
|
||||
*
|
||||
* return greenlet.parent;
|
||||
*
|
||||
* This could return NULL even if there is no exception active.
|
||||
* If it does not return NULL, you are responsible for decrementing the
|
||||
* reference count.
|
||||
*/
|
||||
# define PyGreenlet_GetParent \
|
||||
(*(PyGreenlet* (*)(PyGreenlet*)) \
|
||||
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
|
||||
|
||||
/*
|
||||
* deprecated, undocumented alias.
|
||||
*/
|
||||
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
|
||||
|
||||
# define PyGreenlet_MAIN \
|
||||
(*(int (*)(PyGreenlet*)) \
|
||||
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
|
||||
|
||||
# define PyGreenlet_STARTED \
|
||||
(*(int (*)(PyGreenlet*)) \
|
||||
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
|
||||
|
||||
# define PyGreenlet_ACTIVE \
|
||||
(*(int (*)(PyGreenlet*)) \
|
||||
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
|
||||
|
||||
|
||||
|
||||
|
||||
/* Macro that imports greenlet and initializes C API */
|
||||
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
|
||||
keep the older definition to be sure older code that might have a copy of
|
||||
the header still works. */
|
||||
# define PyGreenlet_Import() \
|
||||
{ \
|
||||
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
|
||||
}
|
||||
|
||||
#endif /* GREENLET_MODULE */
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif /* !Py_GREENLETOBJECT_H */
|
Loading…
Reference in New Issue
Block a user