package middleware import ( "net/http" "strings" ) // SecurityHeaders adds comprehensive security headers to responses func SecurityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Content Security Policy csp := strings.Join([]string{ "default-src 'self'", "script-src 'self' 'unsafe-inline'", // Allow inline scripts for our JS "style-src 'self' 'unsafe-inline'", // Allow inline styles for our CSS "img-src 'self' data: https:", // Allow images from self, data URLs, and HTTPS "media-src 'self'", // Media files from self only "font-src 'self'", // Fonts from self only "connect-src 'self' wss: ws:", // Allow WebSocket connections for Nostr "object-src 'none'", // Block objects/embeds "frame-src 'none'", // Block frames "base-uri 'self'", // Base URI restriction "form-action 'self'", // Form submissions to self only }, "; ") w.Header().Set("Content-Security-Policy", csp) // HTTP Strict Transport Security (only if HTTPS) if r.TLS != nil { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") } // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // XSS Protection w.Header().Set("X-XSS-Protection", "1; mode=block") // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Referrer Policy - privacy-focused w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Permissions Policy - restrict potentially dangerous features permissions := strings.Join([]string{ "camera=()", "microphone=()", "geolocation=()", "payment=()", "usb=()", "magnetometer=()", "gyroscope=()", "accelerometer=()", }, ", ") w.Header().Set("Permissions-Policy", permissions) // Remove server information w.Header().Set("Server", "") // Cache control for sensitive endpoints if strings.Contains(r.URL.Path, "/api/users/") || strings.Contains(r.URL.Path, "/api/auth/") || strings.Contains(r.URL.Path, "/admin") { w.Header().Set("Cache-Control", "no-store, no-cache, must-revalidate, private") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "0") } next.ServeHTTP(w, r) }) } // InputSanitization middleware to validate and sanitize inputs func InputSanitization(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Validate Content-Length for uploads if r.Method == http.MethodPost || r.Method == http.MethodPut { if contentLength := r.Header.Get("Content-Length"); contentLength != "" { // Prevent negative or extremely large content lengths if strings.HasPrefix(contentLength, "-") { http.Error(w, "Invalid Content-Length", http.StatusBadRequest) return } } } // Validate and sanitize query parameters query := r.URL.Query() for key, values := range query { for i, value := range values { // Remove null bytes and control characters cleaned := strings.Map(func(r rune) rune { if r == 0 || (r < 32 && r != 9 && r != 10 && r != 13) { return -1 } return r }, value) query[key][i] = cleaned } } r.URL.RawQuery = query.Encode() // Validate User-Agent to prevent empty or suspicious values userAgent := r.Header.Get("User-Agent") if userAgent == "" { r.Header.Set("User-Agent", "unknown") } else if len(userAgent) > 500 { // Truncate extremely long user agents r.Header.Set("User-Agent", userAgent[:500]) } next.ServeHTTP(w, r) }) } // RequestSizeLimit middleware to limit request body size func RequestSizeLimit(maxBytes int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Limit request body size r.Body = http.MaxBytesReader(w, r.Body, maxBytes) next.ServeHTTP(w, r) }) } } // AntiCrawler middleware to discourage automated scraping func AntiCrawler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { userAgent := strings.ToLower(r.Header.Get("User-Agent")) // List of known bot/crawler user agents crawlerPatterns := []string{ "bot", "crawler", "spider", "scraper", "archive", "wget", "curl", "python-requests", "go-http-client", "facebookexternalhit", "twitterbot", "linkedinbot", } for _, pattern := range crawlerPatterns { if strings.Contains(userAgent, pattern) { // For file downloads, return 403 if strings.Contains(r.URL.Path, "/download/") || strings.Contains(r.URL.Path, "/stream/") { http.Error(w, "Automated access not allowed", http.StatusForbidden) return } } } next.ServeHTTP(w, r) }) }