package middleware import ( "context" "fmt" "net/http" "strings" "git.sovbit.dev/enki/torrentGateway/internal/auth" ) // UserContextKey is the key for storing user info in request context type UserContextKey string const UserKey UserContextKey = "user" // AuthMiddleware provides authentication middleware type AuthMiddleware struct { nostrAuth *auth.NostrAuth } // NewAuthMiddleware creates a new authentication middleware func NewAuthMiddleware(nostrAuth *auth.NostrAuth) *AuthMiddleware { return &AuthMiddleware{ nostrAuth: nostrAuth, } } // RequireAuth middleware that requires valid authentication func (am *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pubkey, err := am.extractAndValidateAuth(r) if err != nil { http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // Add user to context ctx := context.WithValue(r.Context(), UserKey, pubkey) next.ServeHTTP(w, r.WithContext(ctx)) }) } // OptionalAuth middleware that extracts auth if present but doesn't require it func (am *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pubkey, _ := am.extractAndValidateAuth(r) if pubkey != "" { ctx := context.WithValue(r.Context(), UserKey, pubkey) r = r.WithContext(ctx) } next.ServeHTTP(w, r) }) } // extractAndValidateAuth extracts and validates authentication from request func (am *AuthMiddleware) extractAndValidateAuth(r *http.Request) (string, error) { // Try to get session token from Authorization header authHeader := r.Header.Get("Authorization") var token string if authHeader != "" { if strings.HasPrefix(authHeader, "Bearer ") { token = strings.TrimPrefix(authHeader, "Bearer ") } } // If not in header, try cookie if token == "" { if cookie, err := r.Cookie("session_token"); err == nil { token = cookie.Value } } if token == "" { return "", fmt.Errorf("no session token found") } // Validate session token pubkey, err := am.nostrAuth.ValidateSession(token) if err != nil { return "", fmt.Errorf("invalid session: %w", err) } return pubkey, nil } // GetUserFromContext extracts user pubkey from request context func GetUserFromContext(ctx context.Context) string { if pubkey, ok := ctx.Value(UserKey).(string); ok { return pubkey } return "" } // IsAuthenticated checks if the request has valid authentication func IsAuthenticated(ctx context.Context) bool { return GetUserFromContext(ctx) != "" } // CORS middleware for handling cross-origin requests func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") w.Header().Set("Access-Control-Allow-Credentials", "true") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusOK) return } next.ServeHTTP(w, r) }) }