From e261e422c11d7560923b87f0384932ffcce6f6ab Mon Sep 17 00:00:00 2001 From: dd dd Date: Thu, 25 Jul 2024 23:25:56 +0200 Subject: [PATCH] Implement reverse connection handling in SOCKS5 server --- exit/exit.go | 29 +++++++++++++++++++++ netstr/dial.go | 5 ++-- protocol/message.go | 5 ++-- socks5/request.go | 15 ++++++++--- socks5/socks5.go | 15 +++++++---- socks5/tcp.go | 63 +++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 socks5/tcp.go diff --git a/exit/exit.go b/exit/exit.go index a0e4c50..4936512 100644 --- a/exit/exit.go +++ b/exit/exit.go @@ -183,6 +183,8 @@ func (e *Exit) processMessage(ctx context.Context, msg nostr.IncomingEvent) { switch protocolMessage.Type { case protocol.MessageConnect: e.handleConnect(ctx, msg, protocolMessage, false) + case protocol.MessageConnectReverse: + e.handleConnectReverse(ctx, protocolMessage, false) case protocol.MessageTypeSocks5: e.handleSocks5ProxyMessage(msg, protocolMessage) } @@ -226,6 +228,33 @@ func (e *Exit) handleConnect(ctx context.Context, msg nostr.IncomingEvent, proto go socks5.Proxy(connection, dst, nil) } +func (e *Exit) handleConnectReverse(ctx context.Context, protocolMessage *protocol.Message, isTLS bool) { + e.mutexMap.Lock(protocolMessage.Key.String()) + defer e.mutexMap.Unlock(protocolMessage.Key.String()) + connection, err := net.Dial("tcp", ":1234") + if err != nil { + return + } + var dst net.Conn + if isTLS { + conf := tls.Config{InsecureSkipVerify: true} + dst, err = tls.Dial("tcp", e.config.BackendHost, &conf) + } else { + dst, err = net.Dial("tcp", e.config.BackendHost) + } + if err != nil { + slog.Error("could not connect to backend", "error", err) + return + } + + _, err = connection.Write([]byte(protocolMessage.Key.String())) + if err != nil { + return + } + go socks5.Proxy(dst, connection, nil) + go socks5.Proxy(connection, dst, nil) +} + // handleSocks5ProxyMessage handles the SOCKS5 proxy message by writing it to the destination connection. // If the destination connection does not exist, the function returns without doing anything. // diff --git a/netstr/dial.go b/netstr/dial.go index b7ffd58..706e8e9 100644 --- a/netstr/dial.go +++ b/netstr/dial.go @@ -17,10 +17,9 @@ import ( // It creates a signed event using the private key, public key, and destination address. // It ensures that the relays are available in the pool and publishes the signed event to each relay. // Finally, it returns the Connection and nil error. If there are any errors, nil connection and the error are returned. -func DialSocks(pool *nostr.SimplePool) func(ctx context.Context, net_, addr string) (net.Conn, error) { +func DialSocks(pool *nostr.SimplePool, connectionID uuid.UUID) func(ctx context.Context, net_, addr string) (net.Conn, error) { return func(ctx context.Context, net_, addr string) (net.Conn, error) { addr = strings.ReplaceAll(addr, ".", "") - connectionID := uuid.New() key := nostr.GeneratePrivateKey() connection := NewConnection(ctx, WithPrivateKey(key), @@ -39,7 +38,7 @@ func DialSocks(pool *nostr.SimplePool) func(ctx context.Context, net_, addr stri return nil, err } opts := []protocol.MessageOption{ - protocol.WithType(protocol.MessageConnect), + protocol.WithType(protocol.MessageConnectReverse), protocol.WithUUID(connectionID), protocol.WithDestination(addr), } diff --git a/protocol/message.go b/protocol/message.go index ddacdf5..7ecf161 100644 --- a/protocol/message.go +++ b/protocol/message.go @@ -8,8 +8,9 @@ import ( type MessageType string var ( - MessageTypeSocks5 = MessageType("SOCKS5") - MessageConnect = MessageType("CONNECT") + MessageTypeSocks5 = MessageType("SOCKS5") + MessageConnect = MessageType("CONNECT") + MessageConnectReverse = MessageType("CONNECTR") ) type Message struct { diff --git a/socks5/request.go b/socks5/request.go index b18dafc..f9cf735 100644 --- a/socks5/request.go +++ b/socks5/request.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/asmogo/nws/netstr" + "github.com/google/uuid" "io" "net" "strconv" @@ -167,8 +168,11 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) // Attempt to connect dial := s.config.Dial + ch := make(chan net.Conn) if dial == nil { - dial = netstr.DialSocks(s.pool) + connectionID := uuid.New() + s.tcpListener.AddConnectChannel(connectionID, ch) + dial = netstr.DialSocks(s.pool, connectionID) } target, err := dial(ctx, "tcp", req.realDestAddr.FQDN) if err != nil { @@ -192,11 +196,16 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) if err := SendReply(conn, successReply, &bind); err != nil { return fmt.Errorf("failed to send reply: %v", err) } + // read + + // wait for the connection + connR := <-ch + defer connR.Close() // Start proxying errCh := make(chan error, 2) - go Proxy(target, conn, errCh) - go Proxy(conn, target, errCh) + go Proxy(connR, conn, errCh) + go Proxy(conn, connR, errCh) // Wait for i := 0; i < 2; i++ { diff --git a/socks5/socks5.go b/socks5/socks5.go index 1bbb4f4..6d6a8d2 100644 --- a/socks5/socks5.go +++ b/socks5/socks5.go @@ -59,6 +59,7 @@ type Server struct { config *Config authMethods map[uint8]Authenticator pool *nostr.SimplePool + tcpListener *TCPListener } // New creates a new Server and potentially returns an error @@ -86,12 +87,16 @@ func New(conf *Config, pool *nostr.SimplePool) (*Server, error) { if conf.Logger == nil { conf.Logger = log.New(os.Stdout, "", log.LstdFlags) } - - server := &Server{ - config: conf, - pool: pool, + listener, err := NewTCPListener() + if err != nil { + return nil, err + } + go listener.Start() + server := &Server{ + config: conf, + pool: pool, + tcpListener: listener, } - server.authMethods = make(map[uint8]Authenticator) for _, a := range conf.AuthMethods { diff --git a/socks5/tcp.go b/socks5/tcp.go new file mode 100644 index 0000000..9edef36 --- /dev/null +++ b/socks5/tcp.go @@ -0,0 +1,63 @@ +package socks5 + +import ( + "github.com/google/uuid" + "github.com/puzpuzpuz/xsync/v3" + "log/slog" + "net" +) + +type TCPListener struct { + listener net.Listener + connectChannels *xsync.MapOf[string, chan net.Conn] // todo -- use [16]byte for uuid instead of string +} + +func NewTCPListener() (*TCPListener, error) { + l, err := net.Listen("tcp", ":1234") + if err != nil { + return nil, err + } + return &TCPListener{ + listener: l, + connectChannels: xsync.NewMapOf[string, chan net.Conn](), + }, nil +} + +func (l *TCPListener) AddConnectChannel(uuid uuid.UUID, ch chan net.Conn) { + l.connectChannels.Store(uuid.String(), ch) +} + +// Start starts the listener +func (l *TCPListener) Start() { + for { + conn, err := l.listener.Accept() + if err != nil { + return + } + go l.handleConnection(conn) + } +} + +// handleConnection handles the connection +func (l *TCPListener) handleConnection(conn net.Conn) { + //defer conn.Close() + for { + // read uuid from the connection + readbuffer := make([]byte, 36) + + _, err := conn.Read(readbuffer) + if err != nil { + return + } + + // check if uuid is in the map + ch, ok := l.connectChannels.Load(string(readbuffer)) + if !ok { + slog.Error("uuid not found in map") + return + } + // send the connection to the channel + ch <- conn + return + } +}