Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / Build Docker Images (push) Blocked by required conditions
CI Pipeline / E2E Tests (push) Blocked by required conditions
339 lines
10 KiB
Go
339 lines
10 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// RateLimiter manages different types of rate limits
|
|
type RateLimiter struct {
|
|
uploadLimiters map[string]*rate.Limiter // Per IP upload limits
|
|
downloadLimit *rate.Limiter // Global download limit
|
|
streamLimiters map[string]*rate.Limiter // Per file stream limits
|
|
uploadMutex sync.RWMutex
|
|
streamMutex sync.RWMutex
|
|
|
|
config *RateLimitConfig
|
|
}
|
|
|
|
// RateLimitConfig configures rate limiting behavior
|
|
type RateLimitConfig struct {
|
|
// Upload limits (per IP)
|
|
UploadRatePerIP float64 // requests per second per IP
|
|
UploadBurstPerIP int // burst size per IP
|
|
|
|
// Global download limits
|
|
DownloadRate float64 // requests per second globally
|
|
DownloadBurst int // global burst size
|
|
|
|
// Stream limits (per file)
|
|
StreamRatePerFile float64 // requests per second per file
|
|
StreamBurstPerFile int // burst size per file
|
|
|
|
// Cleanup settings
|
|
CleanupInterval time.Duration // how often to clean old limiters
|
|
LimiterTTL time.Duration // how long to keep inactive limiters
|
|
}
|
|
|
|
// NewRateLimiter creates a new rate limiter with the given configuration
|
|
func NewRateLimiter(config *RateLimitConfig) *RateLimiter {
|
|
if config == nil {
|
|
config = &RateLimitConfig{
|
|
UploadRatePerIP: 1.0, // 1 upload per second per IP
|
|
UploadBurstPerIP: 5, // burst of 5
|
|
DownloadRate: 50.0, // 50 downloads per second globally
|
|
DownloadBurst: 100, // burst of 100
|
|
StreamRatePerFile: 10.0, // 10 streams per second per file
|
|
StreamBurstPerFile: 20, // burst of 20
|
|
CleanupInterval: 5 * time.Minute,
|
|
LimiterTTL: 15 * time.Minute,
|
|
}
|
|
}
|
|
|
|
// Validate configuration values
|
|
if config.UploadRatePerIP <= 0 {
|
|
config.UploadRatePerIP = 1.0
|
|
}
|
|
if config.UploadBurstPerIP <= 0 {
|
|
config.UploadBurstPerIP = 5
|
|
}
|
|
if config.DownloadRate <= 0 {
|
|
config.DownloadRate = 50.0
|
|
}
|
|
if config.DownloadBurst <= 0 {
|
|
config.DownloadBurst = 100
|
|
}
|
|
if config.StreamRatePerFile <= 0 {
|
|
config.StreamRatePerFile = 10.0
|
|
}
|
|
if config.StreamBurstPerFile <= 0 {
|
|
config.StreamBurstPerFile = 20
|
|
}
|
|
|
|
rl := &RateLimiter{
|
|
uploadLimiters: make(map[string]*rate.Limiter),
|
|
downloadLimit: rate.NewLimiter(rate.Limit(config.DownloadRate), config.DownloadBurst),
|
|
streamLimiters: make(map[string]*rate.Limiter),
|
|
config: config,
|
|
}
|
|
|
|
// Start cleanup routine
|
|
go rl.cleanupRoutine()
|
|
|
|
return rl
|
|
}
|
|
|
|
// UploadMiddleware applies per-IP upload rate limiting
|
|
func (rl *RateLimiter) UploadMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract IP address
|
|
ip := rl.getClientIP(r)
|
|
|
|
// Get or create limiter for this IP
|
|
limiter := rl.getUploadLimiter(ip)
|
|
|
|
// Check rate limit
|
|
if !limiter.Allow() {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP))
|
|
w.Header().Set("X-RateLimit-Remaining", "0")
|
|
w.Header().Set("Retry-After", "60")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error": "Upload rate limit exceeded. Please try again later."}`))
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.UploadBurstPerIP))
|
|
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens())))
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// DownloadMiddleware applies global download rate limiting
|
|
func (rl *RateLimiter) DownloadMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Check global download rate limit
|
|
if !rl.downloadLimit.Allow() {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst))
|
|
w.Header().Set("X-RateLimit-Remaining", "0")
|
|
w.Header().Set("Retry-After", "10")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error": "Global download rate limit exceeded. Please try again later."}`))
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.DownloadBurst))
|
|
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(rl.downloadLimit.Tokens())))
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// StreamMiddleware applies per-file stream rate limiting
|
|
func (rl *RateLimiter) StreamMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
// Extract file hash from URL path or query parameters
|
|
fileHash := rl.extractFileHash(r)
|
|
if fileHash == "" {
|
|
// No file hash found, proceed without stream limiting
|
|
next(w, r)
|
|
return
|
|
}
|
|
|
|
// Get or create limiter for this file
|
|
limiter := rl.getStreamLimiter(fileHash)
|
|
|
|
// Check rate limit
|
|
if !limiter.Allow() {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile))
|
|
w.Header().Set("X-RateLimit-Remaining", "0")
|
|
w.Header().Set("Retry-After", "5")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error": "Stream rate limit exceeded for this file. Please try again later."}`))
|
|
return
|
|
}
|
|
|
|
// Add rate limit headers
|
|
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.config.StreamBurstPerFile))
|
|
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(int(limiter.Tokens())))
|
|
|
|
next(w, r)
|
|
}
|
|
}
|
|
|
|
// getUploadLimiter gets or creates a rate limiter for the given IP
|
|
func (rl *RateLimiter) getUploadLimiter(ip string) *rate.Limiter {
|
|
rl.uploadMutex.Lock()
|
|
defer rl.uploadMutex.Unlock()
|
|
|
|
limiter, exists := rl.uploadLimiters[ip]
|
|
if !exists {
|
|
limiter = rate.NewLimiter(rate.Limit(rl.config.UploadRatePerIP), rl.config.UploadBurstPerIP)
|
|
rl.uploadLimiters[ip] = limiter
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
// getStreamLimiter gets or creates a rate limiter for the given file
|
|
func (rl *RateLimiter) getStreamLimiter(fileHash string) *rate.Limiter {
|
|
rl.streamMutex.Lock()
|
|
defer rl.streamMutex.Unlock()
|
|
|
|
limiter, exists := rl.streamLimiters[fileHash]
|
|
if !exists {
|
|
limiter = rate.NewLimiter(rate.Limit(rl.config.StreamRatePerFile), rl.config.StreamBurstPerFile)
|
|
rl.streamLimiters[fileHash] = limiter
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
// getClientIP extracts the client IP address from the request
|
|
func (rl *RateLimiter) getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first (for proxy setups)
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the chain
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
ip := r.RemoteAddr
|
|
if idx := strings.LastIndex(ip, ":"); idx != -1 {
|
|
ip = ip[:idx]
|
|
}
|
|
|
|
return ip
|
|
}
|
|
|
|
// extractFileHash extracts file hash from request URL or parameters
|
|
func (rl *RateLimiter) extractFileHash(r *http.Request) string {
|
|
// Try URL path first (e.g., /api/files/{hash}/stream)
|
|
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
|
for i, part := range pathParts {
|
|
if part == "files" && i+1 < len(pathParts) {
|
|
return pathParts[i+1]
|
|
}
|
|
}
|
|
|
|
// Try query parameters
|
|
if hash := r.URL.Query().Get("hash"); hash != "" {
|
|
return hash
|
|
}
|
|
if hash := r.URL.Query().Get("file_hash"); hash != "" {
|
|
return hash
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// cleanupRoutine periodically removes inactive rate limiters to prevent memory leaks
|
|
func (rl *RateLimiter) cleanupRoutine() {
|
|
ticker := time.NewTicker(rl.config.CleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
rl.cleanup()
|
|
}
|
|
}
|
|
|
|
// cleanup removes inactive rate limiters
|
|
func (rl *RateLimiter) cleanup() {
|
|
// Clean upload limiters
|
|
rl.uploadMutex.Lock()
|
|
for ip, limiter := range rl.uploadLimiters {
|
|
// Check if limiter hasn't been used recently
|
|
if limiter.Tokens() >= float64(rl.config.UploadBurstPerIP) {
|
|
// Remove if it's been inactive
|
|
delete(rl.uploadLimiters, ip)
|
|
}
|
|
}
|
|
rl.uploadMutex.Unlock()
|
|
|
|
// Clean stream limiters
|
|
rl.streamMutex.Lock()
|
|
for fileHash, limiter := range rl.streamLimiters {
|
|
// Check if limiter hasn't been used recently
|
|
if limiter.Tokens() >= float64(rl.config.StreamBurstPerFile) {
|
|
// Remove if it's been inactive
|
|
delete(rl.streamLimiters, fileHash)
|
|
}
|
|
}
|
|
rl.streamMutex.Unlock()
|
|
}
|
|
|
|
// GetStats returns current rate limiting statistics
|
|
func (rl *RateLimiter) GetStats() map[string]interface{} {
|
|
rl.uploadMutex.RLock()
|
|
uploadLimiterCount := len(rl.uploadLimiters)
|
|
rl.uploadMutex.RUnlock()
|
|
|
|
rl.streamMutex.RLock()
|
|
streamLimiterCount := len(rl.streamLimiters)
|
|
rl.streamMutex.RUnlock()
|
|
|
|
return map[string]interface{}{
|
|
"upload_limiters": uploadLimiterCount,
|
|
"stream_limiters": streamLimiterCount,
|
|
"download_tokens": rl.downloadLimit.Tokens(),
|
|
"upload_rate_per_ip": rl.config.UploadRatePerIP,
|
|
"upload_burst_per_ip": rl.config.UploadBurstPerIP,
|
|
"download_rate": rl.config.DownloadRate,
|
|
"download_burst": rl.config.DownloadBurst,
|
|
"stream_rate_per_file": rl.config.StreamRatePerFile,
|
|
"stream_burst_per_file": rl.config.StreamBurstPerFile,
|
|
}
|
|
}
|
|
|
|
// UpdateConfig allows runtime configuration updates
|
|
func (rl *RateLimiter) UpdateConfig(config *RateLimitConfig) {
|
|
rl.config = config
|
|
|
|
// Update global download limiter
|
|
rl.downloadLimit.SetLimit(rate.Limit(config.DownloadRate))
|
|
rl.downloadLimit.SetBurst(config.DownloadBurst)
|
|
|
|
// Note: Existing per-IP and per-file limiters will use old config until they're recreated
|
|
// This is acceptable as they'll be cleaned up and recreated with new config over time
|
|
}
|
|
|
|
// WaitHandler wraps a handler to wait for rate limit availability instead of rejecting
|
|
func (rl *RateLimiter) WaitHandler(limiterFunc func(*http.Request) *rate.Limiter, timeout time.Duration) func(http.HandlerFunc) http.HandlerFunc {
|
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
limiter := limiterFunc(r)
|
|
|
|
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
|
defer cancel()
|
|
|
|
// Wait for rate limit availability
|
|
if err := limiter.Wait(ctx); err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
w.Write([]byte(`{"error": "Rate limit timeout exceeded"}`))
|
|
return
|
|
}
|
|
|
|
next(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
} |