package middleware import ( "context" "net/http" "strconv" "strings" "sync" "time" "golang.org/x/time/rate" ) // RateLimiter manages different types of rate limits type RateLimiter struct { uploadLimiters map[string]*rate.Limiter // Per IP upload limits downloadLimit *rate.Limiter // Global download limit streamLimiters map[string]*rate.Limiter // Per file stream limits uploadMutex sync.RWMutex streamMutex sync.RWMutex config *RateLimitConfig } // RateLimitConfig configures rate limiting behavior type RateLimitConfig struct { // Upload limits (per IP) UploadRatePerIP float64 // requests per second per IP UploadBurstPerIP int // burst size per IP // Global download limits DownloadRate float64 // requests per second globally DownloadBurst int // global burst size // Stream limits (per file) StreamRatePerFile float64 // requests per second per file StreamBurstPerFile int // burst size per file // Cleanup settings CleanupInterval time.Duration // how often to clean old limiters LimiterTTL time.Duration // how long to keep inactive limiters } // NewRateLimiter creates a new rate limiter with the given configuration func NewRateLimiter(config *RateLimitConfig) *RateLimiter { if config == nil { config = &RateLimitConfig{ UploadRatePerIP: 1.0, // 1 upload per second per IP UploadBurstPerIP: 5, // burst of 5 DownloadRate: 50.0, // 50 downloads per second globally DownloadBurst: 100, // burst of 100 StreamRatePerFile: 10.0, // 10 streams per second per file StreamBurstPerFile: 20, // burst of 20 CleanupInterval: 5 * time.Minute, LimiterTTL: 15 * time.Minute, } } // Validate configuration values if config.UploadRatePerIP <= 0 { config.UploadRatePerIP = 1.0 } if config.UploadBurstPerIP <= 0 { config.UploadBurstPerIP = 5 } if config.DownloadRate <= 0 { config.DownloadRate = 50.0 } if config.DownloadBurst <= 0 { config.DownloadBurst = 100 } if config.StreamRatePerFile <= 0 { config.StreamRatePerFile = 10.0 } if config.StreamBurstPerFile <= 0 { config.StreamBurstPerFile = 20 } rl := &RateLimiter{ uploadLimiters: make(map[string]*rate.Limiter), downloadLimit: rate.NewLimiter(rate.Limit(config.DownloadRate), config.DownloadBurst), streamLimiters: make(map[string]*rate.Limiter), config: config, } // Start cleanup routine go rl.cleanupRoutine() return rl } // UploadMiddleware applies per-IP upload rate limiting func (rl *RateLimiter) UploadMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Extract IP address ip := rl.getClientIP(r) // Get or create limiter for this IP limiter := rl.getUploadLimiter(ip) // Check rate limit if !limiter.Allow() { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP)) w.Header().Set("X-RateLimit-Remaining", "0") w.Header().Set("Retry-After", "60") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error": "Upload rate limit exceeded. Please try again later."}`)) return } // Add rate limit headers w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP)) w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens()))) next(w, r) } } // DownloadMiddleware applies global download rate limiting func (rl *RateLimiter) DownloadMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Check global download rate limit if !rl.downloadLimit.Allow() { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst)) w.Header().Set("X-RateLimit-Remaining", "0") w.Header().Set("Retry-After", "10") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error": "Global download rate limit exceeded. Please try again later."}`)) return } // Add rate limit headers w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst)) w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(rl.downloadLimit.Tokens()))) next(w, r) } } // StreamMiddleware applies per-file stream rate limiting func (rl *RateLimiter) StreamMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Extract file hash from URL path or query parameters fileHash := rl.extractFileHash(r) if fileHash == "" { // No file hash found, proceed without stream limiting next(w, r) return } // Get or create limiter for this file limiter := rl.getStreamLimiter(fileHash) // Check rate limit if !limiter.Allow() { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile)) w.Header().Set("X-RateLimit-Remaining", "0") w.Header().Set("Retry-After", "5") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error": "Stream rate limit exceeded for this file. Please try again later."}`)) return } // Add rate limit headers w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile)) w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens()))) next(w, r) } } // getUploadLimiter gets or creates a rate limiter for the given IP func (rl *RateLimiter) getUploadLimiter(ip string) *rate.Limiter { rl.uploadMutex.Lock() defer rl.uploadMutex.Unlock() limiter, exists := rl.uploadLimiters[ip] if !exists { limiter = rate.NewLimiter(rate.Limit(rl.config.UploadRatePerIP), rl.config.UploadBurstPerIP) rl.uploadLimiters[ip] = limiter } return limiter } // getStreamLimiter gets or creates a rate limiter for the given file func (rl *RateLimiter) getStreamLimiter(fileHash string) *rate.Limiter { rl.streamMutex.Lock() defer rl.streamMutex.Unlock() limiter, exists := rl.streamLimiters[fileHash] if !exists { limiter = rate.NewLimiter(rate.Limit(rl.config.StreamRatePerFile), rl.config.StreamBurstPerFile) rl.streamLimiters[fileHash] = limiter } return limiter } // getClientIP extracts the client IP address from the request func (rl *RateLimiter) getClientIP(r *http.Request) string { // Check X-Forwarded-For header first (for proxy setups) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { // Take the first IP in the chain if idx := strings.Index(xff, ","); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return strings.TrimSpace(xri) } // Fall back to RemoteAddr ip := r.RemoteAddr if idx := strings.LastIndex(ip, ":"); idx != -1 { ip = ip[:idx] } return ip } // extractFileHash extracts file hash from request URL or parameters func (rl *RateLimiter) extractFileHash(r *http.Request) string { // Try URL path first (e.g., /api/files/{hash}/stream) pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") for i, part := range pathParts { if part == "files" && i+1 < len(pathParts) { return pathParts[i+1] } } // Try query parameters if hash := r.URL.Query().Get("hash"); hash != "" { return hash } if hash := r.URL.Query().Get("file_hash"); hash != "" { return hash } return "" } // cleanupRoutine periodically removes inactive rate limiters to prevent memory leaks func (rl *RateLimiter) cleanupRoutine() { ticker := time.NewTicker(rl.config.CleanupInterval) defer ticker.Stop() for range ticker.C { rl.cleanup() } } // cleanup removes inactive rate limiters func (rl *RateLimiter) cleanup() { // Clean upload limiters rl.uploadMutex.Lock() for ip, limiter := range rl.uploadLimiters { // Check if limiter hasn't been used recently if limiter.Tokens() >= float64(rl.config.UploadBurstPerIP) { // Remove if it's been inactive delete(rl.uploadLimiters, ip) } } rl.uploadMutex.Unlock() // Clean stream limiters rl.streamMutex.Lock() for fileHash, limiter := range rl.streamLimiters { // Check if limiter hasn't been used recently if limiter.Tokens() >= float64(rl.config.StreamBurstPerFile) { // Remove if it's been inactive delete(rl.streamLimiters, fileHash) } } rl.streamMutex.Unlock() } // GetStats returns current rate limiting statistics func (rl *RateLimiter) GetStats() map[string]interface{} { rl.uploadMutex.RLock() uploadLimiterCount := len(rl.uploadLimiters) rl.uploadMutex.RUnlock() rl.streamMutex.RLock() streamLimiterCount := len(rl.streamLimiters) rl.streamMutex.RUnlock() return map[string]interface{}{ "upload_limiters": uploadLimiterCount, "stream_limiters": streamLimiterCount, "download_tokens": rl.downloadLimit.Tokens(), "upload_rate_per_ip": rl.config.UploadRatePerIP, "upload_burst_per_ip": rl.config.UploadBurstPerIP, "download_rate": rl.config.DownloadRate, "download_burst": rl.config.DownloadBurst, "stream_rate_per_file": rl.config.StreamRatePerFile, "stream_burst_per_file": rl.config.StreamBurstPerFile, } } // UpdateConfig allows runtime configuration updates func (rl *RateLimiter) UpdateConfig(config *RateLimitConfig) { rl.config = config // Update global download limiter rl.downloadLimit.SetLimit(rate.Limit(config.DownloadRate)) rl.downloadLimit.SetBurst(config.DownloadBurst) // Note: Existing per-IP and per-file limiters will use old config until they're recreated // This is acceptable as they'll be cleaned up and recreated with new config over time } // WaitHandler wraps a handler to wait for rate limit availability instead of rejecting func (rl *RateLimiter) WaitHandler(limiterFunc func(*http.Request) *rate.Limiter, timeout time.Duration) func(http.HandlerFunc) http.HandlerFunc { return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { limiter := limiterFunc(r) ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() // Wait for rate limit availability if err := limiter.Wait(ctx); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusTooManyRequests) w.Write([]byte(`{"error": "Rate limit timeout exceeded"}`)) return } next(w, r.WithContext(ctx)) } } }