enki 2b7f44b2d7
Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / E2E Tests (push) Blocked by required conditions
DHT Deadlockfix
2025-09-02 19:29:45 -07:00

2920 lines
78 KiB
Go

package dht
import (
"context"
"crypto/rand"
"crypto/sha1"
"database/sql"
"encoding/binary"
"encoding/hex"
"fmt"
"log"
"net"
"sync"
"time"
"torrentGateway/internal/config"
"github.com/anacrolix/torrent/bencode"
)
const (
// Node ID length in bytes (160 bits for SHA-1)
NodeIDLength = 20
// K-bucket size
BucketSize = 8
// Number of buckets (160 bits = 160 buckets max)
NumBuckets = 160
// DHT protocol constants
Alpha = 3 // Parallelism parameter
// BEP-5 DHT protocol constants
TransactionIDLength = 2
TokenLength = 8
MaxNodesReturn = 8
// Message types
QueryMessage = "q"
ResponseMessage = "r"
ErrorMessage = "e"
// Storage limits
MaxPeersPerTorrent = 500
PeerExpiryMinutes = 30
TokenRotationMinutes = 5
MaxResponsePeers = 50
MaxResponseNodes = 8
// Rate limiting
MaxAnnouncesPerIPPerHour = 100
MaxStoragePerIP = 50
BlocklistDurationMinutes = 60
// Peer verification
PeerVerificationTimeout = 5 * time.Second
MinReliabilityScore = 20
MaxReliabilityScore = 100
)
// NodeID represents a 160-bit node identifier
type NodeID [NodeIDLength]byte
// NodeHealth represents the health state of a node
type NodeHealth int
const (
NodeGood NodeHealth = iota // Responded recently
NodeQuestionable // Old, needs verification
NodeBad // Failed to respond, should be removed
)
// Node represents a DHT node
type Node struct {
ID NodeID
Addr *net.UDPAddr
LastSeen time.Time
LastQuery time.Time // When we last queried this node
Health NodeHealth
FailedQueries int // Count of consecutive failed queries
RTT time.Duration // Round trip time for last successful query
}
// Bucket represents a k-bucket in the routing table
type Bucket struct {
mu sync.RWMutex
nodes []*Node
lastChanged time.Time // When bucket was last modified
lastRefresh time.Time // When bucket was last refreshed
}
// RoutingTable implements Kademlia routing table
type RoutingTable struct {
mu sync.RWMutex
selfID NodeID
buckets [NumBuckets]*Bucket
}
// PendingQuery tracks outgoing queries
type PendingQuery struct {
TransactionID string
QueryType string
Target *net.UDPAddr
SentTime time.Time
Retries int
ResponseChan chan *DHTMessage
}
// RateLimiter tracks announce rates per IP
type RateLimiter struct {
mu sync.RWMutex
announces map[string][]time.Time // IP -> announce times
storage map[string]int // IP -> peer count
blocklist map[string]time.Time // IP -> unblock time
}
// DHT represents a DHT node with enhanced peer storage
type DHT struct {
config *config.DHTConfig
nodeID NodeID
routingTable *RoutingTable
conn *net.UDPConn
storage map[string][]byte // Simple in-memory storage for demo
storageMu sync.RWMutex
peerStorage *PeerStorage // Enhanced peer storage with database backing
// Channels for communication
stopCh chan struct{}
announceQueue chan AnnounceRequest
verificationQueue chan *PeerVerificationRequest
// Transaction tracking
pendingQueries map[string]*PendingQuery
pendingMu sync.RWMutex
transactionCounter uint16
// Token generation with rotation
tokenSecret []byte
previousSecret []byte
lastSecretRotation time.Time
tokenMu sync.RWMutex
// Rate limiting and security
rateLimiter *RateLimiter
peerVerification bool // Whether to verify peers
// Statistics
stats Stats
statsMu sync.RWMutex
}
// PeerVerificationRequest represents a peer verification request
type PeerVerificationRequest struct {
InfoHash string
Peer *Peer
Retries int
}
// Stats tracks DHT statistics
type Stats struct {
PacketsSent int64
PacketsReceived int64
NodesInTable int
StoredItems int
}
// AnnounceRequest represents a request to announce a torrent
type AnnounceRequest struct {
InfoHash string
Port int
}
// Peer represents a BitTorrent peer with enhanced metadata
type Peer struct {
IP net.IP
Port int
PeerID string // Peer ID if available
Added time.Time // When this peer was added
LastAnnounced time.Time // When this peer last announced
Source string // How we got this peer ("announce", "dht", etc.)
ReliabilityScore int // Reliability score (0-100)
AnnounceCount int // Number of successful announces
Verified bool // Whether peer has been verified
LastVerified time.Time // When peer was last verified
}
// PeerList holds peers for a specific infohash
type PeerList struct {
mu sync.RWMutex
peers map[string]*Peer // Key is "IP:Port"
}
// DatabaseInterface defines the database operations for peer storage
type DatabaseInterface interface {
StorePeer(infohash, ip string, port int, peerID string, lastAnnounced time.Time) error
GetPeersForTorrent(infohash string) ([]*StoredPeer, error)
CleanExpiredPeers() error
GetPeerCount(infohash string) (int, error)
UpdatePeerReliability(infohash, ip string, port int, reliabilityScore int, verified bool, lastVerified time.Time) error
}
// StoredPeer represents a peer stored in the database
type StoredPeer struct {
InfoHash string
IP string
Port int
PeerID string
LastAnnounced time.Time
ReliabilityScore int
AnnounceCount int
Verified bool
LastVerified time.Time
}
// PeerStorage manages peer storage with expiration and database backing
type PeerStorage struct {
mu sync.RWMutex
infohashes map[string]*PeerList // Key is infohash (in-memory cache)
db DatabaseInterface // Database backing
stats PeerStorageStats
}
// PeerStorageStats tracks storage statistics
type PeerStorageStats struct {
TotalPeers int64
TotalTorrents int64
CacheHits int64
CacheMisses int64
DatabaseWrites int64
DatabaseReads int64
}
// DHT message types
// DHTMessage represents the base structure of a DHT message
type DHTMessage struct {
TransactionID string `bencode:"t"`
MessageType string `bencode:"y"`
QueryType string `bencode:"q,omitempty"`
Response map[string]interface{} `bencode:"r,omitempty"`
Arguments map[string]interface{} `bencode:"a,omitempty"`
Error []interface{} `bencode:"e,omitempty"`
}
// PingQuery represents a ping query
type PingQuery struct {
ID string `bencode:"id"`
}
// PingResponse represents a ping response
type PingResponse struct {
ID string `bencode:"id"`
}
// FindNodeQuery represents a find_node query
type FindNodeQuery struct {
ID string `bencode:"id"`
Target string `bencode:"target"`
}
// FindNodeResponse represents a find_node response
type FindNodeResponse struct {
ID string `bencode:"id"`
Nodes string `bencode:"nodes"`
}
// GetPeersQuery represents a get_peers query
type GetPeersQuery struct {
ID string `bencode:"id"`
InfoHash string `bencode:"info_hash"`
}
// GetPeersResponse represents a get_peers response
type GetPeersResponse struct {
ID string `bencode:"id"`
Token string `bencode:"token"`
Values []string `bencode:"values,omitempty"`
Nodes string `bencode:"nodes,omitempty"`
}
// AnnouncePeerQuery represents an announce_peer query
type AnnouncePeerQuery struct {
ID string `bencode:"id"`
InfoHash string `bencode:"info_hash"`
Port int `bencode:"port"`
Token string `bencode:"token"`
}
// AnnouncePeerResponse represents an announce_peer response
type AnnouncePeerResponse struct {
ID string `bencode:"id"`
}
const (
MsgQuery = "q"
MsgResponse = "r"
MsgError = "e"
)
// NewDHT creates a new DHT node
func NewDHT(config *config.DHTConfig) (*DHT, error) {
// Generate random node ID
var nodeID NodeID
if _, err := rand.Read(nodeID[:]); err != nil {
return nil, fmt.Errorf("failed to generate node ID: %w", err)
}
// Generate token secret
tokenSecret := make([]byte, 32)
if _, err := rand.Read(tokenSecret); err != nil {
return nil, fmt.Errorf("failed to generate token secret: %w", err)
}
dht := &DHT{
config: config,
nodeID: nodeID,
routingTable: NewRoutingTable(nodeID),
storage: make(map[string][]byte),
stopCh: make(chan struct{}),
announceQueue: make(chan AnnounceRequest, 100),
verificationQueue: make(chan *PeerVerificationRequest, 100),
pendingQueries: make(map[string]*PendingQuery),
tokenSecret: tokenSecret,
previousSecret: make([]byte, 32), // Initialize empty previous secret
lastSecretRotation: time.Now(),
peerStorage: NewPeerStorage(nil), // Database will be set later
rateLimiter: NewRateLimiter(),
peerVerification: true, // Enable peer verification by default
}
return dht, nil
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter() *RateLimiter {
return &RateLimiter{
announces: make(map[string][]time.Time),
storage: make(map[string]int),
blocklist: make(map[string]time.Time),
}
}
// IsBlocked checks if an IP is currently blocked
func (rl *RateLimiter) IsBlocked(ip string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
if unblockTime, exists := rl.blocklist[ip]; exists {
if time.Now().Before(unblockTime) {
return true
}
// Unblock expired entries
rl.mu.RUnlock()
rl.mu.Lock()
delete(rl.blocklist, ip)
rl.mu.Unlock()
rl.mu.RLock()
}
return false
}
// CheckAnnounceRate checks if an IP can announce (not rate limited)
func (rl *RateLimiter) CheckAnnounceRate(ip string) bool {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-time.Hour)
// Clean old announces
if announces, exists := rl.announces[ip]; exists {
var recent []time.Time
for _, t := range announces {
if t.After(cutoff) {
recent = append(recent, t)
}
}
rl.announces[ip] = recent
// Check rate limit
if len(recent) >= MaxAnnouncesPerIPPerHour {
// Block this IP
rl.blocklist[ip] = now.Add(BlocklistDurationMinutes * time.Minute)
log.Printf("DHT: Rate limited IP %s (%d announces in last hour)", ip, len(recent))
return false
}
}
// Record this announce
rl.announces[ip] = append(rl.announces[ip], now)
return true
}
// CheckStorageLimit checks if an IP can store more peers
func (rl *RateLimiter) CheckStorageLimit(ip string) bool {
rl.mu.RLock()
defer rl.mu.RUnlock()
if count, exists := rl.storage[ip]; exists {
return count < MaxStoragePerIP
}
return true
}
// IncrementStorage increments the storage count for an IP
func (rl *RateLimiter) IncrementStorage(ip string) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.storage[ip]++
}
// NewPeerStorage creates a new peer storage with optional database backing
func NewPeerStorage(db DatabaseInterface) *PeerStorage {
return &PeerStorage{
infohashes: make(map[string]*PeerList),
db: db,
stats: PeerStorageStats{},
}
}
// validateTokenForTimeWindow validates a token for a specific time window and secret
func (d *DHT) validateTokenForTimeWindow(token string, addr *net.UDPAddr, timeWindow int64, secret []byte) bool {
h := sha1.New()
h.Write(secret)
h.Write(addr.IP)
binary.Write(h, binary.BigEndian, uint16(addr.Port))
binary.Write(h, binary.BigEndian, timeWindow)
expected := fmt.Sprintf("%x", h.Sum(nil)[:TokenLength])
return token == expected
}
// rotateTokenSecretIfNeeded rotates the token secret every 5 minutes
func (d *DHT) rotateTokenSecretIfNeeded() {
if time.Since(d.lastSecretRotation) >= TokenRotationMinutes*time.Minute {
// Move current secret to previous
copy(d.previousSecret, d.tokenSecret)
// Generate new secret
if _, err := rand.Read(d.tokenSecret); err != nil {
log.Printf("DHT: Failed to rotate token secret: %v", err)
return
}
d.lastSecretRotation = time.Now()
log.Printf("DHT: Token secret rotated")
}
}
// SetDatabase sets the database interface for peer storage
func (d *DHT) SetDatabase(db DatabaseInterface) {
d.peerStorage.db = db
log.Printf("DHT: Database interface configured")
}
// AddPeer adds a peer to the storage with database persistence
func (ps *PeerStorage) AddPeer(infohash string, peer *Peer) error {
ps.mu.Lock()
defer ps.mu.Unlock()
// Check if we're at the limit for this torrent
if ps.db != nil {
count, err := ps.db.GetPeerCount(infohash)
if err == nil && count >= MaxPeersPerTorrent {
log.Printf("DHT: Max peers reached for torrent %s (%d peers)", infohash[:8], count)
return fmt.Errorf("max peers reached for torrent")
}
}
// Update in-memory cache
peerList, exists := ps.infohashes[infohash]
if !exists {
peerList = &PeerList{
peers: make(map[string]*Peer),
}
ps.infohashes[infohash] = peerList
}
peerList.mu.Lock()
key := fmt.Sprintf("%s:%d", peer.IP.String(), peer.Port)
now := time.Now()
peer.Added = now
peer.LastAnnounced = now
peerList.peers[key] = peer
peerList.mu.Unlock()
// Store in database if available
if ps.db != nil {
err := ps.db.StorePeer(infohash, peer.IP.String(), peer.Port, peer.PeerID, peer.LastAnnounced)
if err != nil {
log.Printf("DHT: Failed to store peer in database: %v", err)
return err
}
ps.stats.DatabaseWrites++
}
ps.stats.TotalPeers++
return nil
}
// GetPeers gets peers for an infohash with database fallback
func (ps *PeerStorage) GetPeers(infohash string) []*Peer {
ps.mu.RLock()
peerList, exists := ps.infohashes[infohash]
ps.mu.RUnlock()
// Try cache first
if exists {
peerList.mu.RLock()
if len(peerList.peers) > 0 {
var peers []*Peer
for _, peer := range peerList.peers {
// Filter expired peers
if time.Since(peer.LastAnnounced) <= PeerExpiryMinutes*time.Minute {
peers = append(peers, peer)
}
}
peerList.mu.RUnlock()
ps.stats.CacheHits++
return peers
}
peerList.mu.RUnlock()
}
// Cache miss - try database
if ps.db != nil {
storedPeers, err := ps.db.GetPeersForTorrent(infohash)
if err != nil {
log.Printf("DHT: Failed to get peers from database: %v", err)
return nil
}
// Convert stored peers to Peer objects
var peers []*Peer
for _, sp := range storedPeers {
peer := &Peer{
IP: net.ParseIP(sp.IP),
Port: sp.Port,
PeerID: sp.PeerID,
Added: sp.LastAnnounced,
LastAnnounced: sp.LastAnnounced,
Source: "database",
}
peers = append(peers, peer)
}
ps.stats.DatabaseReads++
ps.stats.CacheMisses++
return peers
}
ps.stats.CacheMisses++
return nil
}
// CleanExpiredPeers removes peers older than configured expiry time
func (ps *PeerStorage) CleanExpiredPeers() {
cutoff := time.Now().Add(-PeerExpiryMinutes * time.Minute)
// Clean database first if available
if ps.db != nil {
if err := ps.db.CleanExpiredPeers(); err != nil {
log.Printf("DHT: Failed to clean expired peers from database: %v", err)
}
}
// Clean in-memory cache
ps.mu.Lock()
defer ps.mu.Unlock()
expiredCount := 0
for infohash, peerList := range ps.infohashes {
peerList.mu.Lock()
activePeers := make(map[string]*Peer)
for key, peer := range peerList.peers {
if peer.LastAnnounced.After(cutoff) {
activePeers[key] = peer
} else {
expiredCount++
}
}
peerList.peers = activePeers
peerList.mu.Unlock()
// Remove empty peer lists
if len(activePeers) == 0 {
delete(ps.infohashes, infohash)
}
}
if expiredCount > 0 {
log.Printf("DHT: Cleaned %d expired peers from cache", expiredCount)
ps.stats.TotalPeers -= int64(expiredCount)
}
}
// NewRoutingTable creates a new routing table
func NewRoutingTable(selfID NodeID) *RoutingTable {
rt := &RoutingTable{
selfID: selfID,
}
// Initialize buckets
now := time.Now()
for i := range rt.buckets {
rt.buckets[i] = &Bucket{
nodes: make([]*Node, 0, BucketSize),
lastChanged: now,
lastRefresh: now,
}
}
return rt
}
// Start starts the DHT node
func (d *DHT) Start() error {
// Listen on UDP port
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.config.Port))
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP: %w", err)
}
d.conn = conn
log.Printf("DHT node started on port %d with ID %x", d.config.Port, d.nodeID)
// Start goroutines
go d.handlePackets()
go d.bootstrap()
go d.maintenance()
go d.handleAnnouncements()
go d.handlePeerVerification()
return nil
}
// Stop stops the DHT node
func (d *DHT) Stop() error {
close(d.stopCh)
if d.conn != nil {
return d.conn.Close()
}
return nil
}
// handlePackets handles incoming UDP packets
func (d *DHT) handlePackets() {
buffer := make([]byte, 2048)
for {
select {
case <-d.stopCh:
return
default:
}
d.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, addr, err := d.conn.ReadFromUDP(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
log.Printf("Error reading UDP packet: %v", err)
continue
}
d.statsMu.Lock()
d.stats.PacketsReceived++
d.statsMu.Unlock()
// Parse and handle DHT message
go d.handlePacket(buffer[:n], addr)
}
}
// handlePacket processes a single packet
func (d *DHT) handlePacket(data []byte, addr *net.UDPAddr) {
var msg DHTMessage
if err := bencode.Unmarshal(data, &msg); err != nil {
log.Printf("Failed to unmarshal DHT message from %s: %v", addr, err)
return
}
switch msg.MessageType {
case QueryMessage:
d.handleQuery(&msg, addr)
case ResponseMessage:
d.handleResponse(&msg, addr)
case ErrorMessage:
d.handleError(&msg, addr)
default:
log.Printf("Unknown message type '%s' from %s", msg.MessageType, addr)
}
}
// handleQuery handles incoming query messages
func (d *DHT) handleQuery(msg *DHTMessage, addr *net.UDPAddr) {
switch msg.QueryType {
case "ping":
d.handlePingQuery(msg, addr)
case "find_node":
d.handleFindNodeQuery(msg, addr)
case "get_peers":
d.handleGetPeersQuery(msg, addr)
case "announce_peer":
d.handleAnnouncePeerQuery(msg, addr)
default:
d.sendError(msg.TransactionID, addr, 204, "Method Unknown")
}
}
// handleResponse handles incoming response messages
func (d *DHT) handleResponse(msg *DHTMessage, addr *net.UDPAddr) {
d.pendingMu.RLock()
pendingQuery, exists := d.pendingQueries[msg.TransactionID]
d.pendingMu.RUnlock()
if !exists {
log.Printf("Received response for unknown transaction %s from %s", msg.TransactionID, addr)
return
}
// Send response to waiting goroutine
select {
case pendingQuery.ResponseChan <- msg:
default:
// Channel might be closed
}
// Update routing table with responding node
if nodeIDBytes, ok := msg.Response["id"].(string); ok {
if len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(node)
}
}
}
// handleError handles incoming error messages
func (d *DHT) handleError(msg *DHTMessage, addr *net.UDPAddr) {
log.Printf("Received error from %s: %v", addr, msg.Error)
d.pendingMu.RLock()
pendingQuery, exists := d.pendingQueries[msg.TransactionID]
d.pendingMu.RUnlock()
if exists {
select {
case pendingQuery.ResponseChan <- msg:
default:
// Channel might be closed
}
}
}
// bootstrap connects to bootstrap nodes and builds routing table
func (d *DHT) bootstrap() {
log.Printf("Starting DHT bootstrap with %d nodes", len(d.config.BootstrapNodes))
// First ping all bootstrap nodes to establish initial connections
var activeBootstrapNodes []*net.UDPAddr
for _, bootstrapAddr := range d.config.BootstrapNodes {
addr, err := net.ResolveUDPAddr("udp", bootstrapAddr)
if err != nil {
log.Printf("Failed to resolve bootstrap node %s: %v", bootstrapAddr, err)
continue
}
// Ping bootstrap node
if _, err := d.Ping(addr); err == nil {
activeBootstrapNodes = append(activeBootstrapNodes, addr)
log.Printf("Connected to bootstrap node: %s", addr)
} else {
log.Printf("Failed to ping bootstrap node %s: %v", addr, err)
}
time.Sleep(500 * time.Millisecond)
}
if len(activeBootstrapNodes) == 0 {
log.Printf("Warning: No bootstrap nodes responded")
return
}
// Perform iterative node lookup for our own ID to find neighbors
log.Printf("Performing bootstrap lookup for own node ID")
closestNodes := d.iterativeFindNode(d.nodeID, activeBootstrapNodes)
// Add all discovered nodes to routing table
for _, node := range closestNodes {
d.routingTable.AddNode(node)
}
log.Printf("Bootstrap completed: discovered %d nodes", len(closestNodes))
}
// generateTransactionID generates a transaction ID for DHT messages
func (d *DHT) generateTransactionID() string {
d.transactionCounter++
tid := make([]byte, TransactionIDLength)
binary.BigEndian.PutUint16(tid, d.transactionCounter)
return string(tid)
}
// sendQuery sends a DHT query and tracks the transaction
func (d *DHT) sendQuery(queryType string, args map[string]interface{}, target *net.UDPAddr) (*DHTMessage, error) {
tid := d.generateTransactionID()
msg := map[string]interface{}{
"t": tid,
"y": QueryMessage,
"q": queryType,
"a": args,
}
data, err := bencode.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("failed to marshal query: %w", err)
}
// Create pending query
responseChan := make(chan *DHTMessage, 1)
pendingQuery := &PendingQuery{
TransactionID: tid,
QueryType: queryType,
Target: target,
SentTime: time.Now(),
Retries: 0,
ResponseChan: responseChan,
}
d.pendingMu.Lock()
d.pendingQueries[tid] = pendingQuery
d.pendingMu.Unlock()
// Send the query
_, err = d.conn.WriteToUDP(data, target)
if err != nil {
d.pendingMu.Lock()
delete(d.pendingQueries, tid)
d.pendingMu.Unlock()
return nil, fmt.Errorf("failed to send query: %w", err)
}
d.statsMu.Lock()
d.stats.PacketsSent++
d.statsMu.Unlock()
// Wait for response with timeout
timer := time.NewTimer(5 * time.Second)
defer timer.Stop()
select {
case response := <-responseChan:
d.pendingMu.Lock()
delete(d.pendingQueries, tid)
d.pendingMu.Unlock()
return response, nil
case <-timer.C:
d.pendingMu.Lock()
delete(d.pendingQueries, tid)
d.pendingMu.Unlock()
return nil, fmt.Errorf("query timeout")
case <-d.stopCh:
d.pendingMu.Lock()
delete(d.pendingQueries, tid)
d.pendingMu.Unlock()
return nil, fmt.Errorf("DHT stopping")
}
}
// sendError sends an error response
func (d *DHT) sendError(transactionID string, addr *net.UDPAddr, code int, message string) {
errorMsg := map[string]interface{}{
"t": transactionID,
"y": ErrorMessage,
"e": []interface{}{code, message},
}
data, err := bencode.Marshal(errorMsg)
if err != nil {
log.Printf("Failed to marshal error message: %v", err)
return
}
d.conn.WriteToUDP(data, addr)
}
// sendResponse sends a response message
func (d *DHT) sendResponse(transactionID string, addr *net.UDPAddr, response map[string]interface{}) {
response["id"] = string(d.nodeID[:])
respMsg := map[string]interface{}{
"t": transactionID,
"y": ResponseMessage,
"r": response,
}
data, err := bencode.Marshal(respMsg)
if err != nil {
log.Printf("Failed to marshal response message: %v", err)
return
}
d.conn.WriteToUDP(data, addr)
d.statsMu.Lock()
d.stats.PacketsSent++
d.statsMu.Unlock()
}
// generateToken generates a secure token with rotation for announce_peer
func (d *DHT) generateToken(addr *net.UDPAddr) string {
d.tokenMu.RLock()
defer d.tokenMu.RUnlock()
// Rotate secret if needed
d.rotateTokenSecretIfNeeded()
// Get current time window (5-minute windows)
timeWindow := time.Now().Unix() / (TokenRotationMinutes * 60)
h := sha1.New()
h.Write(d.tokenSecret)
h.Write(addr.IP)
binary.Write(h, binary.BigEndian, uint16(addr.Port))
binary.Write(h, binary.BigEndian, timeWindow)
return fmt.Sprintf("%x", h.Sum(nil)[:TokenLength])
}
// validateToken validates a token for announce_peer (accepts current and previous time window)
func (d *DHT) validateToken(token string, addr *net.UDPAddr) bool {
d.tokenMu.RLock()
defer d.tokenMu.RUnlock()
currentTime := time.Now().Unix() / (TokenRotationMinutes * 60)
// Check current time window
if d.validateTokenForTimeWindow(token, addr, currentTime, d.tokenSecret) {
return true
}
// Check previous time window with current secret
if d.validateTokenForTimeWindow(token, addr, currentTime-1, d.tokenSecret) {
return true
}
// Check current time window with previous secret (if available)
if len(d.previousSecret) > 0 {
if d.validateTokenForTimeWindow(token, addr, currentTime, d.previousSecret) {
return true
}
// Check previous time window with previous secret
if d.validateTokenForTimeWindow(token, addr, currentTime-1, d.previousSecret) {
return true
}
}
return false
}
// handlePingQuery handles ping queries
func (d *DHT) handlePingQuery(msg *DHTMessage, addr *net.UDPAddr) {
// Add querying node to routing table
if args := msg.Arguments; args != nil {
if nodeIDBytes, ok := args["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(node)
}
}
// Send pong response
response := map[string]interface{}{}
d.sendResponse(msg.TransactionID, addr, response)
}
// handleFindNodeQuery handles find_node queries
func (d *DHT) handleFindNodeQuery(msg *DHTMessage, addr *net.UDPAddr) {
// Add querying node to routing table
if args := msg.Arguments; args != nil {
if nodeIDBytes, ok := args["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(node)
}
// Get target from arguments
if targetBytes, ok := args["target"].(string); ok && len(targetBytes) == NodeIDLength {
var targetID NodeID
copy(targetID[:], targetBytes)
// Find closest nodes
closestNodes := d.findClosestNodes(targetID, MaxNodesReturn)
nodesBytes := d.encodeNodes(closestNodes)
response := map[string]interface{}{
"nodes": nodesBytes,
}
d.sendResponse(msg.TransactionID, addr, response)
return
}
}
d.sendError(msg.TransactionID, addr, 203, "Protocol Error")
}
// handleGetPeersQuery handles get_peers queries with enhanced peer selection and node inclusion
func (d *DHT) handleGetPeersQuery(msg *DHTMessage, addr *net.UDPAddr) {
// Add querying node to routing table
if args := msg.Arguments; args != nil {
if nodeIDBytes, ok := args["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(node)
}
// Get info_hash from arguments
if infoHashBytes, ok := args["info_hash"].(string); ok {
infoHash := fmt.Sprintf("%x", infoHashBytes)
// Generate token for potential announce_peer
token := d.generateToken(addr)
// Get all peers for this torrent
allPeers := d.peerStorage.GetPeers(infoHash)
// Always find closest nodes for recursive search
var targetID NodeID
copy(targetID[:], infoHashBytes)
closestNodes := d.findClosestNodes(targetID, MaxResponseNodes)
nodesBytes := d.encodeNodes(closestNodes)
// Prepare response
response := map[string]interface{}{
"token": token,
}
if len(allPeers) > 0 {
// Select reliable peers preferentially
reliablePeers := d.selectReliablePeers(allPeers, MaxResponsePeers)
// Convert peers to compact format
var values []string
for _, peer := range reliablePeers {
// Handle both IPv4 and IPv6
if ip4 := peer.IP.To4(); ip4 != nil {
// IPv4 format: 6 bytes (4 IP + 2 port)
peerBytes := make([]byte, 6)
copy(peerBytes[:4], ip4)
binary.BigEndian.PutUint16(peerBytes[4:], uint16(peer.Port))
values = append(values, string(peerBytes))
} else if len(peer.IP) == 16 {
// IPv6 format: 18 bytes (16 IP + 2 port)
peerBytes := make([]byte, 18)
copy(peerBytes[:16], peer.IP)
binary.BigEndian.PutUint16(peerBytes[16:], uint16(peer.Port))
values = append(values, string(peerBytes))
}
}
if len(values) > 0 {
response["values"] = values
log.Printf("DHT: Returned %d peers for %s to %s", len(values), infoHash[:8], addr)
}
}
// Always include closest nodes to help with recursive search
if len(closestNodes) > 0 {
response["nodes"] = nodesBytes
log.Printf("DHT: Also returned %d nodes for %s to %s", len(closestNodes), infoHash[:8], addr)
}
// Send response (will contain peers and/or nodes plus token)
d.sendResponse(msg.TransactionID, addr, response)
return
}
}
d.sendError(msg.TransactionID, addr, 203, "Protocol Error")
}
// min returns the smaller of two integers
func min(a, b int) int {
if a < b {
return a
}
return b
}
// handleAnnouncePeerQuery handles announce_peer queries with enhanced validation and storage
func (d *DHT) handleAnnouncePeerQuery(msg *DHTMessage, addr *net.UDPAddr) {
ipStr := addr.IP.String()
// Check if IP is blocked
if d.rateLimiter.IsBlocked(ipStr) {
log.Printf("DHT: Blocked IP %s attempted announce", ipStr)
d.sendError(msg.TransactionID, addr, 201, "Generic Error")
return
}
// Check announce rate limit
if !d.rateLimiter.CheckAnnounceRate(ipStr) {
log.Printf("DHT: Rate limited IP %s", ipStr)
d.sendError(msg.TransactionID, addr, 201, "Generic Error")
return
}
if args := msg.Arguments; args != nil {
// Validate token (must be from our get_peers response)
if token, ok := args["token"].(string); ok {
if !d.validateToken(token, addr) {
log.Printf("DHT: Invalid token from %s (token: %s)", addr, token[:min(len(token), 16)])
d.sendError(msg.TransactionID, addr, 203, "Protocol Error")
return
}
} else {
log.Printf("DHT: Missing token from %s", addr)
d.sendError(msg.TransactionID, addr, 203, "Protocol Error")
return
}
// Extract peer ID if available
var peerID string
if nodeIDBytes, ok := args["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
peerID = hex.EncodeToString([]byte(nodeIDBytes))
// Add querying node to routing table
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(node)
}
// Store peer info
if infoHashBytes, ok := args["info_hash"].(string); ok {
if port, ok := args["port"].(int); ok {
infoHash := fmt.Sprintf("%x", infoHashBytes)
// Check storage limit for this IP
if !d.rateLimiter.CheckStorageLimit(ipStr) {
log.Printf("DHT: Storage limit exceeded for IP %s", ipStr)
d.sendError(msg.TransactionID, addr, 201, "Generic Error")
return
}
// Create enhanced peer with metadata
now := time.Now()
peer := &Peer{
IP: addr.IP,
Port: port,
PeerID: peerID,
Added: now,
LastAnnounced: now,
Source: "announce",
ReliabilityScore: 50, // Start with neutral score
AnnounceCount: 1,
Verified: false,
}
// Store peer with error handling
if err := d.peerStorage.AddPeer(infoHash, peer); err != nil {
log.Printf("DHT: Failed to store peer %s:%d for %s: %v", addr.IP, port, infoHash[:8], err)
// Still send success to avoid revealing internal errors
} else {
// Increment storage count for rate limiting
d.rateLimiter.IncrementStorage(ipStr)
// Queue peer for verification if enabled
if d.peerVerification {
select {
case d.verificationQueue <- &PeerVerificationRequest{
InfoHash: infoHash,
Peer: peer,
Retries: 0,
}:
default:
// Verification queue full, skip verification
}
}
}
log.Printf("DHT: Successfully stored peer %s:%d for infohash %s", addr.IP, port, infoHash[:8])
// Send success response
response := map[string]interface{}{}
d.sendResponse(msg.TransactionID, addr, response)
return
}
}
}
d.sendError(msg.TransactionID, addr, 203, "Protocol Error")
}
// findClosestNodes finds the closest nodes to a target ID (wrapper for RoutingTable method)
func (d *DHT) findClosestNodes(targetID NodeID, count int) []*Node {
return d.routingTable.FindClosestNodes(targetID, count)
}
// compareNodeIDs compares two node IDs, returns -1 if a < b, 0 if a == b, 1 if a > b
func compareNodeIDs(a, b NodeID) int {
for i := 0; i < NodeIDLength; i++ {
if a[i] < b[i] {
return -1
}
if a[i] > b[i] {
return 1
}
}
return 0
}
// encodeNodes encodes nodes in compact format
func (d *DHT) encodeNodes(nodes []*Node) string {
var result []byte
for _, node := range nodes {
// Each node: 20 bytes node ID + 4 bytes IP + 2 bytes port
result = append(result, node.ID[:]...)
result = append(result, node.Addr.IP.To4()...)
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(node.Addr.Port))
result = append(result, portBytes...)
}
return string(result)
}
// sendPing sends a ping message to a node (wrapper for Ping operation)
func (d *DHT) sendPing(addr *net.UDPAddr) error {
_, err := d.Ping(addr)
return err
}
// maintenance performs periodic maintenance tasks
func (d *DHT) maintenance() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.performMaintenance()
}
}
}
// performMaintenance cleans up old nodes and refreshes buckets
func (d *DHT) performMaintenance() {
// Rotate token secret if needed
d.tokenMu.Lock()
d.rotateTokenSecretIfNeeded()
d.tokenMu.Unlock()
// Perform health checks on all nodes
d.routingTable.PerformHealthCheck(d)
// Clean up nodes marked as bad
d.routingTable.CleanupBadNodes()
// Clean up expired peers
d.peerStorage.CleanExpiredPeers()
// Clean up unreliable peers
d.CleanUnreliablePeers()
// Refresh stale buckets
d.routingTable.RefreshBuckets(d)
// Update statistics
totalNodes := 0
goodNodes := 0
questionableNodes := 0
badNodes := 0
d.routingTable.mu.RLock()
for _, bucket := range d.routingTable.buckets {
if bucket == nil {
continue
}
bucket.mu.RLock()
for _, node := range bucket.nodes {
totalNodes++
switch node.Health {
case NodeGood:
goodNodes++
case NodeQuestionable:
questionableNodes++
case NodeBad:
badNodes++
}
}
bucket.mu.RUnlock()
}
d.routingTable.mu.RUnlock()
// Get peer storage stats
peerStats := d.peerStorage.GetStats()
torrentCount := d.peerStorage.GetTorrentCount()
verificationStats := d.GetVerificationStats()
d.statsMu.Lock()
d.stats.NodesInTable = totalNodes
d.stats.StoredItems = int(peerStats.TotalPeers)
d.statsMu.Unlock()
log.Printf("DHT maintenance: %d nodes (%d good, %d questionable, %d bad), %d peers across %d torrents",
totalNodes, goodNodes, questionableNodes, badNodes, peerStats.TotalPeers, torrentCount)
log.Printf("DHT peer cache: %d hits, %d misses, %d DB writes, %d DB reads",
peerStats.CacheHits, peerStats.CacheMisses, peerStats.DatabaseWrites, peerStats.DatabaseReads)
log.Printf("DHT verification: %.1f%% verified (%d/%d), %d high reliability, %d low reliability",
verificationStats["verification_rate"], verificationStats["verified_peers"],
verificationStats["total_peers"], verificationStats["high_reliability_peers"],
verificationStats["low_reliability_peers"])
}
// handleAnnouncements processes torrent announcements
func (d *DHT) handleAnnouncements() {
for {
select {
case <-d.stopCh:
return
case req := <-d.announceQueue:
d.processAnnounce(req)
}
}
}
// processAnnounce processes a torrent announce request
func (d *DHT) processAnnounce(req AnnounceRequest) {
// Store our own peer info for this torrent
peerInfo := make([]byte, 6) // 4 bytes IP + 2 bytes port
// Get our external IP (simplified - would need proper detection)
ip := net.ParseIP("127.0.0.1").To4()
copy(peerInfo[:4], ip)
binary.BigEndian.PutUint16(peerInfo[4:], uint16(req.Port))
d.storageMu.Lock()
d.storage[req.InfoHash] = peerInfo
d.storageMu.Unlock()
log.Printf("Announced torrent %s on port %d", req.InfoHash, req.Port)
}
// Announce announces a torrent to the DHT
func (d *DHT) Announce(infoHash string, port int) {
select {
case d.announceQueue <- AnnounceRequest{InfoHash: infoHash, Port: port}:
log.Printf("Queued announce for torrent %s", infoHash)
default:
log.Printf("Announce queue full, dropping announce for %s", infoHash)
}
}
// FindPeers searches for peers for a given info hash using new peer storage
func (d *DHT) FindPeers(infoHash string) ([]*Peer, error) {
peers := d.peerStorage.GetPeers(infoHash)
return peers, nil
}
// FindPeersFromNetwork actively searches the network for peers
func (d *DHT) FindPeersFromNetwork(infohash []byte) ([]*Peer, error) {
var allPeers []*Peer
// Get closest nodes to the infohash
var targetID NodeID
copy(targetID[:], infohash)
closestNodes := d.routingTable.FindClosestNodes(targetID, Alpha)
// Query each node for peers
for _, node := range closestNodes {
peers, nodes, err := d.GetPeers(infohash, node.Addr)
if err != nil {
log.Printf("GetPeers failed for node %s: %v", node.Addr, err)
continue
}
if peers != nil {
// Found peers
allPeers = append(allPeers, peers...)
} else if nodes != nil {
// Got more nodes to query
for _, newNode := range nodes {
morePeers, _, err := d.GetPeers(infohash, newNode.Addr)
if err == nil && morePeers != nil {
allPeers = append(allPeers, morePeers...)
}
}
}
}
return allPeers, nil
}
// GetPeerStorageStats returns peer storage statistics
func (ps *PeerStorage) GetStats() PeerStorageStats {
ps.mu.RLock()
defer ps.mu.RUnlock()
return ps.stats
}
// GetTorrentCount returns the number of torrents with stored peers
func (ps *PeerStorage) GetTorrentCount() int {
ps.mu.RLock()
defer ps.mu.RUnlock()
return len(ps.infohashes)
}
// GetPeerCountForTorrent returns the number of peers for a specific torrent
func (ps *PeerStorage) GetPeerCountForTorrent(infohash string) int {
ps.mu.RLock()
peerList, exists := ps.infohashes[infohash]
ps.mu.RUnlock()
if !exists {
// Try database
if ps.db != nil {
count, err := ps.db.GetPeerCount(infohash)
if err == nil {
return count
}
}
return 0
}
peerList.mu.RLock()
defer peerList.mu.RUnlock()
return len(peerList.peers)
}
// GetStats returns current DHT statistics including peer storage
func (d *DHT) GetStats() Stats {
d.statsMu.Lock()
defer d.statsMu.Unlock()
// Update peer storage stats
peerStats := d.peerStorage.GetStats()
d.stats.StoredItems = int(peerStats.TotalPeers)
return d.stats
}
// AddNode adds a node to the routing table following Kademlia rules
func (rt *RoutingTable) AddNode(node *Node) {
// Don't add ourselves
if node.ID == rt.selfID {
return
}
distance := xor(rt.selfID, node.ID)
bucketIndex := leadingZeros(distance)
if bucketIndex >= NumBuckets {
bucketIndex = NumBuckets - 1
}
bucket := rt.buckets[bucketIndex]
bucket.mu.Lock()
defer bucket.mu.Unlock()
// Check if node already exists
for i, existingNode := range bucket.nodes {
if existingNode.ID == node.ID {
// Update existing node - move to tail (most recently seen)
updated := *node
updated.Health = NodeGood
updated.FailedQueries = 0
// Remove from current position
bucket.nodes = append(bucket.nodes[:i], bucket.nodes[i+1:]...)
// Add to tail
bucket.nodes = append(bucket.nodes, &updated)
bucket.lastChanged = time.Now()
return
}
}
// Add new node if bucket not full
if len(bucket.nodes) < BucketSize {
newNode := *node
newNode.Health = NodeGood
bucket.nodes = append(bucket.nodes, &newNode)
bucket.lastChanged = time.Now()
return
}
// Bucket is full - apply Kademlia replacement rules
// Look for bad nodes to replace first
for i, n := range bucket.nodes {
if n.Health == NodeBad {
newNode := *node
newNode.Health = NodeGood
bucket.nodes[i] = &newNode
bucket.lastChanged = time.Now()
return
}
}
// No bad nodes - look for questionable nodes
for i, n := range bucket.nodes {
if n.Health == NodeQuestionable {
// Ping the questionable node to verify it's still alive
// For now, just replace it (in a full implementation, we'd ping first)
newNode := *node
newNode.Health = NodeGood
bucket.nodes[i] = &newNode
bucket.lastChanged = time.Now()
return
}
}
// All nodes are good - check if the head node is stale
// Kademlia: if all nodes are good, ping the head (least recently seen)
// If it responds, move to tail and discard new node
// If it doesn't respond, replace it with new node
head := bucket.nodes[0]
if time.Since(head.LastSeen) > 15*time.Minute {
head.Health = NodeQuestionable
// In a full implementation, we'd ping here and wait for response
// For now, just replace stale nodes
newNode := *node
newNode.Health = NodeGood
bucket.nodes[0] = &newNode
bucket.lastChanged = time.Now()
}
// If head is fresh, discard the new node (Kademlia rule)
}
// FindClosestNodes returns the K closest nodes to a target ID using proper Kademlia algorithm
func (rt *RoutingTable) FindClosestNodes(targetID NodeID, count int) []*Node {
distance := xor(rt.selfID, targetID)
startBucket := leadingZeros(distance)
if startBucket >= NumBuckets {
startBucket = NumBuckets - 1
}
var candidates []*Node
// Start from the closest bucket and expand outward
for offset := 0; offset < NumBuckets; offset++ {
// Check bucket at startBucket + offset
if startBucket+offset < NumBuckets {
bucket := rt.buckets[startBucket+offset]
bucket.mu.RLock()
for _, node := range bucket.nodes {
if node.Health != NodeBad {
candidates = append(candidates, node)
}
}
bucket.mu.RUnlock()
}
// Check bucket at startBucket - offset (if different)
if offset > 0 && startBucket-offset >= 0 {
bucket := rt.buckets[startBucket-offset]
bucket.mu.RLock()
for _, node := range bucket.nodes {
if node.Health != NodeBad {
candidates = append(candidates, node)
}
}
bucket.mu.RUnlock()
}
// If we have enough candidates, stop searching
if len(candidates) >= count*2 {
break
}
}
// Sort candidates by XOR distance
type nodeDistance struct {
node *Node
distance NodeID
}
var distances []nodeDistance
for _, node := range candidates {
dist := xor(node.ID, targetID)
distances = append(distances, nodeDistance{node: node, distance: dist})
}
// Sort using more efficient sorting
for i := 0; i < len(distances)-1; i++ {
minIdx := i
for j := i + 1; j < len(distances); j++ {
if compareNodeIDs(distances[j].distance, distances[minIdx].distance) < 0 {
minIdx = j
}
}
if minIdx != i {
distances[i], distances[minIdx] = distances[minIdx], distances[i]
}
}
// Return up to count closest nodes
var result []*Node
for i := 0; i < len(distances) && i < count; i++ {
result = append(result, distances[i].node)
}
return result
}
// UpdateNodeHealth updates the health status of a node
func (rt *RoutingTable) UpdateNodeHealth(nodeID NodeID, health NodeHealth) {
distance := xor(rt.selfID, nodeID)
bucketIndex := leadingZeros(distance)
if bucketIndex >= NumBuckets {
bucketIndex = NumBuckets - 1
}
bucket := rt.buckets[bucketIndex]
bucket.mu.Lock()
defer bucket.mu.Unlock()
for _, node := range bucket.nodes {
if node.ID == nodeID {
node.Health = health
if health == NodeBad {
node.FailedQueries++
} else if health == NodeGood {
node.FailedQueries = 0
node.LastSeen = time.Now()
}
return
}
}
}
// RemoveNode removes a node from the routing table
func (rt *RoutingTable) RemoveNode(nodeID NodeID) {
distance := xor(rt.selfID, nodeID)
bucketIndex := leadingZeros(distance)
if bucketIndex >= NumBuckets {
bucketIndex = NumBuckets - 1
}
bucket := rt.buckets[bucketIndex]
bucket.mu.Lock()
defer bucket.mu.Unlock()
for i, node := range bucket.nodes {
if node.ID == nodeID {
// Remove node from slice
bucket.nodes = append(bucket.nodes[:i], bucket.nodes[i+1:]...)
bucket.lastChanged = time.Now()
return
}
}
}
// GetBucketInfo returns information about a bucket for debugging
func (rt *RoutingTable) GetBucketInfo(bucketIndex int) (int, time.Time) {
if bucketIndex >= NumBuckets || bucketIndex < 0 {
return 0, time.Time{}
}
bucket := rt.buckets[bucketIndex]
bucket.mu.RLock()
defer bucket.mu.RUnlock()
return len(bucket.nodes), bucket.lastChanged
}
// RefreshBuckets performs bucket refresh according to Kademlia
func (rt *RoutingTable) RefreshBuckets(dht *DHT) {
now := time.Now()
refreshInterval := 15 * time.Minute
for i, bucket := range rt.buckets {
bucket.mu.Lock()
needsRefresh := now.Sub(bucket.lastRefresh) > refreshInterval
if needsRefresh {
// Generate a random ID in this bucket's range
targetID := rt.generateRandomIDForBucket(i)
bucket.lastRefresh = now
bucket.mu.Unlock()
// Perform lookup to refresh this bucket
go dht.performLookup(targetID)
} else {
bucket.mu.Unlock()
}
}
}
// generateRandomIDForBucket generates a random node ID that would fall in the specified bucket
func (rt *RoutingTable) generateRandomIDForBucket(bucketIndex int) NodeID {
var targetID NodeID
copy(targetID[:], rt.selfID[:])
if bucketIndex == 0 {
// For bucket 0, flip a random bit in the last few bits
randomByte := byte(time.Now().UnixNano() % 256)
targetID[NodeIDLength-1] ^= randomByte
} else if bucketIndex < 160 {
// Calculate which bit to flip based on bucket index
byteIndex := (159 - bucketIndex) / 8
bitIndex := (159 - bucketIndex) % 8
// Flip the bit at the calculated position
targetID[byteIndex] ^= (1 << bitIndex)
// Add some randomness to the remaining bits
for i := byteIndex + 1; i < NodeIDLength; i++ {
targetID[i] ^= byte(time.Now().UnixNano() % 256)
}
}
return targetID
}
// PerformHealthCheck checks the health of nodes and marks questionable ones
func (rt *RoutingTable) PerformHealthCheck(dht *DHT) {
now := time.Now()
questionableThreshold := 15 * time.Minute
badThreshold := 30 * time.Minute
for _, bucket := range rt.buckets {
bucket.mu.Lock()
for _, node := range bucket.nodes {
timeSinceLastSeen := now.Sub(node.LastSeen)
// Mark nodes as questionable if not seen recently
if node.Health == NodeGood && timeSinceLastSeen > questionableThreshold {
node.Health = NodeQuestionable
// Ping questionable nodes to verify they're still alive
go dht.pingNode(node)
} else if node.Health == NodeQuestionable && timeSinceLastSeen > badThreshold {
// Mark as bad if no response for a long time
node.Health = NodeBad
}
}
bucket.mu.Unlock()
}
}
// CleanupBadNodes removes nodes marked as bad from the routing table
func (rt *RoutingTable) CleanupBadNodes() {
for _, bucket := range rt.buckets {
bucket.mu.Lock()
// Filter out bad nodes
goodNodes := make([]*Node, 0, len(bucket.nodes))
for _, node := range bucket.nodes {
if node.Health != NodeBad || node.FailedQueries < 3 {
goodNodes = append(goodNodes, node)
}
}
if len(goodNodes) != len(bucket.nodes) {
bucket.nodes = goodNodes
bucket.lastChanged = time.Now()
}
bucket.mu.Unlock()
}
}
// pingNode pings a node to check if it's still alive
func (d *DHT) pingNode(node *Node) {
start := time.Now()
err := d.sendPing(node.Addr)
if err == nil {
// Node responded - mark as good and update RTT
node.Health = NodeGood
node.FailedQueries = 0
node.LastSeen = time.Now()
node.RTT = time.Since(start)
} else {
// Node didn't respond - increment failed queries
node.FailedQueries++
if node.FailedQueries >= 3 {
node.Health = NodeBad
}
}
}
// performLookup performs a node lookup for bucket refresh
func (d *DHT) performLookup(targetID NodeID) {
// This is a simplified lookup - in a full implementation,
// this would be an iterative process finding closer and closer nodes
closestNodes := d.routingTable.FindClosestNodes(targetID, Alpha)
for _, node := range closestNodes {
// Send find_node query to each close node
args := map[string]interface{}{
"id": string(d.nodeID[:]),
"target": string(targetID[:]),
}
go func(n *Node) {
response, err := d.sendQuery("find_node", args, n.Addr)
if err == nil && response.Response != nil {
if nodesBytes, ok := response.Response["nodes"].(string); ok {
// Parse and add returned nodes to routing table
d.parseAndAddNodes([]byte(nodesBytes))
}
}
}(node)
}
}
// ============ CORE DHT OPERATIONS ============
// Ping sends a ping query to a node and updates routing table on response
func (d *DHT) Ping(addr *net.UDPAddr) (*Node, error) {
args := map[string]interface{}{
"id": string(d.nodeID[:]),
}
response, err := d.sendQuery("ping", args, addr)
if err != nil {
return nil, fmt.Errorf("ping failed: %w", err)
}
// Parse response and extract node ID
if response.Response == nil {
return nil, fmt.Errorf("invalid ping response")
}
nodeIDBytes, ok := response.Response["id"].(string)
if !ok || len(nodeIDBytes) != NodeIDLength {
return nil, fmt.Errorf("invalid node ID in ping response")
}
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
// Create and add node to routing table
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
RTT: time.Since(time.Now()), // This would be calculated properly in sendQuery
}
d.routingTable.AddNode(node)
log.Printf("Ping successful: node %x at %s", nodeID[:8], addr)
return node, nil
}
// FindNode queries for nodes close to a target ID
func (d *DHT) FindNode(targetID NodeID, addr *net.UDPAddr) ([]*Node, error) {
args := map[string]interface{}{
"id": string(d.nodeID[:]),
"target": string(targetID[:]),
}
response, err := d.sendQuery("find_node", args, addr)
if err != nil {
return nil, fmt.Errorf("find_node failed: %w", err)
}
if response.Response == nil {
return nil, fmt.Errorf("invalid find_node response")
}
// Parse responding node ID and add to routing table
if nodeIDBytes, ok := response.Response["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
respondingNode := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(respondingNode)
}
// Parse nodes from response
nodesBytes, ok := response.Response["nodes"].(string)
if !ok {
return nil, fmt.Errorf("no nodes in find_node response")
}
nodes := d.parseCompactNodes([]byte(nodesBytes))
// Add discovered nodes to routing table
for _, node := range nodes {
d.routingTable.AddNode(node)
}
log.Printf("FindNode successful: discovered %d nodes from %s", len(nodes), addr)
return nodes, nil
}
// GetPeers searches for peers of a torrent by infohash
func (d *DHT) GetPeers(infohash []byte, addr *net.UDPAddr) ([]*Peer, []*Node, error) {
args := map[string]interface{}{
"id": string(d.nodeID[:]),
"info_hash": string(infohash),
}
response, err := d.sendQuery("get_peers", args, addr)
if err != nil {
return nil, nil, fmt.Errorf("get_peers failed: %w", err)
}
if response.Response == nil {
return nil, nil, fmt.Errorf("invalid get_peers response")
}
// Parse responding node ID and add to routing table
if nodeIDBytes, ok := response.Response["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
respondingNode := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(respondingNode)
}
// Check if response contains peers or nodes
if values, ok := response.Response["values"].([]interface{}); ok {
// Response contains peers
var peers []*Peer
for _, value := range values {
if peerBytes, ok := value.(string); ok && len(peerBytes) == 6 {
ip := net.IP(peerBytes[:4])
port := int(binary.BigEndian.Uint16([]byte(peerBytes[4:])))
peers = append(peers, &Peer{
IP: ip,
Port: port,
Added: time.Now(),
})
}
}
log.Printf("GetPeers successful: found %d peers from %s", len(peers), addr)
return peers, nil, nil
}
// Response contains nodes
nodesBytes, ok := response.Response["nodes"].(string)
if !ok {
return nil, nil, fmt.Errorf("no peers or nodes in get_peers response")
}
nodes := d.parseCompactNodes([]byte(nodesBytes))
// Add discovered nodes to routing table
for _, node := range nodes {
d.routingTable.AddNode(node)
}
log.Printf("GetPeers successful: found %d nodes from %s", len(nodes), addr)
return nil, nodes, nil
}
// AnnouncePeer stores peer info for an infohash after validating token
func (d *DHT) AnnouncePeer(infohash []byte, port int, token string, addr *net.UDPAddr) error {
args := map[string]interface{}{
"id": string(d.nodeID[:]),
"info_hash": string(infohash),
"port": port,
"token": token,
}
response, err := d.sendQuery("announce_peer", args, addr)
if err != nil {
return fmt.Errorf("announce_peer failed: %w", err)
}
if response.Response == nil {
return fmt.Errorf("invalid announce_peer response")
}
// Parse responding node ID and add to routing table
if nodeIDBytes, ok := response.Response["id"].(string); ok && len(nodeIDBytes) == NodeIDLength {
var nodeID NodeID
copy(nodeID[:], nodeIDBytes)
respondingNode := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
d.routingTable.AddNode(respondingNode)
}
// Store peer info in our storage
peer := &Peer{
IP: addr.IP,
Port: port,
Added: time.Now(),
}
infohashHex := fmt.Sprintf("%x", infohash)
d.peerStorage.AddPeer(infohashHex, peer)
log.Printf("AnnouncePeer successful: stored peer %s:%d for %s", addr.IP, port, infohashHex[:8])
return nil
}
// ============ HELPER FUNCTIONS ============
// parseCompactNodes parses compact node format (26 bytes per node)
func (d *DHT) parseCompactNodes(nodesBytes []byte) []*Node {
var nodes []*Node
nodeSize := 26 // 20-byte ID + 4-byte IP + 2-byte port
for i := 0; i+nodeSize <= len(nodesBytes); i += nodeSize {
var nodeID NodeID
copy(nodeID[:], nodesBytes[i:i+20])
ip := net.IP(nodesBytes[i+20 : i+24])
port := binary.BigEndian.Uint16(nodesBytes[i+24 : i+26])
addr := &net.UDPAddr{
IP: ip,
Port: int(port),
}
node := &Node{
ID: nodeID,
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
nodes = append(nodes, node)
}
return nodes
}
// parseAndAddNodes parses compact node format and adds nodes to routing table
func (d *DHT) parseAndAddNodes(nodesBytes []byte) {
nodes := d.parseCompactNodes(nodesBytes)
for _, node := range nodes {
d.routingTable.AddNode(node)
}
}
// ============ ITERATIVE LOOKUP FUNCTIONS ============
// iterativeFindNode performs an iterative node lookup using Kademlia algorithm
func (d *DHT) iterativeFindNode(targetID NodeID, initialNodes []*net.UDPAddr) []*Node {
return d.iterativeLookup(targetID, initialNodes, "find_node")
}
// iterativeFindPeers performs an iterative peer lookup
func (d *DHT) iterativeFindPeers(targetID NodeID, initialNodes []*net.UDPAddr) ([]*Peer, []*Node) {
nodes := d.iterativeLookup(targetID, initialNodes, "get_peers")
// For now, return nodes only. Peers would be collected during the lookup process
return nil, nodes
}
// iterativeLookup performs the core iterative lookup algorithm
func (d *DHT) iterativeLookup(targetID NodeID, initialNodes []*net.UDPAddr, queryType string) []*Node {
// Shortlist of closest nodes found so far
shortlist := make(map[string]*Node)
queried := make(map[string]bool)
// Convert initial nodes to Node objects
for _, addr := range initialNodes {
// Create temporary node (we'll get real node ID when we query)
tempNode := &Node{
Addr: addr,
LastSeen: time.Now(),
Health: NodeGood,
}
key := fmt.Sprintf("%s:%d", addr.IP, addr.Port)
shortlist[key] = tempNode
}
// Add nodes from our routing table
routingTableNodes := d.routingTable.FindClosestNodes(targetID, BucketSize*2)
for _, node := range routingTableNodes {
key := fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port)
shortlist[key] = node
}
log.Printf("Starting iterative lookup for target %x with %d initial nodes", targetID[:8], len(shortlist))
// Iterative process
for iteration := 0; iteration < 10; iteration++ { // Max 10 iterations
// Find Alpha closest unqueried nodes
candidates := d.selectCandidates(shortlist, queried, targetID, Alpha)
if len(candidates) == 0 {
log.Printf("No more candidates, stopping after %d iterations", iteration)
break
}
log.Printf("Iteration %d: querying %d nodes", iteration+1, len(candidates))
// Query candidates in parallel
newNodes := d.queryNodesParallel(candidates, targetID, queryType, queried)
// Add new nodes to shortlist
addedNew := false
for _, node := range newNodes {
key := fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port)
if _, exists := shortlist[key]; !exists {
shortlist[key] = node
addedNew = true
}
}
// If no new nodes found, stop
if !addedNew {
log.Printf("No new nodes found, stopping after %d iterations", iteration+1)
break
}
}
// Return the closest nodes found
result := d.getClosestFromShortlist(shortlist, targetID, BucketSize)
log.Printf("Iterative lookup completed: found %d nodes", len(result))
return result
}
// selectCandidates selects the Alpha closest unqueried nodes
func (d *DHT) selectCandidates(shortlist map[string]*Node, queried map[string]bool, targetID NodeID, count int) []*Node {
type nodeDistance struct {
node *Node
distance NodeID
key string
}
var candidates []nodeDistance
for key, node := range shortlist {
if !queried[key] && node.Health != NodeBad {
// For initial nodes without real node IDs, use a hash of their address
var nodeIDForDistance NodeID
if node.ID == (NodeID{}) {
// Create pseudo node ID from address for distance calculation
hash := sha1.Sum([]byte(fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port)))
copy(nodeIDForDistance[:], hash[:])
} else {
nodeIDForDistance = node.ID
}
dist := xor(nodeIDForDistance, targetID)
candidates = append(candidates, nodeDistance{
node: node,
distance: dist,
key: key,
})
}
}
// Sort by distance
for i := 0; i < len(candidates)-1; i++ {
for j := i + 1; j < len(candidates); j++ {
if compareNodeIDs(candidates[i].distance, candidates[j].distance) > 0 {
candidates[i], candidates[j] = candidates[j], candidates[i]
}
}
}
// Return up to count closest candidates
var result []*Node
for i := 0; i < len(candidates) && i < count; i++ {
result = append(result, candidates[i].node)
}
return result
}
// queryNodesParallel queries multiple nodes in parallel
func (d *DHT) queryNodesParallel(candidates []*Node, targetID NodeID, queryType string, queried map[string]bool) []*Node {
type queryResult struct {
nodes []*Node
node *Node
}
resultChan := make(chan queryResult, len(candidates))
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Launch parallel queries
for _, candidate := range candidates {
wg.Add(1)
go func(node *Node) {
defer func() {
if r := recover(); r != nil {
log.Printf("DHT query panic recovered: %v", r)
}
wg.Done()
}()
key := fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port)
queried[key] = true
var discoveredNodes []*Node
select {
case <-ctx.Done():
return
default:
if queryType == "find_node" {
nodes, err := d.FindNode(targetID, node.Addr)
if err == nil {
discoveredNodes = nodes
}
} else if queryType == "get_peers" {
_, nodes, err := d.GetPeers(targetID[:], node.Addr)
if err == nil && nodes != nil {
discoveredNodes = nodes
}
}
select {
case resultChan <- queryResult{nodes: discoveredNodes, node: node}:
case <-ctx.Done():
return
}
}
}(candidate)
}
// Close channel when all queries complete or timeout
go func() {
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
}()
select {
case <-done:
case <-ctx.Done():
log.Printf("DHT parallel query timeout after 30s")
}
close(resultChan)
}()
// Collect results with timeout
var allNodes []*Node
for result := range resultChan {
if len(result.nodes) > 0 {
allNodes = append(allNodes, result.nodes...)
}
}
return allNodes
}
// getClosestFromShortlist returns the closest nodes from the shortlist
func (d *DHT) getClosestFromShortlist(shortlist map[string]*Node, targetID NodeID, count int) []*Node {
type nodeDistance struct {
node *Node
distance NodeID
}
var candidates []nodeDistance
for _, node := range shortlist {
if node.Health != NodeBad {
// Use actual node ID if available, otherwise create from address
var nodeIDForDistance NodeID
if node.ID != (NodeID{}) {
nodeIDForDistance = node.ID
} else {
hash := sha1.Sum([]byte(fmt.Sprintf("%s:%d", node.Addr.IP, node.Addr.Port)))
copy(nodeIDForDistance[:], hash[:])
}
dist := xor(nodeIDForDistance, targetID)
candidates = append(candidates, nodeDistance{
node: node,
distance: dist,
})
}
}
// Sort by distance
for i := 0; i < len(candidates)-1; i++ {
for j := i + 1; j < len(candidates); j++ {
if compareNodeIDs(candidates[i].distance, candidates[j].distance) > 0 {
candidates[i], candidates[j] = candidates[j], candidates[i]
}
}
}
// Return up to count closest nodes
var result []*Node
for i := 0; i < len(candidates) && i < count; i++ {
result = append(result, candidates[i].node)
}
return result
}
// xor calculates XOR distance between two node IDs
func xor(a, b NodeID) NodeID {
var result NodeID
for i := 0; i < NodeIDLength; i++ {
result[i] = a[i] ^ b[i]
}
return result
}
// leadingZeros counts leading zero bits
func leadingZeros(id NodeID) int {
for i, b := range id {
if b != 0 {
// Count zeros in this byte
zeros := 0
for bit := 7; bit >= 0; bit-- {
if (b>>bit)&1 == 0 {
zeros++
} else {
break
}
}
return i*8 + zeros
}
}
return NodeIDLength * 8
}
// SimpleDatabaseImpl provides a simple database implementation
type SimpleDatabaseImpl struct {
db *sql.DB
}
// NewSimpleDatabase creates a new simple database implementation
func NewSimpleDatabase(db *sql.DB) *SimpleDatabaseImpl {
return &SimpleDatabaseImpl{db: db}
}
// CreateDHTTables creates the necessary DHT tables with reliability tracking
func (sdb *SimpleDatabaseImpl) CreateDHTTables() error {
query := `
CREATE TABLE IF NOT EXISTS dht_stored_peers (
infohash TEXT NOT NULL,
ip TEXT NOT NULL,
port INTEGER NOT NULL,
peer_id TEXT,
last_announced DATETIME DEFAULT CURRENT_TIMESTAMP,
reliability_score INTEGER DEFAULT 50,
announce_count INTEGER DEFAULT 1,
verified BOOLEAN DEFAULT FALSE,
last_verified DATETIME,
PRIMARY KEY (infohash, ip, port)
);
CREATE INDEX IF NOT EXISTS idx_dht_peers_infohash ON dht_stored_peers(infohash);
CREATE INDEX IF NOT EXISTS idx_dht_peers_last_announced ON dht_stored_peers(last_announced);
CREATE INDEX IF NOT EXISTS idx_dht_peers_reliability ON dht_stored_peers(reliability_score);
`
_, err := sdb.db.Exec(query)
return err
}
// StorePeer stores a peer in the database with reliability tracking
func (sdb *SimpleDatabaseImpl) StorePeer(infohash, ip string, port int, peerID string, lastAnnounced time.Time) error {
// Check if peer exists to update announce count
var existingCount int
existQuery := `SELECT announce_count FROM dht_stored_peers WHERE infohash = ? AND ip = ? AND port = ?`
err := sdb.db.QueryRow(existQuery, infohash, ip, port).Scan(&existingCount)
if err == sql.ErrNoRows {
// New peer
query := `INSERT INTO dht_stored_peers (infohash, ip, port, peer_id, last_announced, reliability_score, announce_count)
VALUES (?, ?, ?, ?, ?, 50, 1)`
_, err := sdb.db.Exec(query, infohash, ip, port, peerID, lastAnnounced)
return err
} else if err != nil {
return err
} else {
// Existing peer - update
query := `UPDATE dht_stored_peers
SET peer_id = ?, last_announced = ?, announce_count = announce_count + 1,
reliability_score = MIN(100, reliability_score + 5)
WHERE infohash = ? AND ip = ? AND port = ?`
_, err := sdb.db.Exec(query, peerID, lastAnnounced, infohash, ip, port)
return err
}
}
// UpdatePeerReliability updates peer reliability information
func (sdb *SimpleDatabaseImpl) UpdatePeerReliability(infohash, ip string, port int, reliabilityScore int, verified bool, lastVerified time.Time) error {
query := `UPDATE dht_stored_peers
SET reliability_score = ?, verified = ?, last_verified = ?
WHERE infohash = ? AND ip = ? AND port = ?`
_, err := sdb.db.Exec(query, reliabilityScore, verified, lastVerified, infohash, ip, port)
return err
}
// GetPeersForTorrent retrieves peers for a specific torrent ordered by reliability
func (sdb *SimpleDatabaseImpl) GetPeersForTorrent(infohash string) ([]*StoredPeer, error) {
query := `SELECT infohash, ip, port, peer_id, last_announced, reliability_score, announce_count, verified, last_verified
FROM dht_stored_peers
WHERE infohash = ? AND last_announced > datetime('now', '-30 minutes')
ORDER BY reliability_score DESC, last_announced DESC
LIMIT ?`
rows, err := sdb.db.Query(query, infohash, MaxPeersPerTorrent)
if err != nil {
return nil, err
}
defer rows.Close()
var peers []*StoredPeer
for rows.Next() {
var peer StoredPeer
var lastVerified sql.NullTime
err := rows.Scan(&peer.InfoHash, &peer.IP, &peer.Port, &peer.PeerID,
&peer.LastAnnounced, &peer.ReliabilityScore, &peer.AnnounceCount,
&peer.Verified, &lastVerified)
if err != nil {
return nil, err
}
if lastVerified.Valid {
peer.LastVerified = lastVerified.Time
}
peers = append(peers, &peer)
}
return peers, rows.Err()
}
// CleanExpiredPeers removes peers older than 30 minutes
func (sdb *SimpleDatabaseImpl) CleanExpiredPeers() error {
query := `DELETE FROM dht_stored_peers WHERE last_announced < datetime('now', '-30 minutes')`
result, err := sdb.db.Exec(query)
if err != nil {
return err
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected > 0 {
log.Printf("DHT: Cleaned %d expired peers from database", rowsAffected)
}
return nil
}
// GetPeerCount returns the number of peers for a specific torrent
func (sdb *SimpleDatabaseImpl) GetPeerCount(infohash string) (int, error) {
query := `SELECT COUNT(*) FROM dht_stored_peers
WHERE infohash = ? AND last_announced > datetime('now', '-30 minutes')`
var count int
err := sdb.db.QueryRow(query, infohash).Scan(&count)
return count, err
}
// GetStoredPeersWithNodes returns both stored peers and closest nodes for better queryability
func (d *DHT) GetStoredPeersWithNodes(infoHash string) (peers []*Peer, nodes []*Node, err error) {
// Get stored peers first
peers = d.peerStorage.GetPeers(infoHash)
// If we have peers, return them
if len(peers) > 0 {
return peers, nil, nil
}
// Otherwise, find closest nodes to the infohash
var targetID NodeID
if hashBytes, err := hex.DecodeString(infoHash); err == nil && len(hashBytes) == 20 {
copy(targetID[:], hashBytes)
} else {
// Generate target from infohash string
hash := sha1.Sum([]byte(infoHash))
copy(targetID[:], hash[:])
}
nodes = d.routingTable.FindClosestNodes(targetID, MaxResponseNodes)
return nil, nodes, nil
}
// GetAllStoredTorrents returns a list of all torrents with stored peers
func (ps *PeerStorage) GetAllStoredTorrents() []string {
ps.mu.RLock()
defer ps.mu.RUnlock()
var torrents []string
for infohash := range ps.infohashes {
torrents = append(torrents, infohash)
}
return torrents
}
// GetStorageInfo returns comprehensive storage information
func (d *DHT) GetStorageInfo() map[string]interface{} {
peerStats := d.peerStorage.GetStats()
torrentCount := d.peerStorage.GetTorrentCount()
allTorrents := d.peerStorage.GetAllStoredTorrents()
return map[string]interface{}{
"total_peers": peerStats.TotalPeers,
"total_torrents": torrentCount,
"cache_hits": peerStats.CacheHits,
"cache_misses": peerStats.CacheMisses,
"database_writes": peerStats.DatabaseWrites,
"database_reads": peerStats.DatabaseReads,
"stored_torrents": allTorrents,
"max_peers_limit": MaxPeersPerTorrent,
"peer_expiry_mins": PeerExpiryMinutes,
"token_rotation_mins": TokenRotationMinutes,
}
}
// FindPeersWithFallback provides enhanced peer discovery with multiple fallbacks
func (d *DHT) FindPeersWithFallback(infoHash string) ([]*Peer, []*Node, error) {
// Try storage first
peers, nodes, err := d.GetStoredPeersWithNodes(infoHash)
if err != nil {
return nil, nil, err
}
// If we have peers, return them
if len(peers) > 0 {
log.Printf("DHT: Found %d stored peers for %s", len(peers), infoHash[:8])
return peers, nil, nil
}
// If we have nodes, we can use them for further queries
if len(nodes) > 0 {
log.Printf("DHT: Found %d closest nodes for %s", len(nodes), infoHash[:8])
return nil, nodes, nil
}
// Last resort: try network lookup if we have routing table nodes
if infoHashBytes, err := hex.DecodeString(infoHash); err == nil && len(infoHashBytes) == 20 {
networkPeers, err := d.FindPeersFromNetwork(infoHashBytes)
if err == nil && len(networkPeers) > 0 {
log.Printf("DHT: Found %d network peers for %s", len(networkPeers), infoHash[:8])
return networkPeers, nil, nil
}
}
log.Printf("DHT: No peers found for %s", infoHash[:8])
return nil, nil, fmt.Errorf("no peers found")
}
// selectReliablePeers selects the most reliable peers up to the limit
func (d *DHT) selectReliablePeers(peers []*Peer, limit int) []*Peer {
if len(peers) <= limit {
return peers
}
// Sort peers by reliability score (descending)
sortedPeers := make([]*Peer, len(peers))
copy(sortedPeers, peers)
// Simple bubble sort by reliability score
for i := 0; i < len(sortedPeers)-1; i++ {
for j := 0; j < len(sortedPeers)-1-i; j++ {
if sortedPeers[j].ReliabilityScore < sortedPeers[j+1].ReliabilityScore {
sortedPeers[j], sortedPeers[j+1] = sortedPeers[j+1], sortedPeers[j]
}
}
}
// Filter out peers with very low reliability scores
var reliablePeers []*Peer
for _, peer := range sortedPeers {
if peer.ReliabilityScore >= MinReliabilityScore && len(reliablePeers) < limit {
reliablePeers = append(reliablePeers, peer)
}
}
// If we don't have enough reliable peers, fill with less reliable ones
if len(reliablePeers) < limit {
for _, peer := range sortedPeers {
if peer.ReliabilityScore < MinReliabilityScore && len(reliablePeers) < limit {
// Include less reliable peers but flag them
reliablePeers = append(reliablePeers, peer)
}
}
}
return reliablePeers[:min(len(reliablePeers), limit)]
}
// handlePeerVerification processes peer verification requests
func (d *DHT) handlePeerVerification() {
for {
select {
case <-d.stopCh:
return
case req := <-d.verificationQueue:
d.verifyPeer(req)
}
}
}
// verifyPeer attempts to verify a peer by connecting to it
func (d *DHT) verifyPeer(req *PeerVerificationRequest) {
peerAddr := fmt.Sprintf("%s:%d", req.Peer.IP.String(), req.Peer.Port)
// Attempt to connect with timeout
conn, err := net.DialTimeout("tcp", peerAddr, PeerVerificationTimeout)
if err != nil {
// Peer verification failed
d.handleVerificationFailure(req)
return
}
defer conn.Close()
// Peer responded - update reliability score
d.updatePeerReliability(req.InfoHash, req.Peer, true)
log.Printf("DHT: Verified peer %s for %s", peerAddr, req.InfoHash[:8])
}
// handleVerificationFailure handles when peer verification fails
func (d *DHT) handleVerificationFailure(req *PeerVerificationRequest) {
peerAddr := fmt.Sprintf("%s:%d", req.Peer.IP.String(), req.Peer.Port)
// Retry up to 2 times with exponential backoff
if req.Retries < 2 {
req.Retries++
// Exponential backoff: 30s, 60s
delay := time.Duration(30*req.Retries) * time.Second
go func() {
time.Sleep(delay)
select {
case d.verificationQueue <- req:
log.Printf("DHT: Retrying verification for peer %s (attempt %d)", peerAddr, req.Retries+1)
default:
// Queue full, skip retry
log.Printf("DHT: Failed to queue verification retry for peer %s", peerAddr)
}
}()
return
}
// All retries failed - mark peer as unreliable
d.updatePeerReliability(req.InfoHash, req.Peer, false)
log.Printf("DHT: Peer verification failed for %s after %d attempts", peerAddr, req.Retries+1)
}
// updatePeerReliability updates the reliability score of a peer
func (d *DHT) updatePeerReliability(infoHash string, peer *Peer, verified bool) {
d.peerStorage.mu.Lock()
defer d.peerStorage.mu.Unlock()
peerList, exists := d.peerStorage.infohashes[infoHash]
if !exists {
return
}
peerList.mu.Lock()
defer peerList.mu.Unlock()
key := fmt.Sprintf("%s:%d", peer.IP.String(), peer.Port)
if storedPeer, exists := peerList.peers[key]; exists {
now := time.Now()
if verified {
// Increase reliability score
storedPeer.ReliabilityScore = min(MaxReliabilityScore, storedPeer.ReliabilityScore+20)
storedPeer.Verified = true
storedPeer.LastVerified = now
} else {
// Decrease reliability score
storedPeer.ReliabilityScore = max(0, storedPeer.ReliabilityScore-30)
// Don't mark as verified=false immediately, might be temporary network issue
}
// Update database if available
if d.peerStorage.db != nil {
err := d.updatePeerInDatabase(infoHash, storedPeer)
if err != nil {
log.Printf("DHT: Failed to update peer reliability in database: %v", err)
}
}
log.Printf("DHT: Updated peer %s reliability score to %d (verified: %v)",
key, storedPeer.ReliabilityScore, storedPeer.Verified)
}
}
// max returns the larger of two integers
func max(a, b int) int {
if a > b {
return a
}
return b
}
// updatePeerInDatabase updates peer information in the database
func (d *DHT) updatePeerInDatabase(infoHash string, peer *Peer) error {
if d.peerStorage.db == nil {
return nil
}
// Update peer reliability information
return d.peerStorage.db.UpdatePeerReliability(
infoHash,
peer.IP.String(),
peer.Port,
peer.ReliabilityScore,
peer.Verified,
peer.LastVerified,
)
}
// SetPeerVerification enables or disables peer verification
func (d *DHT) SetPeerVerification(enabled bool) {
d.peerVerification = enabled
log.Printf("DHT: Peer verification %s", map[bool]string{true: "enabled", false: "disabled"}[enabled])
}
// GetVerificationStats returns peer verification statistics
func (d *DHT) GetVerificationStats() map[string]interface{} {
stats := make(map[string]interface{})
// Count peers by verification status
totalPeers := 0
verifiedPeers := 0
highReliabilityPeers := 0
lowReliabilityPeers := 0
d.peerStorage.mu.RLock()
for _, peerList := range d.peerStorage.infohashes {
peerList.mu.RLock()
for _, peer := range peerList.peers {
totalPeers++
if peer.Verified {
verifiedPeers++
}
if peer.ReliabilityScore >= 80 {
highReliabilityPeers++
} else if peer.ReliabilityScore <= MinReliabilityScore {
lowReliabilityPeers++
}
}
peerList.mu.RUnlock()
}
d.peerStorage.mu.RUnlock()
stats["total_peers"] = totalPeers
stats["verified_peers"] = verifiedPeers
stats["high_reliability_peers"] = highReliabilityPeers
stats["low_reliability_peers"] = lowReliabilityPeers
stats["verification_enabled"] = d.peerVerification
stats["verification_queue_length"] = len(d.verificationQueue)
if totalPeers > 0 {
stats["verification_rate"] = float64(verifiedPeers) / float64(totalPeers) * 100
} else {
stats["verification_rate"] = 0.0
}
return stats
}
// CleanUnreliablePeers removes peers with very low reliability scores
func (d *DHT) CleanUnreliablePeers() {
d.peerStorage.mu.Lock()
defer d.peerStorage.mu.Unlock()
cleanedCount := 0
for infoHash, peerList := range d.peerStorage.infohashes {
peerList.mu.Lock()
reliablePeers := make(map[string]*Peer)
for key, peer := range peerList.peers {
// Keep peers with decent reliability or recent announces
if peer.ReliabilityScore >= 10 || time.Since(peer.LastAnnounced) < 10*time.Minute {
reliablePeers[key] = peer
} else {
cleanedCount++
log.Printf("DHT: Removed unreliable peer %s (score: %d) for %s",
key, peer.ReliabilityScore, infoHash[:8])
}
}
peerList.peers = reliablePeers
peerList.mu.Unlock()
// Remove empty peer lists
if len(reliablePeers) == 0 {
delete(d.peerStorage.infohashes, infoHash)
}
}
if cleanedCount > 0 {
log.Printf("DHT: Cleaned %d unreliable peers", cleanedCount)
}
}
// GetRateLimitStats returns rate limiting statistics
func (d *DHT) GetRateLimitStats() map[string]interface{} {
d.rateLimiter.mu.RLock()
defer d.rateLimiter.mu.RUnlock()
blockedIPs := 0
activeIPs := len(d.rateLimiter.announces)
now := time.Now()
for _, unblockTime := range d.rateLimiter.blocklist {
if now.Before(unblockTime) {
blockedIPs++
}
}
return map[string]interface{}{
"active_ips": activeIPs,
"blocked_ips": blockedIPs,
"max_announces_per_hour": MaxAnnouncesPerIPPerHour,
"max_storage_per_ip": MaxStoragePerIP,
"blocklist_duration_mins": BlocklistDurationMinutes,
}
}
// GetComprehensiveStats returns all DHT statistics
func (d *DHT) GetComprehensiveStats() map[string]interface{} {
baseStats := d.GetStats()
storageInfo := d.GetStorageInfo()
verificationStats := d.GetVerificationStats()
rateLimitStats := d.GetRateLimitStats()
return map[string]interface{}{
"node_stats": baseStats,
"storage_stats": storageInfo,
"verification_stats": verificationStats,
"rate_limit_stats": rateLimitStats,
"node_id": hex.EncodeToString(d.nodeID[:]),
"port": d.config.Port,
"routing_table_buckets": NumBuckets,
}
}
// ForceVerifyPeer manually triggers verification for a specific peer
func (d *DHT) ForceVerifyPeer(infoHash string, ip string, port int) error {
peers := d.peerStorage.GetPeers(infoHash)
for _, peer := range peers {
if peer.IP.String() == ip && peer.Port == port {
req := &PeerVerificationRequest{
InfoHash: infoHash,
Peer: peer,
Retries: 0,
}
select {
case d.verificationQueue <- req:
log.Printf("DHT: Queued forced verification for peer %s:%d", ip, port)
return nil
default:
return fmt.Errorf("verification queue full")
}
}
}
return fmt.Errorf("peer not found")
}
// GetTokenStats returns token generation statistics
func (d *DHT) GetTokenStats() map[string]interface{} {
d.tokenMu.RLock()
defer d.tokenMu.RUnlock()
return map[string]interface{}{
"secret_rotated_at": d.lastSecretRotation,
"next_rotation": d.lastSecretRotation.Add(TokenRotationMinutes * time.Minute),
"rotation_interval_mins": TokenRotationMinutes,
"token_length_bytes": TokenLength,
"has_previous_secret": len(d.previousSecret) > 0,
}
}
// GenerateInfoHash generates an info hash for a torrent name (for testing)
func GenerateInfoHash(name string) string {
hash := sha1.Sum([]byte(name))
return fmt.Sprintf("%x", hash)
}