From fa37f4a22fe28570bbcae897ccfae8c6ade57a92 Mon Sep 17 00:00:00 2001 From: David Dean Date: Tue, 26 Dec 2023 07:00:06 +1100 Subject: [PATCH] add ALLOWED_NETS, use netip, small linting fixes --- README.md | 1 + ruleset.go | 3 +- server.go | 37 +++++++---- vendor/github.com/armon/go-socks5/auth.go | 8 +-- vendor/github.com/armon/go-socks5/request.go | 61 ++++++++++--------- vendor/github.com/armon/go-socks5/resolver.go | 16 ++--- vendor/github.com/armon/go-socks5/ruleset.go | 2 +- vendor/github.com/armon/go-socks5/socks5.go | 42 +++++++------ 8 files changed, 95 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index 79bfca0..6e58339 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Simple socks5 server using go-socks5 with authentication, allowed ips list and d |PROXY_PORT|String|1080|Set listen port for application inside docker container| |ALLOWED_DEST_FQDN|String|EMPTY|Allowed destination address regular expression pattern. Default allows all.| |ALLOWED_IPS|String|Empty|Set allowed IP's that can connect to proxy, separator `,`| +|ALLOWED_NETS|String|Empty|Set allowed networks that can connect to proxy, separator `,`| # Build your own image: diff --git a/ruleset.go b/ruleset.go index 4b9f6d6..1e5367e 100644 --- a/ruleset.go +++ b/ruleset.go @@ -3,8 +3,9 @@ package main import ( "regexp" + "context" + "github.com/armon/go-socks5" - "golang.org/x/net/context" ) // PermitDestAddrPattern returns a RuleSet which selectively allows addresses diff --git a/server.go b/server.go index 921f1ac..be2d846 100644 --- a/server.go +++ b/server.go @@ -2,19 +2,20 @@ package main import ( "log" - "net" + "net/netip" "os" "github.com/armon/go-socks5" - "github.com/caarlos0/env/v6" + env "github.com/caarlos0/env/v6" ) type params struct { - User string `env:"PROXY_USER" envDefault:""` - Password string `env:"PROXY_PASSWORD" envDefault:""` - Port string `env:"PROXY_PORT" envDefault:"1080"` - AllowedDestFqdn string `env:"ALLOWED_DEST_FQDN" envDefault:""` - AllowedIPs []string `env:"ALLOWED_IPS" envSeparator:"," envDefault:""` + User string `env:"PROXY_USER" envDefault:""` + Password string `env:"PROXY_PASSWORD" envDefault:""` + Port string `env:"PROXY_PORT" envDefault:"1080"` + AllowedDestFqdn string `env:"ALLOWED_DEST_FQDN" envDefault:""` + AllowedIPs []string `env:"ALLOWED_IPS" envSeparator:"," envDefault:""` + AllowedNets []string `env:"ALLOWED_NETS" envSeparator:"," envDefault:""` } func main() { @@ -48,15 +49,25 @@ func main() { } // Set IP whitelist - if len(cfg.AllowedIPs) > 0 { - whitelist := make([]net.IP, len(cfg.AllowedIPs)) - for i, ip := range cfg.AllowedIPs { - whitelist[i] = net.ParseIP(ip) + if len(cfg.AllowedIPs) > 0 || len(cfg.AllowedNets) > 0 { + whitelist := make([]netip.Addr, len(cfg.AllowedIPs)) + whitelistnet := make([]netip.Prefix, len(cfg.AllowedNets)) + + if len(cfg.AllowedIPs) > 0 { + for i, ip := range cfg.AllowedIPs { + whitelist[i], _ = netip.ParseAddr(ip) + } } - server.SetIPWhitelist(whitelist) + if len(cfg.AllowedNets) > 0 { + for i, inet := range cfg.AllowedNets { + whitelistnet[i], _ = netip.ParsePrefix(inet) + } + } + + server.SetIPWhitelist(whitelist, whitelistnet) } - log.Printf("Start listening proxy service on port %s\n", cfg.Port) + log.Printf("Started proxy service listening on port %s\n", cfg.Port) if err := server.ListenAndServe("tcp", ":"+cfg.Port); err != nil { log.Fatal(err) } diff --git a/vendor/github.com/armon/go-socks5/auth.go b/vendor/github.com/armon/go-socks5/auth.go index 7811e2a..93d4957 100644 --- a/vendor/github.com/armon/go-socks5/auth.go +++ b/vendor/github.com/armon/go-socks5/auth.go @@ -15,8 +15,8 @@ const ( ) var ( - UserAuthFailed = fmt.Errorf("User authentication failed") - NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") + UserAuthFailed = fmt.Errorf("user authentication failed") + NoSupportedAuth = fmt.Errorf("no supported authentication mechanism") ) // A Request encapsulates authentication state provided @@ -71,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) // Ensure we are compatible if header[0] != userAuthVersion { - return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) + return nil, fmt.Errorf("unsupported auth version: %v", header[0]) } // Get the user name @@ -114,7 +114,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, // Get the methods methods, err := readMethods(bufConn) if err != nil { - return nil, fmt.Errorf("Failed to get auth methods: %v", err) + return nil, fmt.Errorf("failed to get auth methods: %v", err) } // Select a usable method diff --git a/vendor/github.com/armon/go-socks5/request.go b/vendor/github.com/armon/go-socks5/request.go index b615fcb..fb00fd6 100644 --- a/vendor/github.com/armon/go-socks5/request.go +++ b/vendor/github.com/armon/go-socks5/request.go @@ -1,13 +1,13 @@ package socks5 import ( + "context" "fmt" "io" "net" + "net/netip" "strconv" "strings" - - "golang.org/x/net/context" ) const ( @@ -32,7 +32,7 @@ const ( ) var ( - unrecognizedAddrType = fmt.Errorf("Unrecognized address type") + unrecognizedAddrType = fmt.Errorf("unrecognized address type") ) // AddressRewriter is used to rewrite a destination transparently @@ -44,7 +44,7 @@ type AddressRewriter interface { // which may be specified as IPv4, IPv6, or a FQDN type AddrSpec struct { FQDN string - IP net.IP + IP netip.Addr Port int } @@ -58,7 +58,7 @@ func (a *AddrSpec) String() string { // Address returns a string suitable to dial; prefer returning IP-based // address, fallback to FQDN func (a AddrSpec) Address() string { - if 0 != len(a.IP) { + if len(a.String()) != 0 { return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) } return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) @@ -91,12 +91,12 @@ func NewRequest(bufConn io.Reader) (*Request, error) { // Read the version byte header := []byte{0, 0, 0} if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { - return nil, fmt.Errorf("Failed to get command version: %v", err) + return nil, fmt.Errorf("failed to get command version: %v", err) } // Ensure we are compatible if header[0] != socks5Version { - return nil, fmt.Errorf("Unsupported command version: %v", header[0]) + return nil, fmt.Errorf("unsupported command version: %v", header[0]) } // Read in the destination address @@ -125,9 +125,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error { ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) if err != nil { if err := sendReply(conn, hostUnreachable, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) + return fmt.Errorf("failed to resolve destination '%v': %v", dest.FQDN, err) } ctx = ctx_ dest.IP = addr @@ -149,9 +149,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error { return s.handleAssociate(ctx, conn, req) default: if err := sendReply(conn, commandNotSupported, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Unsupported command: %v", req.Command) + return fmt.Errorf("unsupported command: %v", req.Command) } } @@ -160,9 +160,9 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) + return fmt.Errorf("connect to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } @@ -184,17 +184,18 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err resp = networkUnreachable } if err := sendReply(conn, resp, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) + return fmt.Errorf("connect to %v failed: %v", req.DestAddr, err) } defer target.Close() // Send success local := target.LocalAddr().(*net.TCPAddr) - bind := AddrSpec{IP: local.IP, Port: local.Port} + localAddr, _ := netip.AddrFromSlice(local.IP) + bind := AddrSpec{IP: localAddr, Port: local.Port} if err := sendReply(conn, successReply, &bind); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } // Start proxying @@ -218,16 +219,16 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) + return fmt.Errorf("bind to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support bind if err := sendReply(conn, commandNotSupported, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } return nil } @@ -237,16 +238,16 @@ func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) e // Check if this is allowed if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } - return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) + return fmt.Errorf("associate to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support associate if err := sendReply(conn, commandNotSupported, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } return nil } @@ -269,14 +270,14 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } - d.IP = net.IP(addr) + d.IP, _ = netip.AddrFromSlice(addr) case ipv6Address: addr := make([]byte, 16) if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { return nil, err } - d.IP = net.IP(addr) + d.IP, _ = netip.AddrFromSlice(addr) case fqdnAddress: if _, err := r.Read(addrType); err != nil { @@ -320,18 +321,18 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) addrPort = uint16(addr.Port) - case addr.IP.To4() != nil: + case addr.IP.As4() != [4]byte{}: addrType = ipv4Address - addrBody = []byte(addr.IP.To4()) + addrBody = addr.IP.AsSlice() addrPort = uint16(addr.Port) - case addr.IP.To16() != nil: + case addr.IP.As16() != [16]byte{}: addrType = ipv6Address - addrBody = []byte(addr.IP.To16()) + addrBody = addr.IP.AsSlice() addrPort = uint16(addr.Port) default: - return fmt.Errorf("Failed to format address: %v", addr) + return fmt.Errorf("failed to format address: %v", addr) } // Format the message diff --git a/vendor/github.com/armon/go-socks5/resolver.go b/vendor/github.com/armon/go-socks5/resolver.go index b75a5c4..9c9b82d 100644 --- a/vendor/github.com/armon/go-socks5/resolver.go +++ b/vendor/github.com/armon/go-socks5/resolver.go @@ -1,23 +1,23 @@ package socks5 import ( - "net" - - "golang.org/x/net/context" + "context" + "net/netip" ) // NameResolver is used to implement custom name resolution type NameResolver interface { - Resolve(ctx context.Context, name string) (context.Context, net.IP, error) + Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error) } // DNSResolver uses the system DNS to resolve host names type DNSResolver struct{} -func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { - addr, err := net.ResolveIPAddr("ip", name) +func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error) { + addr, err := netip.ParseAddr(name) + if err != nil { - return ctx, nil, err + return ctx, netip.Addr{}, err } - return ctx, addr.IP, err + return ctx, addr, err } diff --git a/vendor/github.com/armon/go-socks5/ruleset.go b/vendor/github.com/armon/go-socks5/ruleset.go index ba0e353..d65699d 100644 --- a/vendor/github.com/armon/go-socks5/ruleset.go +++ b/vendor/github.com/armon/go-socks5/ruleset.go @@ -1,7 +1,7 @@ package socks5 import ( - "golang.org/x/net/context" + "context" ) // RuleSet is used to provide custom rules to allow or prohibit actions diff --git a/vendor/github.com/armon/go-socks5/socks5.go b/vendor/github.com/armon/go-socks5/socks5.go index 2d630fb..81d9f82 100644 --- a/vendor/github.com/armon/go-socks5/socks5.go +++ b/vendor/github.com/armon/go-socks5/socks5.go @@ -2,12 +2,12 @@ package socks5 import ( "bufio" + "context" "fmt" "log" "net" + "net/netip" "os" - - "golang.org/x/net/context" ) const ( @@ -40,7 +40,7 @@ type Config struct { Rewriter AddressRewriter // BindIP is used for bind or udp associate - BindIP net.IP + BindIP netip.Addr // Logger can be used to provide a custom log target. // Defaults to stdout. @@ -53,9 +53,10 @@ type Config struct { // Server is reponsible for accepting connections and handling // the details of the SOCKS5 protocol type Server struct { - config *Config - authMethods map[uint8]Authenticator - isIPAllowed func(net.IP) bool + config *Config + authMethods map[uint8]Authenticator + isIPAllowed func(netip.Addr) bool + isNetAllowed func(netip.Addr) bool } // New creates a new Server and potentially returns an error @@ -95,7 +96,7 @@ func New(conf *Config) (*Server, error) { } // Set default IP whitelist function - server.isIPAllowed = func(ip net.IP) bool { + server.isIPAllowed = func(ip netip.Addr) bool { return true // default allow all IPs } @@ -120,14 +121,18 @@ func (s *Server) Serve(l net.Listener) error { } go s.ServeConn(conn) } - return nil } // SetIPWhitelist sets the function to check if a given IP is allowed -func (s *Server) SetIPWhitelist(allowedIPs []net.IP) { - s.isIPAllowed = func(ip net.IP) bool { +func (s *Server) SetIPWhitelist(allowedIPs []netip.Addr, allowedNets []netip.Prefix) { + s.isIPAllowed = func(ip netip.Addr) bool { for _, allowedIP := range allowedIPs { - if ip.Equal(allowedIP) { + if ip.Compare(allowedIP) == 0 { + return true + } + } + for _, allowedNet := range allowedNets { + if allowedNet.Contains(ip) { return true } } @@ -146,7 +151,7 @@ func (s *Server) ServeConn(conn net.Conn) error { s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err) return err } - ip := net.ParseIP(clientIP) + ip, _ := netip.ParseAddr(string(clientIP)) if s.isIPAllowed(ip) { s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP) } else { @@ -163,7 +168,7 @@ func (s *Server) ServeConn(conn net.Conn) error { // Ensure we are compatible if version[0] != socks5Version { - err := fmt.Errorf("Unsupported SOCKS version: %v", version) + err := fmt.Errorf("unsupported SOCKS version: %v", version) s.config.Logger.Printf("[ERR] socks: %v", err) return err } @@ -171,7 +176,7 @@ func (s *Server) ServeConn(conn net.Conn) error { // Authenticate the connection authContext, err := s.authenticate(conn, bufConn) if err != nil { - err = fmt.Errorf("Failed to authenticate: %v", err) + err = fmt.Errorf("failed to authenticate: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err) return err } @@ -180,19 +185,20 @@ func (s *Server) ServeConn(conn net.Conn) error { if err != nil { if err == unrecognizedAddrType { if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { - return fmt.Errorf("Failed to send reply: %v", err) + return fmt.Errorf("failed to send reply: %v", err) } } - return fmt.Errorf("Failed to read destination address: %v", err) + return fmt.Errorf("failed to read destination address: %v", err) } request.AuthContext = authContext if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { - request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} + addr, _ := netip.ParseAddr(string(client.IP)) + request.RemoteAddr = &AddrSpec{IP: addr, Port: client.Port} } // Process the client request if err := s.handleRequest(request, conn); err != nil { - err = fmt.Errorf("Failed to handle request: %v", err) + err = fmt.Errorf("failed to handle request: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err) return err }