Implement reverse connection handling in SOCKS5 server

This commit is contained in:
dd dd 2024-07-25 23:25:56 +02:00
parent 6fce4c187f
commit e261e422c1
6 changed files with 119 additions and 13 deletions

View File

@ -183,6 +183,8 @@ func (e *Exit) processMessage(ctx context.Context, msg nostr.IncomingEvent) {
switch protocolMessage.Type { switch protocolMessage.Type {
case protocol.MessageConnect: case protocol.MessageConnect:
e.handleConnect(ctx, msg, protocolMessage, false) e.handleConnect(ctx, msg, protocolMessage, false)
case protocol.MessageConnectReverse:
e.handleConnectReverse(ctx, protocolMessage, false)
case protocol.MessageTypeSocks5: case protocol.MessageTypeSocks5:
e.handleSocks5ProxyMessage(msg, protocolMessage) e.handleSocks5ProxyMessage(msg, protocolMessage)
} }
@ -226,6 +228,33 @@ func (e *Exit) handleConnect(ctx context.Context, msg nostr.IncomingEvent, proto
go socks5.Proxy(connection, dst, nil) go socks5.Proxy(connection, dst, nil)
} }
func (e *Exit) handleConnectReverse(ctx context.Context, protocolMessage *protocol.Message, isTLS bool) {
e.mutexMap.Lock(protocolMessage.Key.String())
defer e.mutexMap.Unlock(protocolMessage.Key.String())
connection, err := net.Dial("tcp", ":1234")
if err != nil {
return
}
var dst net.Conn
if isTLS {
conf := tls.Config{InsecureSkipVerify: true}
dst, err = tls.Dial("tcp", e.config.BackendHost, &conf)
} else {
dst, err = net.Dial("tcp", e.config.BackendHost)
}
if err != nil {
slog.Error("could not connect to backend", "error", err)
return
}
_, err = connection.Write([]byte(protocolMessage.Key.String()))
if err != nil {
return
}
go socks5.Proxy(dst, connection, nil)
go socks5.Proxy(connection, dst, nil)
}
// handleSocks5ProxyMessage handles the SOCKS5 proxy message by writing it to the destination connection. // handleSocks5ProxyMessage handles the SOCKS5 proxy message by writing it to the destination connection.
// If the destination connection does not exist, the function returns without doing anything. // If the destination connection does not exist, the function returns without doing anything.
// //

View File

