package middleware import ( "net/http" "strings" "time" "torrentGateway/internal/stats" ) // ResponseWriter wrapper to capture response data type statsResponseWriter struct { http.ResponseWriter statusCode int bytesWritten int64 } func (w *statsResponseWriter) WriteHeader(code int) { w.statusCode = code w.ResponseWriter.WriteHeader(code) } func (w *statsResponseWriter) Write(data []byte) (int, error) { n, err := w.ResponseWriter.Write(data) w.bytesWritten += int64(n) return n, err } // StatsMiddleware creates middleware that collects performance metrics func StatsMiddleware(statsDB *stats.StatsDB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Wrap response writer to capture metrics wrapper := &statsResponseWriter{ ResponseWriter: w, statusCode: 200, // default status code } // Get request size var bytesReceived int64 if r.ContentLength > 0 { bytesReceived = r.ContentLength } // Process request next.ServeHTTP(wrapper, r) // Calculate metrics duration := time.Since(start) responseTimeMs := float64(duration.Nanoseconds()) / 1e6 // Clean endpoint path for grouping similar endpoints endpoint := cleanEndpointPath(r.URL.Path) // Record performance metrics userAgent := r.Header.Get("User-Agent") if len(userAgent) > 255 { userAgent = userAgent[:255] } ipAddress := getClientIP(r) // Record async to avoid blocking requests go func() { err := statsDB.RecordPerformance( endpoint, r.Method, responseTimeMs, wrapper.statusCode, wrapper.bytesWritten, bytesReceived, userAgent, ipAddress, ) if err != nil { // Log error but don't fail the request // log.Printf("Failed to record performance metrics: %v", err) } }() }) } } // cleanEndpointPath normalizes endpoint paths for better grouping func cleanEndpointPath(path string) string { // Remove file hashes and IDs for grouping parts := strings.Split(path, "/") for i, part := range parts { // Replace 40-character hex strings (file hashes) with placeholder if len(part) == 40 && isHexString(part) { parts[i] = "{hash}" } // Replace UUIDs with placeholder if len(part) == 36 && strings.Count(part, "-") == 4 { parts[i] = "{id}" } // Replace numeric IDs with placeholder if isNumericString(part) { parts[i] = "{id}" } } cleaned := strings.Join(parts, "/") // Limit endpoint path length if len(cleaned) > 100 { cleaned = cleaned[:100] } return cleaned } // isHexString checks if string is hexadecimal func isHexString(s string) bool { for _, c := range s { if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) { return false } } return true } // isNumericString checks if string is numeric func isNumericString(s string) bool { if len(s) == 0 { return false } for _, c := range s { if c < '0' || c > '9' { return false } } return true } // getClientIP extracts client IP from request func getClientIP(r *http.Request) string { // Check X-Forwarded-For header first if xff := r.Header.Get("X-Forwarded-For"); xff != "" { ips := strings.Split(xff, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } // Check X-Real-IP header if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to remote address ip := r.RemoteAddr if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 { ip = ip[:colonIndex] } return ip } // BandwidthTrackingMiddleware creates middleware that tracks bandwidth usage func BandwidthTrackingMiddleware(statsDB *stats.StatsDB, collector *stats.StatsCollector) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Only track bandwidth for specific endpoints if !shouldTrackBandwidth(r.URL.Path) { next.ServeHTTP(w, r) return } wrapper := &statsResponseWriter{ ResponseWriter: w, statusCode: 200, } next.ServeHTTP(wrapper, r) // Extract torrent hash from path if applicable if hash := extractTorrentHash(r.URL.Path); hash != "" { // Record bandwidth usage async go func() { source := determineBandwidthSource(r.URL.Path) err := statsDB.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0, source) if err != nil { // log.Printf("Failed to record bandwidth: %v", err) } // Also update in-memory collector collector.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0) }() } }) } } // shouldTrackBandwidth determines if bandwidth should be tracked for this path func shouldTrackBandwidth(path string) bool { trackPaths := []string{"/download/", "/webseed/", "/stream/", "/torrent/"} for _, trackPath := range trackPaths { if strings.Contains(path, trackPath) { return true } } return false } // extractTorrentHash extracts torrent hash from URL path func extractTorrentHash(path string) string { parts := strings.Split(path, "/") for _, part := range parts { if len(part) == 40 && isHexString(part) { return part } } return "" } // determineBandwidthSource determines the source type from URL path func determineBandwidthSource(path string) string { if strings.Contains(path, "/webseed/") { return "webseed" } if strings.Contains(path, "/stream/") { return "stream" } if strings.Contains(path, "/download/") { return "download" } return "unknown" }