package dht import ( "crypto/rand" "crypto/sha1" "encoding/binary" "fmt" "log" "net" "sync" "time" "git.sovbit.dev/enki/torrentGateway/internal/config" ) const ( // Node ID length in bytes (160 bits for SHA-1) NodeIDLength = 20 // K-bucket size BucketSize = 8 // Number of buckets (160 bits = 160 buckets max) NumBuckets = 160 // DHT protocol constants Alpha = 3 // Parallelism parameter ) // NodeID represents a 160-bit node identifier type NodeID [NodeIDLength]byte // Node represents a DHT node type Node struct { ID NodeID Addr *net.UDPAddr LastSeen time.Time } // Bucket represents a k-bucket in the routing table type Bucket struct { mu sync.RWMutex nodes []*Node } // RoutingTable implements Kademlia routing table type RoutingTable struct { mu sync.RWMutex selfID NodeID buckets [NumBuckets]*Bucket } // DHT represents a DHT node type DHT struct { config *config.DHTConfig nodeID NodeID routingTable *RoutingTable conn *net.UDPConn storage map[string][]byte // Simple in-memory storage for demo storageMu sync.RWMutex // Channels for communication stopCh chan struct{} announceQueue chan AnnounceRequest // Statistics stats Stats statsMu sync.RWMutex } // Stats tracks DHT statistics type Stats struct { PacketsSent int64 PacketsReceived int64 NodesInTable int StoredItems int } // AnnounceRequest represents a request to announce a torrent type AnnounceRequest struct { InfoHash string Port int } // Peer represents a BitTorrent peer type Peer struct { IP net.IP Port int } // DHT message types const ( MsgQuery = "q" MsgResponse = "r" MsgError = "e" ) // NewDHT creates a new DHT node func NewDHT(config *config.DHTConfig) (*DHT, error) { // Generate random node ID var nodeID NodeID if _, err := rand.Read(nodeID[:]); err != nil { return nil, fmt.Errorf("failed to generate node ID: %w", err) } dht := &DHT{ config: config, nodeID: nodeID, routingTable: NewRoutingTable(nodeID), storage: make(map[string][]byte), stopCh: make(chan struct{}), announceQueue: make(chan AnnounceRequest, 100), } return dht, nil } // NewRoutingTable creates a new routing table func NewRoutingTable(selfID NodeID) *RoutingTable { rt := &RoutingTable{ selfID: selfID, } // Initialize buckets for i := range rt.buckets { rt.buckets[i] = &Bucket{ nodes: make([]*Node, 0, BucketSize), } } return rt } // Start starts the DHT node func (d *DHT) Start() error { // Listen on UDP port addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.config.Port)) if err != nil { return fmt.Errorf("failed to resolve UDP address: %w", err) } conn, err := net.ListenUDP("udp", addr) if err != nil { return fmt.Errorf("failed to listen on UDP: %w", err) } d.conn = conn log.Printf("DHT node started on port %d with ID %x", d.config.Port, d.nodeID) // Start goroutines go d.handlePackets() go d.bootstrap() go d.maintenance() go d.handleAnnouncements() return nil } // Stop stops the DHT node func (d *DHT) Stop() error { close(d.stopCh) if d.conn != nil { return d.conn.Close() } return nil } // handlePackets handles incoming UDP packets func (d *DHT) handlePackets() { buffer := make([]byte, 2048) for { select { case <-d.stopCh: return default: } d.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) n, addr, err := d.conn.ReadFromUDP(buffer) if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue } log.Printf("Error reading UDP packet: %v", err) continue } d.statsMu.Lock() d.stats.PacketsSent++ d.statsMu.Unlock() // Simple packet handling (in real implementation, would parse bencode) go d.handlePacket(buffer[:n], addr) } } // handlePacket processes a single packet func (d *DHT) handlePacket(data []byte, addr *net.UDPAddr) { // This is a simplified implementation // In a real DHT, you would parse bencode messages and handle: // - ping/pong // - find_node // - get_peers // - announce_peer log.Printf("Received packet from %s: %d bytes", addr, len(data)) } // bootstrap connects to bootstrap nodes func (d *DHT) bootstrap() { for _, bootstrapAddr := range d.config.BootstrapNodes { addr, err := net.ResolveUDPAddr("udp", bootstrapAddr) if err != nil { log.Printf("Failed to resolve bootstrap node %s: %v", bootstrapAddr, err) continue } // Send ping to bootstrap node d.sendPing(addr) time.Sleep(1 * time.Second) } } // sendPing sends a ping message to a node func (d *DHT) sendPing(addr *net.UDPAddr) error { // Simplified ping message (in real implementation, would use bencode) message := []byte("ping") _, err := d.conn.WriteToUDP(message, addr) if err == nil { d.statsMu.Lock() d.stats.PacketsSent++ d.statsMu.Unlock() } return err } // maintenance performs periodic maintenance tasks func (d *DHT) maintenance() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for { select { case <-d.stopCh: return case <-ticker.C: d.performMaintenance() } } } // performMaintenance cleans up old nodes and refreshes buckets func (d *DHT) performMaintenance() { now := time.Now() cutoff := now.Add(-15 * time.Minute) d.routingTable.mu.Lock() defer d.routingTable.mu.Unlock() totalNodes := 0 for _, bucket := range d.routingTable.buckets { if bucket == nil { continue } bucket.mu.Lock() // Remove stale nodes activeNodes := make([]*Node, 0, len(bucket.nodes)) for _, node := range bucket.nodes { if node.LastSeen.After(cutoff) { activeNodes = append(activeNodes, node) } } bucket.nodes = activeNodes totalNodes += len(activeNodes) bucket.mu.Unlock() } d.statsMu.Lock() d.stats.NodesInTable = totalNodes d.stats.StoredItems = len(d.storage) d.statsMu.Unlock() log.Printf("DHT maintenance: %d nodes in routing table, %d stored items", totalNodes, len(d.storage)) } // handleAnnouncements processes torrent announcements func (d *DHT) handleAnnouncements() { for { select { case <-d.stopCh: return case req := <-d.announceQueue: d.processAnnounce(req) } } } // processAnnounce processes a torrent announce request func (d *DHT) processAnnounce(req AnnounceRequest) { // Store our own peer info for this torrent peerInfo := make([]byte, 6) // 4 bytes IP + 2 bytes port // Get our external IP (simplified - would need proper detection) ip := net.ParseIP("127.0.0.1").To4() copy(peerInfo[:4], ip) binary.BigEndian.PutUint16(peerInfo[4:], uint16(req.Port)) d.storageMu.Lock() d.storage[req.InfoHash] = peerInfo d.storageMu.Unlock() log.Printf("Announced torrent %s on port %d", req.InfoHash, req.Port) } // Announce announces a torrent to the DHT func (d *DHT) Announce(infoHash string, port int) { select { case d.announceQueue <- AnnounceRequest{InfoHash: infoHash, Port: port}: log.Printf("Queued announce for torrent %s", infoHash) default: log.Printf("Announce queue full, dropping announce for %s", infoHash) } } // FindPeers searches for peers for a given info hash func (d *DHT) FindPeers(infoHash string) ([]Peer, error) { d.storageMu.RLock() peerData, exists := d.storage[infoHash] d.storageMu.RUnlock() if !exists { return []Peer{}, nil } // Parse peer data (simplified) if len(peerData) < 6 { return []Peer{}, nil } peer := Peer{ IP: net.IP(peerData[:4]), Port: int(binary.BigEndian.Uint16(peerData[4:])), } return []Peer{peer}, nil } // GetStats returns current DHT statistics func (d *DHT) GetStats() Stats { d.statsMu.RLock() defer d.statsMu.RUnlock() return d.stats } // AddNode adds a node to the routing table func (rt *RoutingTable) AddNode(node *Node) { distance := xor(rt.selfID, node.ID) bucketIndex := leadingZeros(distance) if bucketIndex >= NumBuckets { bucketIndex = NumBuckets - 1 } bucket := rt.buckets[bucketIndex] bucket.mu.Lock() defer bucket.mu.Unlock() // Check if node already exists for i, existingNode := range bucket.nodes { if existingNode.ID == node.ID { // Update existing node bucket.nodes[i] = node return } } // Add new node if bucket not full if len(bucket.nodes) < BucketSize { bucket.nodes = append(bucket.nodes, node) return } // Bucket is full - implement replacement logic // For simplicity, replace the oldest node oldestIndex := 0 oldestTime := bucket.nodes[0].LastSeen for i, n := range bucket.nodes { if n.LastSeen.Before(oldestTime) { oldestIndex = i oldestTime = n.LastSeen } } bucket.nodes[oldestIndex] = node } // xor calculates XOR distance between two node IDs func xor(a, b NodeID) NodeID { var result NodeID for i := 0; i < NodeIDLength; i++ { result[i] = a[i] ^ b[i] } return result } // leadingZeros counts leading zero bits func leadingZeros(id NodeID) int { for i, b := range id { if b != 0 { // Count zeros in this byte zeros := 0 for bit := 7; bit >= 0; bit-- { if (b>>bit)&1 == 0 { zeros++ } else { break } } return i*8 + zeros } } return NodeIDLength * 8 } // GenerateInfoHash generates an info hash for a torrent name (for testing) func GenerateInfoHash(name string) string { hash := sha1.Sum([]byte(name)) return fmt.Sprintf("%x", hash) }