enki b3204ea07a
Some checks are pending
CI Pipeline / Run Tests (push) Waiting to run
CI Pipeline / Lint Code (push) Waiting to run
CI Pipeline / Security Scan (push) Waiting to run
CI Pipeline / Build Docker Images (push) Blocked by required conditions
CI Pipeline / E2E Tests (push) Blocked by required conditions
first commit
2025-08-18 00:40:15 -07:00

435 lines
9.1 KiB
Go

package dht
import (
"crypto/rand"
"crypto/sha1"
"encoding/binary"
"fmt"
"log"
"net"
"sync"
"time"
"git.sovbit.dev/enki/torrentGateway/internal/config"
)
const (
// Node ID length in bytes (160 bits for SHA-1)
NodeIDLength = 20
// K-bucket size
BucketSize = 8
// Number of buckets (160 bits = 160 buckets max)
NumBuckets = 160
// DHT protocol constants
Alpha = 3 // Parallelism parameter
)
// NodeID represents a 160-bit node identifier
type NodeID [NodeIDLength]byte
// Node represents a DHT node
type Node struct {
ID NodeID
Addr *net.UDPAddr
LastSeen time.Time
}
// Bucket represents a k-bucket in the routing table
type Bucket struct {
mu sync.RWMutex
nodes []*Node
}
// RoutingTable implements Kademlia routing table
type RoutingTable struct {
mu sync.RWMutex
selfID NodeID
buckets [NumBuckets]*Bucket
}
// DHT represents a DHT node
type DHT struct {
config *config.DHTConfig
nodeID NodeID
routingTable *RoutingTable
conn *net.UDPConn
storage map[string][]byte // Simple in-memory storage for demo
storageMu sync.RWMutex
// Channels for communication
stopCh chan struct{}
announceQueue chan AnnounceRequest
// Statistics
stats Stats
statsMu sync.RWMutex
}
// Stats tracks DHT statistics
type Stats struct {
PacketsSent int64
PacketsReceived int64
NodesInTable int
StoredItems int
}
// AnnounceRequest represents a request to announce a torrent
type AnnounceRequest struct {
InfoHash string
Port int
}
// Peer represents a BitTorrent peer
type Peer struct {
IP net.IP
Port int
}
// DHT message types
const (
MsgQuery = "q"
MsgResponse = "r"
MsgError = "e"
)
// NewDHT creates a new DHT node
func NewDHT(config *config.DHTConfig) (*DHT, error) {
// Generate random node ID
var nodeID NodeID
if _, err := rand.Read(nodeID[:]); err != nil {
return nil, fmt.Errorf("failed to generate node ID: %w", err)
}
dht := &DHT{
config: config,
nodeID: nodeID,
routingTable: NewRoutingTable(nodeID),
storage: make(map[string][]byte),
stopCh: make(chan struct{}),
announceQueue: make(chan AnnounceRequest, 100),
}
return dht, nil
}
// NewRoutingTable creates a new routing table
func NewRoutingTable(selfID NodeID) *RoutingTable {
rt := &RoutingTable{
selfID: selfID,
}
// Initialize buckets
for i := range rt.buckets {
rt.buckets[i] = &Bucket{
nodes: make([]*Node, 0, BucketSize),
}
}
return rt
}
// Start starts the DHT node
func (d *DHT) Start() error {
// Listen on UDP port
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", d.config.Port))
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen on UDP: %w", err)
}
d.conn = conn
log.Printf("DHT node started on port %d with ID %x", d.config.Port, d.nodeID)
// Start goroutines
go d.handlePackets()
go d.bootstrap()
go d.maintenance()
go d.handleAnnouncements()
return nil
}
// Stop stops the DHT node
func (d *DHT) Stop() error {
close(d.stopCh)
if d.conn != nil {
return d.conn.Close()
}
return nil
}
// handlePackets handles incoming UDP packets
func (d *DHT) handlePackets() {
buffer := make([]byte, 2048)
for {
select {
case <-d.stopCh:
return
default:
}
d.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, addr, err := d.conn.ReadFromUDP(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
log.Printf("Error reading UDP packet: %v", err)
continue
}
d.statsMu.Lock()
d.stats.PacketsSent++
d.statsMu.Unlock()
// Simple packet handling (in real implementation, would parse bencode)
go d.handlePacket(buffer[:n], addr)
}
}
// handlePacket processes a single packet
func (d *DHT) handlePacket(data []byte, addr *net.UDPAddr) {
// This is a simplified implementation
// In a real DHT, you would parse bencode messages and handle:
// - ping/pong
// - find_node
// - get_peers
// - announce_peer
log.Printf("Received packet from %s: %d bytes", addr, len(data))
}
// bootstrap connects to bootstrap nodes
func (d *DHT) bootstrap() {
for _, bootstrapAddr := range d.config.BootstrapNodes {
addr, err := net.ResolveUDPAddr("udp", bootstrapAddr)
if err != nil {
log.Printf("Failed to resolve bootstrap node %s: %v", bootstrapAddr, err)
continue
}
// Send ping to bootstrap node
d.sendPing(addr)
time.Sleep(1 * time.Second)
}
}
// sendPing sends a ping message to a node
func (d *DHT) sendPing(addr *net.UDPAddr) error {
// Simplified ping message (in real implementation, would use bencode)
message := []byte("ping")
_, err := d.conn.WriteToUDP(message, addr)
if err == nil {
d.statsMu.Lock()
d.stats.PacketsSent++
d.statsMu.Unlock()
}
return err
}
// maintenance performs periodic maintenance tasks
func (d *DHT) maintenance() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.performMaintenance()
}
}
}
// performMaintenance cleans up old nodes and refreshes buckets
func (d *DHT) performMaintenance() {
now := time.Now()
cutoff := now.Add(-15 * time.Minute)
d.routingTable.mu.Lock()
defer d.routingTable.mu.Unlock()
totalNodes := 0
for _, bucket := range d.routingTable.buckets {
if bucket == nil {
continue
}
bucket.mu.Lock()
// Remove stale nodes
activeNodes := make([]*Node, 0, len(bucket.nodes))
for _, node := range bucket.nodes {
if node.LastSeen.After(cutoff) {
activeNodes = append(activeNodes, node)
}
}
bucket.nodes = activeNodes
totalNodes += len(activeNodes)
bucket.mu.Unlock()
}
d.statsMu.Lock()
d.stats.NodesInTable = totalNodes
d.stats.StoredItems = len(d.storage)
d.statsMu.Unlock()
log.Printf("DHT maintenance: %d nodes in routing table, %d stored items", totalNodes, len(d.storage))
}
// handleAnnouncements processes torrent announcements
func (d *DHT) handleAnnouncements() {
for {
select {
case <-d.stopCh:
return
case req := <-d.announceQueue:
d.processAnnounce(req)
}
}
}
// processAnnounce processes a torrent announce request
func (d *DHT) processAnnounce(req AnnounceRequest) {
// Store our own peer info for this torrent
peerInfo := make([]byte, 6) // 4 bytes IP + 2 bytes port
// Get our external IP (simplified - would need proper detection)
ip := net.ParseIP("127.0.0.1").To4()
copy(peerInfo[:4], ip)
binary.BigEndian.PutUint16(peerInfo[4:], uint16(req.Port))
d.storageMu.Lock()
d.storage[req.InfoHash] = peerInfo
d.storageMu.Unlock()
log.Printf("Announced torrent %s on port %d", req.InfoHash, req.Port)
}
// Announce announces a torrent to the DHT
func (d *DHT) Announce(infoHash string, port int) {
select {
case d.announceQueue <- AnnounceRequest{InfoHash: infoHash, Port: port}:
log.Printf("Queued announce for torrent %s", infoHash)
default:
log.Printf("Announce queue full, dropping announce for %s", infoHash)
}
}
// FindPeers searches for peers for a given info hash
func (d *DHT) FindPeers(infoHash string) ([]Peer, error) {
d.storageMu.RLock()
peerData, exists := d.storage[infoHash]
d.storageMu.RUnlock()
if !exists {
return []Peer{}, nil
}
// Parse peer data (simplified)
if len(peerData) < 6 {
return []Peer{}, nil
}
peer := Peer{
IP: net.IP(peerData[:4]),
Port: int(binary.BigEndian.Uint16(peerData[4:])),
}
return []Peer{peer}, nil
}
// GetStats returns current DHT statistics
func (d *DHT) GetStats() Stats {
d.statsMu.RLock()
defer d.statsMu.RUnlock()
return d.stats
}
// AddNode adds a node to the routing table
func (rt *RoutingTable) AddNode(node *Node) {
distance := xor(rt.selfID, node.ID)
bucketIndex := leadingZeros(distance)
if bucketIndex >= NumBuckets {
bucketIndex = NumBuckets - 1
}
bucket := rt.buckets[bucketIndex]
bucket.mu.Lock()
defer bucket.mu.Unlock()
// Check if node already exists
for i, existingNode := range bucket.nodes {
if existingNode.ID == node.ID {
// Update existing node
bucket.nodes[i] = node
return
}
}
// Add new node if bucket not full
if len(bucket.nodes) < BucketSize {
bucket.nodes = append(bucket.nodes, node)
return
}
// Bucket is full - implement replacement logic
// For simplicity, replace the oldest node
oldestIndex := 0
oldestTime := bucket.nodes[0].LastSeen
for i, n := range bucket.nodes {
if n.LastSeen.Before(oldestTime) {
oldestIndex = i
oldestTime = n.LastSeen
}
}
bucket.nodes[oldestIndex] = node
}
// xor calculates XOR distance between two node IDs
func xor(a, b NodeID) NodeID {
var result NodeID
for i := 0; i < NodeIDLength; i++ {
result[i] = a[i] ^ b[i]
}
return result
}
// leadingZeros counts leading zero bits
func leadingZeros(id NodeID) int {
for i, b := range id {
if b != 0 {
// Count zeros in this byte
zeros := 0
for bit := 7; bit >= 0; bit-- {
if (b>>bit)&1 == 0 {
zeros++
} else {
break
}
}
return i*8 + zeros
}
}
return NodeIDLength * 8
}
// GenerateInfoHash generates an info hash for a torrent name (for testing)
func GenerateInfoHash(name string) string {
hash := sha1.Sum([]byte(name))
return fmt.Sprintf("%x", hash)
}