enki b3204ea07a
Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / Build Docker Images (push) Blocked by required conditions
CI Pipeline / E2E Tests (push) Blocked by required conditions
first commit
2025-08-18 00:40:15 -07:00

339 lines
10 KiB
Go

package middleware
import (
"context"
"net/http"
"strconv"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiter manages different types of rate limits
type RateLimiter struct {
uploadLimiters map[string]*rate.Limiter // Per IP upload limits
downloadLimit *rate.Limiter // Global download limit
streamLimiters map[string]*rate.Limiter // Per file stream limits
uploadMutex sync.RWMutex
streamMutex sync.RWMutex
config *RateLimitConfig
}
// RateLimitConfig configures rate limiting behavior
type RateLimitConfig struct {
// Upload limits (per IP)
UploadRatePerIP float64 // requests per second per IP
UploadBurstPerIP int // burst size per IP
// Global download limits
DownloadRate float64 // requests per second globally
DownloadBurst int // global burst size
// Stream limits (per file)
StreamRatePerFile float64 // requests per second per file
StreamBurstPerFile int // burst size per file
// Cleanup settings
CleanupInterval time.Duration // how often to clean old limiters
LimiterTTL time.Duration // how long to keep inactive limiters
}
// NewRateLimiter creates a new rate limiter with the given configuration
func NewRateLimiter(config *RateLimitConfig) *RateLimiter {
if config == nil {
config = &RateLimitConfig{
UploadRatePerIP: 1.0, // 1 upload per second per IP
UploadBurstPerIP: 5, // burst of 5
DownloadRate: 50.0, // 50 downloads per second globally
DownloadBurst: 100, // burst of 100
StreamRatePerFile: 10.0, // 10 streams per second per file
StreamBurstPerFile: 20, // burst of 20
CleanupInterval: 5 * time.Minute,
LimiterTTL: 15 * time.Minute,
}
}
// Validate configuration values
if config.UploadRatePerIP <= 0 {
config.UploadRatePerIP = 1.0
}
if config.UploadBurstPerIP <= 0 {
config.UploadBurstPerIP = 5
}
if config.DownloadRate <= 0 {
config.DownloadRate = 50.0
}
if config.DownloadBurst <= 0 {
config.DownloadBurst = 100
}
if config.StreamRatePerFile <= 0 {
config.StreamRatePerFile = 10.0
}
if config.StreamBurstPerFile <= 0 {
config.StreamBurstPerFile = 20
}
rl := &RateLimiter{
uploadLimiters: make(map[string]*rate.Limiter),
downloadLimit: rate.NewLimiter(rate.Limit(config.DownloadRate), config.DownloadBurst),
streamLimiters: make(map[string]*rate.Limiter),
config: config,
}
// Start cleanup routine
go rl.cleanupRoutine()
return rl
}
// UploadMiddleware applies per-IP upload rate limiting
func (rl *RateLimiter) UploadMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Extract IP address
ip := rl.getClientIP(r)
// Get or create limiter for this IP
limiter := rl.getUploadLimiter(ip)
// Check rate limit
if !limiter.Allow() {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP))
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Upload rate limit exceeded. Please try again later."}`))
return
}
// Add rate limit headers
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP))
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens())))
next(w, r)
}
}
// DownloadMiddleware applies global download rate limiting
func (rl *RateLimiter) DownloadMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Check global download rate limit
if !rl.downloadLimit.Allow() {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst))
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("Retry-After", "10")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Global download rate limit exceeded. Please try again later."}`))
return
}
// Add rate limit headers
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst))
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(rl.downloadLimit.Tokens())))
next(w, r)
}
}
// StreamMiddleware applies per-file stream rate limiting
func (rl *RateLimiter) StreamMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Extract file hash from URL path or query parameters
fileHash := rl.extractFileHash(r)
if fileHash == "" {
// No file hash found, proceed without stream limiting
next(w, r)
return
}
// Get or create limiter for this file
limiter := rl.getStreamLimiter(fileHash)
// Check rate limit
if !limiter.Allow() {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile))
w.Header().Set("X-RateLimit-Remaining", "0")
w.Header().Set("Retry-After", "5")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Stream rate limit exceeded for this file. Please try again later."}`))
return
}
// Add rate limit headers
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile))
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens())))
next(w, r)
}
}
// getUploadLimiter gets or creates a rate limiter for the given IP
func (rl *RateLimiter) getUploadLimiter(ip string) *rate.Limiter {
rl.uploadMutex.Lock()
defer rl.uploadMutex.Unlock()
limiter, exists := rl.uploadLimiters[ip]
if !exists {
limiter = rate.NewLimiter(rate.Limit(rl.config.UploadRatePerIP), rl.config.UploadBurstPerIP)
rl.uploadLimiters[ip] = limiter
}
return limiter
}
// getStreamLimiter gets or creates a rate limiter for the given file
func (rl *RateLimiter) getStreamLimiter(fileHash string) *rate.Limiter {
rl.streamMutex.Lock()
defer rl.streamMutex.Unlock()
limiter, exists := rl.streamLimiters[fileHash]
if !exists {
limiter = rate.NewLimiter(rate.Limit(rl.config.StreamRatePerFile), rl.config.StreamBurstPerFile)
rl.streamLimiters[fileHash] = limiter
}
return limiter
}
// getClientIP extracts the client IP address from the request
func (rl *RateLimiter) getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first (for proxy setups)
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the chain
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
ip := r.RemoteAddr
if idx := strings.LastIndex(ip, ":"); idx != -1 {
ip = ip[:idx]
}
return ip
}
// extractFileHash extracts file hash from request URL or parameters
func (rl *RateLimiter) extractFileHash(r *http.Request) string {
// Try URL path first (e.g., /api/files/{hash}/stream)
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
for i, part := range pathParts {
if part == "files" && i+1 < len(pathParts) {
return pathParts[i+1]
}
}
// Try query parameters
if hash := r.URL.Query().Get("hash"); hash != "" {
return hash
}
if hash := r.URL.Query().Get("file_hash"); hash != "" {
return hash
}
return ""
}
// cleanupRoutine periodically removes inactive rate limiters to prevent memory leaks
func (rl *RateLimiter) cleanupRoutine() {
ticker := time.NewTicker(rl.config.CleanupInterval)
defer ticker.Stop()
for range ticker.C {
rl.cleanup()
}
}
// cleanup removes inactive rate limiters
func (rl *RateLimiter) cleanup() {
// Clean upload limiters
rl.uploadMutex.Lock()
for ip, limiter := range rl.uploadLimiters {
// Check if limiter hasn't been used recently
if limiter.Tokens() >= float64(rl.config.UploadBurstPerIP) {
// Remove if it's been inactive
delete(rl.uploadLimiters, ip)
}
}
rl.uploadMutex.Unlock()
// Clean stream limiters
rl.streamMutex.Lock()
for fileHash, limiter := range rl.streamLimiters {
// Check if limiter hasn't been used recently
if limiter.Tokens() >= float64(rl.config.StreamBurstPerFile) {
// Remove if it's been inactive
delete(rl.streamLimiters, fileHash)
}
}
rl.streamMutex.Unlock()
}
// GetStats returns current rate limiting statistics
func (rl *RateLimiter) GetStats() map[string]interface{} {
rl.uploadMutex.RLock()
uploadLimiterCount := len(rl.uploadLimiters)
rl.uploadMutex.RUnlock()
rl.streamMutex.RLock()
streamLimiterCount := len(rl.streamLimiters)
rl.streamMutex.RUnlock()
return map[string]interface{}{
"upload_limiters": uploadLimiterCount,
"stream_limiters": streamLimiterCount,
"download_tokens": rl.downloadLimit.Tokens(),
"upload_rate_per_ip": rl.config.UploadRatePerIP,
"upload_burst_per_ip": rl.config.UploadBurstPerIP,
"download_rate": rl.config.DownloadRate,
"download_burst": rl.config.DownloadBurst,
"stream_rate_per_file": rl.config.StreamRatePerFile,
"stream_burst_per_file": rl.config.StreamBurstPerFile,
}
}
// UpdateConfig allows runtime configuration updates
func (rl *RateLimiter) UpdateConfig(config *RateLimitConfig) {
rl.config = config
// Update global download limiter
rl.downloadLimit.SetLimit(rate.Limit(config.DownloadRate))
rl.downloadLimit.SetBurst(config.DownloadBurst)
// Note: Existing per-IP and per-file limiters will use old config until they're recreated
// This is acceptable as they'll be cleaned up and recreated with new config over time
}
// WaitHandler wraps a handler to wait for rate limit availability instead of rejecting
func (rl *RateLimiter) WaitHandler(limiterFunc func(*http.Request) *rate.Limiter, timeout time.Duration) func(http.HandlerFunc) http.HandlerFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
limiter := limiterFunc(r)
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
// Wait for rate limit availability
if err := limiter.Wait(ctx); err != nil {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "Rate limit timeout exceeded"}`))
return
}
next(w, r.WithContext(ctx))
}
}
}