enki 7c92aa3ded
Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / E2E Tests (push) Blocked by required conditions
Major DHT and Torrent fixes.
2025-08-29 21:18:36 -07:00

230 lines
5.5 KiB
Go

package middleware
import (
"net/http"
"strings"
"time"
"torrentGateway/internal/stats"
)
// ResponseWriter wrapper to capture response data
type statsResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int64
}
func (w *statsResponseWriter) WriteHeader(code int) {
w.statusCode = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statsResponseWriter) Write(data []byte) (int, error) {
n, err := w.ResponseWriter.Write(data)
w.bytesWritten += int64(n)
return n, err
}
// StatsMiddleware creates middleware that collects performance metrics
func StatsMiddleware(statsDB *stats.StatsDB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Wrap response writer to capture metrics
wrapper := &statsResponseWriter{
ResponseWriter: w,
statusCode: 200, // default status code
}
// Get request size
var bytesReceived int64
if r.ContentLength > 0 {
bytesReceived = r.ContentLength
}
// Process request
next.ServeHTTP(wrapper, r)
// Calculate metrics
duration := time.Since(start)
responseTimeMs := float64(duration.Nanoseconds()) / 1e6
// Clean endpoint path for grouping similar endpoints
endpoint := cleanEndpointPath(r.URL.Path)
// Record performance metrics
userAgent := r.Header.Get("User-Agent")
if len(userAgent) > 255 {
userAgent = userAgent[:255]
}
ipAddress := getClientIP(r)
// Record async to avoid blocking requests
go func() {
err := statsDB.RecordPerformance(
endpoint,
r.Method,
responseTimeMs,
wrapper.statusCode,
wrapper.bytesWritten,
bytesReceived,
userAgent,
ipAddress,
)
if err != nil {
// Log error but don't fail the request
// log.Printf("Failed to record performance metrics: %v", err)
}
}()
})
}
}
// cleanEndpointPath normalizes endpoint paths for better grouping
func cleanEndpointPath(path string) string {
// Remove file hashes and IDs for grouping
parts := strings.Split(path, "/")
for i, part := range parts {
// Replace 40-character hex strings (file hashes) with placeholder
if len(part) == 40 && isHexString(part) {
parts[i] = "{hash}"
}
// Replace UUIDs with placeholder
if len(part) == 36 && strings.Count(part, "-") == 4 {
parts[i] = "{id}"
}
// Replace numeric IDs with placeholder
if isNumericString(part) {
parts[i] = "{id}"
}
}
cleaned := strings.Join(parts, "/")
// Limit endpoint path length
if len(cleaned) > 100 {
cleaned = cleaned[:100]
}
return cleaned
}
// isHexString checks if string is hexadecimal
func isHexString(s string) bool {
for _, c := range s {
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
return false
}
}
return true
}
// isNumericString checks if string is numeric
func isNumericString(s string) bool {
if len(s) == 0 {
return false
}
for _, c := range s {
if c < '0' || c > '9' {
return false
}
}
return true
}
// getClientIP extracts client IP from request
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
ips := strings.Split(xff, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// Check X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// Fall back to remote address
ip := r.RemoteAddr
if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 {
ip = ip[:colonIndex]
}
return ip
}
// BandwidthTrackingMiddleware creates middleware that tracks bandwidth usage
func BandwidthTrackingMiddleware(statsDB *stats.StatsDB, collector *stats.StatsCollector) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Only track bandwidth for specific endpoints
if !shouldTrackBandwidth(r.URL.Path) {
next.ServeHTTP(w, r)
return
}
wrapper := &statsResponseWriter{
ResponseWriter: w,
statusCode: 200,
}
next.ServeHTTP(wrapper, r)
// Extract torrent hash from path if applicable
if hash := extractTorrentHash(r.URL.Path); hash != "" {
// Record bandwidth usage async
go func() {
source := determineBandwidthSource(r.URL.Path)
err := statsDB.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0, source)
if err != nil {
// log.Printf("Failed to record bandwidth: %v", err)
}
// Also update in-memory collector
collector.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0)
}()
}
})
}
}
// shouldTrackBandwidth determines if bandwidth should be tracked for this path
func shouldTrackBandwidth(path string) bool {
trackPaths := []string{"/download/", "/webseed/", "/stream/", "/torrent/"}
for _, trackPath := range trackPaths {
if strings.Contains(path, trackPath) {
return true
}
}
return false
}
// extractTorrentHash extracts torrent hash from URL path
func extractTorrentHash(path string) string {
parts := strings.Split(path, "/")
for _, part := range parts {
if len(part) == 40 && isHexString(part) {
return part
}
}
return ""
}
// determineBandwidthSource determines the source type from URL path
func determineBandwidthSource(path string) string {
if strings.Contains(path, "/webseed/") {
return "webseed"
}
if strings.Contains(path, "/stream/") {
return "stream"
}
if strings.Contains(path, "/download/") {
return "download"
}
return "unknown"
}