enki b3204ea07a
Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / Build Docker Images (push) Blocked by required conditions
CI Pipeline / E2E Tests (push) Blocked by required conditions
first commit
2025-08-18 00:40:15 -07:00

500 lines
14 KiB
Go

package cdn
import (
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
// CDNFeatures provides CDN-like capabilities for the gateway
type CDNFeatures struct {
compressionEnabled bool
throttlingEnabled bool
// Bandwidth throttling per connection
bandwidthLimiters map[string]*rate.Limiter
throttleMutex sync.RWMutex
config *CDNConfig
}
// CDNConfig configures CDN behavior
type CDNConfig struct {
// Compression settings
CompressionEnabled bool
CompressionMinSize int64 // Minimum file size to compress
CompressionTypes []string // MIME types to compress
// Bandwidth throttling
ThrottlingEnabled bool
DefaultBandwidthKbps int // Default bandwidth limit in KB/s
PremiumBandwidthKbps int // Premium user bandwidth in KB/s
// Cache control
CacheMaxAge time.Duration
StaticCacheMaxAge time.Duration
// Content delivery optimization
EdgeCacheTTL time.Duration
PrefetchEnabled bool
// Cleanup
ThrottlerCleanup time.Duration
}
// NewCDNFeatures creates a new CDN features manager
func NewCDNFeatures(config *CDNConfig) *CDNFeatures {
if config == nil {
config = &CDNConfig{
CompressionEnabled: true,
CompressionMinSize: 1024, // 1KB
CompressionTypes: []string{"text/", "application/json", "application/javascript", "application/xml"},
ThrottlingEnabled: true,
DefaultBandwidthKbps: 1024, // 1MB/s
PremiumBandwidthKbps: 5120, // 5MB/s
CacheMaxAge: 24 * time.Hour,
StaticCacheMaxAge: 7 * 24 * time.Hour,
EdgeCacheTTL: time.Hour,
PrefetchEnabled: false,
ThrottlerCleanup: 10 * time.Minute,
}
}
cdn := &CDNFeatures{
compressionEnabled: config.CompressionEnabled,
throttlingEnabled: config.ThrottlingEnabled,
bandwidthLimiters: make(map[string]*rate.Limiter),
config: config,
}
// Start cleanup routine for bandwidth limiters
if config.ThrottlingEnabled {
go cdn.cleanupThrottlers()
}
return cdn
}
// CompressionMiddleware adds gzip compression for supported content types
func (c *CDNFeatures) CompressionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !c.compressionEnabled {
next.ServeHTTP(w, r)
return
}
// Check if client accepts gzip
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
// Wrap response writer with gzip compression
gzipWriter := &gzipResponseWriter{
ResponseWriter: w,
cdnConfig: c.config,
}
defer gzipWriter.Close()
// Set compression headers
w.Header().Set("Content-Encoding", "gzip")
w.Header().Set("Vary", "Accept-Encoding")
next.ServeHTTP(gzipWriter, r)
})
}
// ThrottlingMiddleware applies bandwidth throttling per connection
func (c *CDNFeatures) ThrottlingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !c.throttlingEnabled {
next.ServeHTTP(w, r)
return
}
// Get client IP for throttling
clientIP := c.getClientIP(r)
// Get bandwidth limiter for this client
limiter := c.getBandwidthLimiter(clientIP, r)
// Wrap response writer with throttling
throttledWriter := &throttledResponseWriter{
ResponseWriter: w,
limiter: limiter,
}
next.ServeHTTP(throttledWriter, r)
})
}
// CacheControlMiddleware adds appropriate cache headers
func (c *CDNFeatures) CacheControlMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set cache headers based on content type and path
if strings.HasPrefix(r.URL.Path, "/static/") {
// Static assets get long cache
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(c.config.StaticCacheMaxAge.Seconds())))
} else if strings.HasPrefix(r.URL.Path, "/api/files/") {
// File content gets medium cache
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(c.config.CacheMaxAge.Seconds())))
} else {
// API responses get minimal cache
w.Header().Set("Cache-Control", "public, max-age=300") // 5 minutes
}
// Add ETag support for better caching
if r.Method == "GET" && strings.Contains(r.URL.Path, "/files/") {
// Extract file hash from URL as ETag
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
for i, part := range pathParts {
if part == "files" && i+1 < len(pathParts) {
etag := fmt.Sprintf("\"%s\"", pathParts[i+1])
w.Header().Set("ETag", etag)
// Check If-None-Match header
if r.Header.Get("If-None-Match") == etag {
w.WriteHeader(http.StatusNotModified)
return
}
break
}
}
}
next.ServeHTTP(w, r)
})
}
// gzipResponseWriter wraps http.ResponseWriter with gzip compression
type gzipResponseWriter struct {
http.ResponseWriter
gzipWriter *gzip.Writer
cdnConfig *CDNConfig
written bool
}
func (g *gzipResponseWriter) Write(data []byte) (int, error) {
if !g.written {
g.written = true
// Check if we should compress this content
contentType := g.Header().Get("Content-Type")
shouldCompress := false
for _, compressType := range g.cdnConfig.CompressionTypes {
if strings.HasPrefix(contentType, compressType) {
shouldCompress = true
break
}
}
// Check minimum size
if contentLength := g.Header().Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil {
if size < g.cdnConfig.CompressionMinSize {
shouldCompress = false
}
}
}
if !shouldCompress {
// Remove compression headers and write directly
g.Header().Del("Content-Encoding")
g.Header().Del("Vary")
return g.ResponseWriter.Write(data)
}
// Initialize gzip writer
g.gzipWriter = gzip.NewWriter(g.ResponseWriter)
}
if g.gzipWriter != nil {
return g.gzipWriter.Write(data)
}
return g.ResponseWriter.Write(data)
}
func (g *gzipResponseWriter) Close() error {
if g.gzipWriter != nil {
return g.gzipWriter.Close()
}
return nil
}
// throttledResponseWriter wraps http.ResponseWriter with bandwidth throttling
type throttledResponseWriter struct {
http.ResponseWriter
limiter *rate.Limiter
}
func (t *throttledResponseWriter) Write(data []byte) (int, error) {
if t.limiter == nil {
return t.ResponseWriter.Write(data)
}
// Apply throttling by reserving tokens for each byte
ctx := context.Background()
reservation := t.limiter.ReserveN(time.Now(), len(data))
if !reservation.OK() {
// Rate limit exceeded, write with delay
time.Sleep(100 * time.Millisecond)
} else {
// Wait for the reservation
time.Sleep(reservation.Delay())
}
return t.ResponseWriter.Write(data)
}
// getBandwidthLimiter gets or creates a bandwidth limiter for the client
func (c *CDNFeatures) getBandwidthLimiter(clientIP string, r *http.Request) *rate.Limiter {
c.throttleMutex.Lock()
defer c.throttleMutex.Unlock()
limiter, exists := c.bandwidthLimiters[clientIP]
if !exists {
// Determine bandwidth limit based on user tier
bandwidthKbps := c.config.DefaultBandwidthKbps
// Check if user is premium (this would need integration with user management)
// For now, use default bandwidth
// Convert KB/s to bytes per second for rate limiter
bytesPerSecond := rate.Limit(bandwidthKbps * 1024)
burstSize := bandwidthKbps * 1024 * 2 // 2 second burst
limiter = rate.NewLimiter(bytesPerSecond, burstSize)
c.bandwidthLimiters[clientIP] = limiter
}
return limiter
}
// getClientIP extracts client IP from request
func (c *CDNFeatures) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// cleanupThrottlers periodically removes inactive bandwidth limiters
func (c *CDNFeatures) cleanupThrottlers() {
ticker := time.NewTicker(c.config.ThrottlerCleanup)
defer ticker.Stop()
for range ticker.C {
c.throttleMutex.Lock()
// Remove limiters that are at full capacity (inactive)
for ip, limiter := range c.bandwidthLimiters {
if limiter.Tokens() >= float64(limiter.Burst()) {
delete(c.bandwidthLimiters, ip)
}
}
c.throttleMutex.Unlock()
}
}
// ContentOptimizationMiddleware optimizes content delivery
func (c *CDNFeatures) ContentOptimizationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Add security headers
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Add CORS headers for API endpoints
if strings.HasPrefix(r.URL.Path, "/api/") {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
// Handle preflight requests
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}
// Add performance hints
w.Header().Set("X-Served-By", "TorrentGateway")
w.Header().Set("Server-Timing", fmt.Sprintf("cdn;dur=%d", time.Now().UnixMilli()))
next.ServeHTTP(w, r)
})
}
// GetStats returns CDN performance statistics
func (c *CDNFeatures) GetStats() map[string]interface{} {
c.throttleMutex.RLock()
activeLimiters := len(c.bandwidthLimiters)
c.throttleMutex.RUnlock()
return map[string]interface{}{
"compression_enabled": c.compressionEnabled,
"throttling_enabled": c.throttlingEnabled,
"active_bandwidth_limiters": activeLimiters,
"default_bandwidth_kbps": c.config.DefaultBandwidthKbps,
"premium_bandwidth_kbps": c.config.PremiumBandwidthKbps,
"cache_max_age_seconds": int(c.config.CacheMaxAge.Seconds()),
"compression_min_size": c.config.CompressionMinSize,
}
}
// UpdateBandwidthLimit dynamically updates bandwidth limit for a client
func (c *CDNFeatures) UpdateBandwidthLimit(clientIP string, kbps int) {
c.throttleMutex.Lock()
defer c.throttleMutex.Unlock()
bytesPerSecond := rate.Limit(kbps * 1024)
burstSize := kbps * 1024 * 2 // 2 second burst
c.bandwidthLimiters[clientIP] = rate.NewLimiter(bytesPerSecond, burstSize)
}
// StreamingOptimizationMiddleware optimizes streaming responses
func (c *CDNFeatures) StreamingOptimizationMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if this is a streaming request
if strings.Contains(r.URL.Path, "/stream") || r.Header.Get("Accept") == "application/octet-stream" {
// Disable compression for streaming
w.Header().Set("Content-Encoding", "identity")
// Set streaming headers
w.Header().Set("Accept-Ranges", "bytes")
w.Header().Set("Content-Type", "application/octet-stream")
// Handle range requests for efficient streaming
if rangeHeader := r.Header.Get("Range"); rangeHeader != "" {
c.handleRangeRequest(w, r, rangeHeader)
return
}
}
next.ServeHTTP(w, r)
})
}
// handleRangeRequest handles HTTP range requests for efficient streaming
func (c *CDNFeatures) handleRangeRequest(w http.ResponseWriter, r *http.Request, rangeHeader string) {
// Parse range header (e.g., "bytes=0-1023")
if !strings.HasPrefix(rangeHeader, "bytes=") {
http.Error(w, "Invalid range header", http.StatusBadRequest)
return
}
rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=")
rangeParts := strings.Split(rangeSpec, "-")
if len(rangeParts) != 2 {
http.Error(w, "Invalid range format", http.StatusBadRequest)
return
}
// For now, let the next handler deal with the actual range processing
// This middleware just sets up the headers
w.Header().Set("Content-Range", "bytes */")
w.WriteHeader(http.StatusPartialContent)
}
// PrefetchMiddleware adds link prefetch headers for performance
func (c *CDNFeatures) PrefetchMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if c.config.PrefetchEnabled && r.URL.Path == "/" {
// Add prefetch hints for common resources
w.Header().Add("Link", "</static/style.css>; rel=prefetch")
w.Header().Add("Link", "</static/app.js>; rel=prefetch")
}
next.ServeHTTP(w, r)
})
}
// EdgeCacheMiddleware simulates edge caching behavior
func (c *CDNFeatures) EdgeCacheMiddleware(cache CacheInterface) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only cache GET requests
if r.Method != "GET" {
next.ServeHTTP(w, r)
return
}
// Generate cache key
cacheKey := fmt.Sprintf("edge:%s", r.URL.Path)
// Check cache first
if cached, found := cache.Get(cacheKey); found {
w.Header().Set("X-Cache", "HIT")
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(cached)
return
}
// Cache miss - capture response for caching
recorder := &responseRecorder{
ResponseWriter: w,
body: make([]byte, 0),
}
next.ServeHTTP(recorder, r)
// Cache successful responses
if recorder.statusCode == 200 && len(recorder.body) > 0 {
cache.Set(cacheKey, recorder.body, c.config.EdgeCacheTTL)
w.Header().Set("X-Cache", "MISS")
}
})
}
}
// CacheInterface defines cache operations for CDN
type CacheInterface interface {
Get(key string) ([]byte, bool)
Set(key string, value []byte, ttl time.Duration) error
}
// responseRecorder captures HTTP responses for caching
type responseRecorder struct {
http.ResponseWriter
body []byte
statusCode int
}
func (r *responseRecorder) Write(data []byte) (int, error) {
r.body = append(r.body, data...)
return r.ResponseWriter.Write(data)
}
func (r *responseRecorder) WriteHeader(statusCode int) {
r.statusCode = statusCode
r.ResponseWriter.WriteHeader(statusCode)
}