@ -17,10 +17,9 @@ import (
// It creates a signed event using the private key, public key, and destination address. // It creates a signed event using the private key, public key, and destination address.
// It ensures that the relays are available in the pool and publishes the signed event to each relay. // It ensures that the relays are available in the pool and publishes the signed event to each relay.
// Finally, it returns the Connection and nil error. If there are any errors, nil connection and the error are returned. // Finally, it returns the Connection and nil error. If there are any errors, nil connection and the error are returned.
func DialSocks(pool *nostr.SimplePool) func(ctx context.Context, net_, addr string) (net.Conn, error) { func DialSocks(pool *nostr.SimplePool, connectionID uuid.UUID) func(ctx context.Context, net_, addr string) (net.Conn, error) {
return func(ctx context.Context, net_, addr string) (net.Conn, error) { return func(ctx context.Context, net_, addr string) (net.Conn, error) {
addr = strings.ReplaceAll(addr, ".", "") addr = strings.ReplaceAll(addr, ".", "")
connectionID := uuid.New()
key := nostr.GeneratePrivateKey() key := nostr.GeneratePrivateKey()
connection := NewConnection(ctx, connection := NewConnection(ctx,
WithPrivateKey(key), WithPrivateKey(key),
@ -39,7 +38,7 @@ func DialSocks(pool *nostr.SimplePool) func(ctx context.Context, net_, addr stri
return nil, err return nil, err
} }
opts := []protocol.MessageOption{ opts := []protocol.MessageOption{
protocol.WithType(protocol.MessageConnect), protocol.WithType(protocol.MessageConnectReverse),
protocol.WithUUID(connectionID), protocol.WithUUID(connectionID),
protocol.WithDestination(addr), protocol.WithDestination(addr),
} }

View File

@ -8,8 +8,9 @@ import (
type MessageType string type MessageType string
var ( var (
MessageTypeSocks5 = MessageType("SOCKS5") MessageTypeSocks5 = MessageType("SOCKS5")
MessageConnect = MessageType("CONNECT") MessageConnect = MessageType("CONNECT")
MessageConnectReverse = MessageType("CONNECTR")
) )
type Message struct { type Message struct {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/asmogo/nws/netstr" "github.com/asmogo/nws/netstr"
"github.com/google/uuid"
"io" "io"
"net" "net"
"strconv" "strconv"
@ -167,8 +168,11 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request)
// Attempt to connect // Attempt to connect
dial := s.config.Dial dial := s.config.Dial
ch := make(chan net.Conn)
if dial == nil { if dial == nil {
dial = netstr.DialSocks(s.pool) connectionID := uuid.New()
s.tcpListener.AddConnectChannel(connectionID, ch)
dial = netstr.DialSocks(s.pool, connectionID)
} }
target, err := dial(ctx, "tcp", req.realDestAddr.FQDN) target, err := dial(ctx, "tcp", req.realDestAddr.FQDN)
if err != nil { if err != nil {
@ -192,11 +196,16 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request)
if err := SendReply(conn, successReply, &bind); err != nil { if err := SendReply(conn, successReply, &bind); err != nil {
return fmt.Errorf("failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
// read
// wait for the connection
connR := <-ch
defer connR.Close()
// Start proxying // Start proxying
errCh := make(chan error, 2) errCh := make(chan error, 2)
go Proxy(target, conn, errCh) go Proxy(connR, conn, errCh)
go Proxy(conn, target, errCh) go Proxy(conn, connR, errCh)
// Wait // Wait
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {

View File

@ -59,6 +59,7 @@ type Server struct {
config *Config config *Config
authMethods map[uint8]Authenticator authMethods map[uint8]Authenticator
pool *nostr.SimplePool pool *nostr.SimplePool
tcpListener *TCPListener
} }
// New creates a new Server and potentially returns an error // New creates a new Server and potentially returns an error
@ -86,12 +87,16 @@ func New(conf *Config, pool *nostr.SimplePool) (*Server, error) {
if conf.Logger == nil { if conf.Logger == nil {
conf.Logger = log.New(os.Stdout, "", log.LstdFlags) conf.Logger = log.New(os.Stdout, "", log.LstdFlags)
} }
listener, err := NewTCPListener()
server := &Server{ if err != nil {
config: conf, return nil, err
pool: pool, }
go listener.Start()
server := &Server{
config: conf,
pool: pool,
tcpListener: listener,
} }
server.authMethods = make(map[uint8]Authenticator) server.authMethods = make(map[uint8]Authenticator)
for _, a := range conf.AuthMethods { for _, a := range conf.AuthMethods {

63
socks5/tcp.go Normal file
View File

@ -0,0 +1,63 @@
package socks5
import (
"github.com/google/uuid"
"github.com/puzpuzpuz/xsync/v3"
"log/slog"
"net"
)
type TCPListener struct {
listener net.Listener
connectChannels *xsync.MapOf[string, chan net.Conn] // todo -- use [16]byte for uuid instead of string
}
func NewTCPListener() (*TCPListener, error) {
l, err := net.Listen("tcp", ":1234")
if err != nil {
return nil, err
}
return &TCPListener{
listener: l,
connectChannels: xsync.NewMapOf[string, chan net.Conn](),
}, nil
}
func (l *TCPListener) AddConnectChannel(uuid uuid.UUID, ch chan net.Conn) {
l.connectChannels.Store(uuid.String(), ch)
}
// Start starts the listener
func (l *TCPListener) Start() {
for {
conn, err := l.listener.Accept()
if err != nil {
return
}
go l.handleConnection(conn)
}
}
// handleConnection handles the connection
func (l *TCPListener) handleConnection(conn net.Conn) {
//defer conn.Close()
for {
// read uuid from the connection
readbuffer := make([]byte, 36)
_, err := conn.Read(readbuffer)
if err != nil {
return
}
// check if uuid is in the map
ch, ok := l.connectChannels.Load(string(readbuffer))
if !ok {
slog.Error("uuid not found in map")
return
}
// send the connection to the channel
ch <- conn
return
}
}