diff --git a/internal/dht/node.go b/internal/dht/node.go index 79a64cf..16de262 100644 --- a/internal/dht/node.go +++ b/internal/dht/node.go @@ -1,6 +1,7 @@ package dht import ( + "context" "crypto/rand" "crypto/sha1" "database/sql" @@ -2217,41 +2218,67 @@ func (d *DHT) queryNodesParallel(candidates []*Node, targetID NodeID, queryType resultChan := make(chan queryResult, len(candidates)) var wg sync.WaitGroup + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() // Launch parallel queries for _, candidate := range candidates { wg.Add(1) go func(node *Node) { - defer wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf("DHT query panic recovered: %v", r) + } + wg.Done() + }() key := fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port) queried[key] = true var discoveredNodes []*Node - if queryType == "find_node" { - nodes, err := d.FindNode(targetID, node.Addr) - if err == nil { - discoveredNodes = nodes + select { + case <-ctx.Done(): + return + default: + if queryType == "find_node" { + nodes, err := d.FindNode(targetID, node.Addr) + if err == nil { + discoveredNodes = nodes + } + } else if queryType == "get_peers" { + _, nodes, err := d.GetPeers(targetID[:], node.Addr) + if err == nil && nodes != nil { + discoveredNodes = nodes + } } - } else if queryType == "get_peers" { - _, nodes, err := d.GetPeers(targetID[:], node.Addr) - if err == nil && nodes != nil { - discoveredNodes = nodes + + select { + case resultChan <- queryResult{nodes: discoveredNodes, node: node}: + case <-ctx.Done(): + return } } - - resultChan <- queryResult{nodes: discoveredNodes, node: node} }(candidate) } - // Close channel when all queries complete + // Close channel when all queries complete or timeout go func() { - wg.Wait() + done := make(chan struct{}) + go func() { + defer close(done) + wg.Wait() + }() + + select { + case <-done: + case <-ctx.Done(): + log.Printf("DHT parallel query timeout after 30s") + } close(resultChan) }() - // Collect results + // Collect results with timeout var allNodes []*Node for result := range resultChan { if len(result.nodes) > 0 { diff --git a/internal/p2p/gateway.go b/internal/p2p/gateway.go index 99a0bb8..1d2ad4a 100644 --- a/internal/p2p/gateway.go +++ b/internal/p2p/gateway.go @@ -107,6 +107,11 @@ func NewUnifiedP2PGateway(config *config.Config, db *sql.DB) *UnifiedP2PGateway func (g *UnifiedP2PGateway) Initialize() error { log.Printf("P2P Gateway: Initializing unified P2P system") + // Create database tables + if err := g.CreateP2PTables(); err != nil { + return fmt.Errorf("failed to create P2P database tables: %w", err) + } + // Initialize tracker if err := g.initializeTracker(); err != nil { return fmt.Errorf("failed to initialize tracker: %w", err) diff --git a/torrent-gateway b/torrent-gateway index 94554eb..2f40dd0 100755 Binary files a/torrent-gateway and b/torrent-gateway differ