230 lines
5.5 KiB
Go
230 lines
5.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
"torrentGateway/internal/stats"
|
|
)
|
|
|
|
// ResponseWriter wrapper to capture response data
|
|
type statsResponseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
bytesWritten int64
|
|
}
|
|
|
|
func (w *statsResponseWriter) WriteHeader(code int) {
|
|
w.statusCode = code
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (w *statsResponseWriter) Write(data []byte) (int, error) {
|
|
n, err := w.ResponseWriter.Write(data)
|
|
w.bytesWritten += int64(n)
|
|
return n, err
|
|
}
|
|
|
|
// StatsMiddleware creates middleware that collects performance metrics
|
|
func StatsMiddleware(statsDB *stats.StatsDB) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
|
|
// Wrap response writer to capture metrics
|
|
wrapper := &statsResponseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: 200, // default status code
|
|
}
|
|
|
|
// Get request size
|
|
var bytesReceived int64
|
|
if r.ContentLength > 0 {
|
|
bytesReceived = r.ContentLength
|
|
}
|
|
|
|
// Process request
|
|
next.ServeHTTP(wrapper, r)
|
|
|
|
// Calculate metrics
|
|
duration := time.Since(start)
|
|
responseTimeMs := float64(duration.Nanoseconds()) / 1e6
|
|
|
|
// Clean endpoint path for grouping similar endpoints
|
|
endpoint := cleanEndpointPath(r.URL.Path)
|
|
|
|
// Record performance metrics
|
|
userAgent := r.Header.Get("User-Agent")
|
|
if len(userAgent) > 255 {
|
|
userAgent = userAgent[:255]
|
|
}
|
|
|
|
ipAddress := getClientIP(r)
|
|
|
|
// Record async to avoid blocking requests
|
|
go func() {
|
|
err := statsDB.RecordPerformance(
|
|
endpoint,
|
|
r.Method,
|
|
responseTimeMs,
|
|
wrapper.statusCode,
|
|
wrapper.bytesWritten,
|
|
bytesReceived,
|
|
userAgent,
|
|
ipAddress,
|
|
)
|
|
if err != nil {
|
|
// Log error but don't fail the request
|
|
// log.Printf("Failed to record performance metrics: %v", err)
|
|
}
|
|
}()
|
|
})
|
|
}
|
|
}
|
|
|
|
// cleanEndpointPath normalizes endpoint paths for better grouping
|
|
func cleanEndpointPath(path string) string {
|
|
// Remove file hashes and IDs for grouping
|
|
parts := strings.Split(path, "/")
|
|
for i, part := range parts {
|
|
// Replace 40-character hex strings (file hashes) with placeholder
|
|
if len(part) == 40 && isHexString(part) {
|
|
parts[i] = "{hash}"
|
|
}
|
|
// Replace UUIDs with placeholder
|
|
if len(part) == 36 && strings.Count(part, "-") == 4 {
|
|
parts[i] = "{id}"
|
|
}
|
|
// Replace numeric IDs with placeholder
|
|
if isNumericString(part) {
|
|
parts[i] = "{id}"
|
|
}
|
|
}
|
|
|
|
cleaned := strings.Join(parts, "/")
|
|
|
|
// Limit endpoint path length
|
|
if len(cleaned) > 100 {
|
|
cleaned = cleaned[:100]
|
|
}
|
|
|
|
return cleaned
|
|
}
|
|
|
|
// isHexString checks if string is hexadecimal
|
|
func isHexString(s string) bool {
|
|
for _, c := range s {
|
|
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// isNumericString checks if string is numeric
|
|
func isNumericString(s string) bool {
|
|
if len(s) == 0 {
|
|
return false
|
|
}
|
|
for _, c := range s {
|
|
if c < '0' || c > '9' {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// getClientIP extracts client IP from request
|
|
func getClientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For header first
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
ips := strings.Split(xff, ",")
|
|
if len(ips) > 0 {
|
|
return strings.TrimSpace(ips[0])
|
|
}
|
|
}
|
|
|
|
// Check X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
|
|
// Fall back to remote address
|
|
ip := r.RemoteAddr
|
|
if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 {
|
|
ip = ip[:colonIndex]
|
|
}
|
|
|
|
return ip
|
|
}
|
|
|
|
// BandwidthTrackingMiddleware creates middleware that tracks bandwidth usage
|
|
func BandwidthTrackingMiddleware(statsDB *stats.StatsDB, collector *stats.StatsCollector) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Only track bandwidth for specific endpoints
|
|
if !shouldTrackBandwidth(r.URL.Path) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
wrapper := &statsResponseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: 200,
|
|
}
|
|
|
|
next.ServeHTTP(wrapper, r)
|
|
|
|
// Extract torrent hash from path if applicable
|
|
if hash := extractTorrentHash(r.URL.Path); hash != "" {
|
|
// Record bandwidth usage async
|
|
go func() {
|
|
source := determineBandwidthSource(r.URL.Path)
|
|
err := statsDB.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0, source)
|
|
if err != nil {
|
|
// log.Printf("Failed to record bandwidth: %v", err)
|
|
}
|
|
|
|
// Also update in-memory collector
|
|
collector.RecordBandwidth(hash, wrapper.bytesWritten, 0, 0)
|
|
}()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// shouldTrackBandwidth determines if bandwidth should be tracked for this path
|
|
func shouldTrackBandwidth(path string) bool {
|
|
trackPaths := []string{"/download/", "/webseed/", "/stream/", "/torrent/"}
|
|
for _, trackPath := range trackPaths {
|
|
if strings.Contains(path, trackPath) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// extractTorrentHash extracts torrent hash from URL path
|
|
func extractTorrentHash(path string) string {
|
|
parts := strings.Split(path, "/")
|
|
for _, part := range parts {
|
|
if len(part) == 40 && isHexString(part) {
|
|
return part
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// determineBandwidthSource determines the source type from URL path
|
|
func determineBandwidthSource(path string) string {
|
|
if strings.Contains(path, "/webseed/") {
|
|
return "webseed"
|
|
}
|
|
if strings.Contains(path, "/stream/") {
|
|
return "stream"
|
|
}
|
|
if strings.Contains(path, "/download/") {
|
|
return "download"
|
|
}
|
|
return "unknown"
|
|
} |