add ALLOWED_NETS, use netip, small linting fixes

This commit is contained in:
David Dean 2023-12-26 07:00:06 +11:00
parent d0347549a4
commit fa37f4a22f
No known key found for this signature in database
GPG Key ID: 5B60492A7C71BBFC
8 changed files with 95 additions and 75 deletions

View File

@ -26,6 +26,7 @@ Simple socks5 server using go-socks5 with authentication, allowed ips list and d
|PROXY_PORT|String|1080|Set listen port for application inside docker container| |PROXY_PORT|String|1080|Set listen port for application inside docker container|
|ALLOWED_DEST_FQDN|String|EMPTY|Allowed destination address regular expression pattern. Default allows all.| |ALLOWED_DEST_FQDN|String|EMPTY|Allowed destination address regular expression pattern. Default allows all.|
|ALLOWED_IPS|String|Empty|Set allowed IP's that can connect to proxy, separator `,`| |ALLOWED_IPS|String|Empty|Set allowed IP's that can connect to proxy, separator `,`|
|ALLOWED_NETS|String|Empty|Set allowed networks that can connect to proxy, separator `,`|
# Build your own image: # Build your own image:

View File

@ -3,8 +3,9 @@ package main
import ( import (
"regexp" "regexp"
"context"
"github.com/armon/go-socks5" "github.com/armon/go-socks5"
"golang.org/x/net/context"
) )
// PermitDestAddrPattern returns a RuleSet which selectively allows addresses // PermitDestAddrPattern returns a RuleSet which selectively allows addresses

View File

@ -2,19 +2,20 @@ package main
import ( import (
"log" "log"
"net" "net/netip"
"os" "os"
"github.com/armon/go-socks5" "github.com/armon/go-socks5"
"github.com/caarlos0/env/v6" env "github.com/caarlos0/env/v6"
) )
type params struct { type params struct {
User string `env:"PROXY_USER" envDefault:""` User string `env:"PROXY_USER" envDefault:""`
Password string `env:"PROXY_PASSWORD" envDefault:""` Password string `env:"PROXY_PASSWORD" envDefault:""`
Port string `env:"PROXY_PORT" envDefault:"1080"` Port string `env:"PROXY_PORT" envDefault:"1080"`
AllowedDestFqdn string `env:"ALLOWED_DEST_FQDN" envDefault:""` AllowedDestFqdn string `env:"ALLOWED_DEST_FQDN" envDefault:""`
AllowedIPs []string `env:"ALLOWED_IPS" envSeparator:"," envDefault:""` AllowedIPs []string `env:"ALLOWED_IPS" envSeparator:"," envDefault:""`
AllowedNets []string `env:"ALLOWED_NETS" envSeparator:"," envDefault:""`
} }
func main() { func main() {
@ -48,15 +49,25 @@ func main() {
} }
// Set IP whitelist // Set IP whitelist
if len(cfg.AllowedIPs) > 0 { if len(cfg.AllowedIPs) > 0 || len(cfg.AllowedNets) > 0 {
whitelist := make([]net.IP, len(cfg.AllowedIPs)) whitelist := make([]netip.Addr, len(cfg.AllowedIPs))
for i, ip := range cfg.AllowedIPs { whitelistnet := make([]netip.Prefix, len(cfg.AllowedNets))
whitelist[i] = net.ParseIP(ip)
if len(cfg.AllowedIPs) > 0 {
for i, ip := range cfg.AllowedIPs {
whitelist[i], _ = netip.ParseAddr(ip)
}
} }
server.SetIPWhitelist(whitelist) if len(cfg.AllowedNets) > 0 {
for i, inet := range cfg.AllowedNets {
whitelistnet[i], _ = netip.ParsePrefix(inet)
}
}
server.SetIPWhitelist(whitelist, whitelistnet)
} }
log.Printf("Start listening proxy service on port %s\n", cfg.Port) log.Printf("Started proxy service listening on port %s\n", cfg.Port)
if err := server.ListenAndServe("tcp", ":"+cfg.Port); err != nil { if err := server.ListenAndServe("tcp", ":"+cfg.Port); err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -15,8 +15,8 @@ const (
) )
var ( var (
UserAuthFailed = fmt.Errorf("User authentication failed") UserAuthFailed = fmt.Errorf("user authentication failed")
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") NoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
) )
// A Request encapsulates authentication state provided // A Request encapsulates authentication state provided
@ -71,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
// Ensure we are compatible // Ensure we are compatible
if header[0] != userAuthVersion { if header[0] != userAuthVersion {
return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) return nil, fmt.Errorf("unsupported auth version: %v", header[0])
} }
// Get the user name // Get the user name
@ -114,7 +114,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
// Get the methods // Get the methods
methods, err := readMethods(bufConn) methods, err := readMethods(bufConn)
if err != nil { if err != nil {
return nil, fmt.Errorf("Failed to get auth methods: %v", err) return nil, fmt.Errorf("failed to get auth methods: %v", err)
} }
// Select a usable method // Select a usable method

View File

@ -1,13 +1,13 @@
package socks5 package socks5
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"golang.org/x/net/context"
) )
const ( const (
@ -32,7 +32,7 @@ const (
) )
var ( var (
unrecognizedAddrType = fmt.Errorf("Unrecognized address type") unrecognizedAddrType = fmt.Errorf("unrecognized address type")
) )
// AddressRewriter is used to rewrite a destination transparently // AddressRewriter is used to rewrite a destination transparently
@ -44,7 +44,7 @@ type AddressRewriter interface {
// which may be specified as IPv4, IPv6, or a FQDN // which may be specified as IPv4, IPv6, or a FQDN
type AddrSpec struct { type AddrSpec struct {
FQDN string FQDN string
IP net.IP IP netip.Addr
Port int Port int
} }
@ -58,7 +58,7 @@ func (a *AddrSpec) String() string {
// Address returns a string suitable to dial; prefer returning IP-based // Address returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN // address, fallback to FQDN
func (a AddrSpec) Address() string { func (a AddrSpec) Address() string {
if 0 != len(a.IP) { if len(a.String()) != 0 {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
} }
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
@ -91,12 +91,12 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
// Read the version byte // Read the version byte
header := []byte{0, 0, 0} header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err) return nil, fmt.Errorf("failed to get command version: %v", err)
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != socks5Version { if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0]) return nil, fmt.Errorf("unsupported command version: %v", header[0])
} }
// Read in the destination address // Read in the destination address
@ -125,9 +125,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil { if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil { if err := sendReply(conn, hostUnreachable, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) return fmt.Errorf("failed to resolve destination '%v': %v", dest.FQDN, err)
} }
ctx = ctx_ ctx = ctx_
dest.IP = addr dest.IP = addr
@ -149,9 +149,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
return s.handleAssociate(ctx, conn, req) return s.handleAssociate(ctx, conn, req)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Unsupported command: %v", req.Command) return fmt.Errorf("unsupported command: %v", req.Command)
} }
} }
@ -160,9 +160,9 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
// Check if this is allowed // Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) return fmt.Errorf("connect to %v blocked by rules", req.DestAddr)
} else { } else {
ctx = ctx_ ctx = ctx_
} }
@ -184,17 +184,18 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
resp = networkUnreachable resp = networkUnreachable
} }
if err := sendReply(conn, resp, nil); err != nil { if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) return fmt.Errorf("connect to %v failed: %v", req.DestAddr, err)
} }
defer target.Close() defer target.Close()
// Send success // Send success
local := target.LocalAddr().(*net.TCPAddr) local := target.LocalAddr().(*net.TCPAddr)
bind := AddrSpec{IP: local.IP, Port: local.Port} localAddr, _ := netip.AddrFromSlice(local.IP)
bind := AddrSpec{IP: localAddr, Port: local.Port}
if err := sendReply(conn, successReply, &bind); err != nil { if err := sendReply(conn, successReply, &bind); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
// Start proxying // Start proxying
@ -218,16 +219,16 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error
// Check if this is allowed // Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) return fmt.Errorf("bind to %v blocked by rules", req.DestAddr)
} else { } else {
ctx = ctx_ ctx = ctx_
} }
// TODO: Support bind // TODO: Support bind
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return nil return nil
} }
@ -237,16 +238,16 @@ func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) e
// Check if this is allowed // Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) return fmt.Errorf("associate to %v blocked by rules", req.DestAddr)
} else { } else {
ctx = ctx_ ctx = ctx_
} }
// TODO: Support associate // TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
return nil return nil
} }
@ -269,14 +270,14 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) {
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err return nil, err
} }
d.IP = net.IP(addr) d.IP, _ = netip.AddrFromSlice(addr)
case ipv6Address: case ipv6Address:
addr := make([]byte, 16) addr := make([]byte, 16)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err return nil, err
} }
d.IP = net.IP(addr) d.IP, _ = netip.AddrFromSlice(addr)
case fqdnAddress: case fqdnAddress:
if _, err := r.Read(addrType); err != nil { if _, err := r.Read(addrType); err != nil {
@ -320,18 +321,18 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
addrPort = uint16(addr.Port) addrPort = uint16(addr.Port)
case addr.IP.To4() != nil: case addr.IP.As4() != [4]byte{}:
addrType = ipv4Address addrType = ipv4Address
addrBody = []byte(addr.IP.To4()) addrBody = addr.IP.AsSlice()
addrPort = uint16(addr.Port) addrPort = uint16(addr.Port)
case addr.IP.To16() != nil: case addr.IP.As16() != [16]byte{}:
addrType = ipv6Address addrType = ipv6Address
addrBody = []byte(addr.IP.To16()) addrBody = addr.IP.AsSlice()
addrPort = uint16(addr.Port) addrPort = uint16(addr.Port)
default: default:
return fmt.Errorf("Failed to format address: %v", addr) return fmt.Errorf("failed to format address: %v", addr)
} }
// Format the message // Format the message

View File

@ -1,23 +1,23 @@
package socks5 package socks5
import ( import (
"net" "context"
"net/netip"
"golang.org/x/net/context"
) )
// NameResolver is used to implement custom name resolution // NameResolver is used to implement custom name resolution
type NameResolver interface { type NameResolver interface {
Resolve(ctx context.Context, name string) (context.Context, net.IP, error) Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error)
} }
// DNSResolver uses the system DNS to resolve host names // DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{} type DNSResolver struct{}
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error) {
addr, err := net.ResolveIPAddr("ip", name) addr, err := netip.ParseAddr(name)
if err != nil { if err != nil {
return ctx, nil, err return ctx, netip.Addr{}, err
} }
return ctx, addr.IP, err return ctx, addr, err
} }

View File

@ -1,7 +1,7 @@
package socks5 package socks5
import ( import (
"golang.org/x/net/context" "context"
) )
// RuleSet is used to provide custom rules to allow or prohibit actions // RuleSet is used to provide custom rules to allow or prohibit actions

View File

@ -2,12 +2,12 @@ package socks5
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"log" "log"
"net" "net"
"net/netip"
"os" "os"
"golang.org/x/net/context"
) )
const ( const (
@ -40,7 +40,7 @@ type Config struct {
Rewriter AddressRewriter Rewriter AddressRewriter
// BindIP is used for bind or udp associate // BindIP is used for bind or udp associate
BindIP net.IP BindIP netip.Addr
// Logger can be used to provide a custom log target. // Logger can be used to provide a custom log target.
// Defaults to stdout. // Defaults to stdout.
@ -53,9 +53,10 @@ type Config struct {
// Server is reponsible for accepting connections and handling // Server is reponsible for accepting connections and handling
// the details of the SOCKS5 protocol // the details of the SOCKS5 protocol
type Server struct { type Server struct {
config *Config config *Config
authMethods map[uint8]Authenticator authMethods map[uint8]Authenticator
isIPAllowed func(net.IP) bool isIPAllowed func(netip.Addr) bool
isNetAllowed func(netip.Addr) bool
} }
// New creates a new Server and potentially returns an error // New creates a new Server and potentially returns an error
@ -95,7 +96,7 @@ func New(conf *Config) (*Server, error) {
} }
// Set default IP whitelist function // Set default IP whitelist function
server.isIPAllowed = func(ip net.IP) bool { server.isIPAllowed = func(ip netip.Addr) bool {
return true // default allow all IPs return true // default allow all IPs
} }
@ -120,14 +121,18 @@ func (s *Server) Serve(l net.Listener) error {
} }
go s.ServeConn(conn) go s.ServeConn(conn)
} }
return nil
} }
// SetIPWhitelist sets the function to check if a given IP is allowed // SetIPWhitelist sets the function to check if a given IP is allowed
func (s *Server) SetIPWhitelist(allowedIPs []net.IP) { func (s *Server) SetIPWhitelist(allowedIPs []netip.Addr, allowedNets []netip.Prefix) {
s.isIPAllowed = func(ip net.IP) bool { s.isIPAllowed = func(ip netip.Addr) bool {
for _, allowedIP := range allowedIPs { for _, allowedIP := range allowedIPs {
if ip.Equal(allowedIP) { if ip.Compare(allowedIP) == 0 {
return true
}
}
for _, allowedNet := range allowedNets {
if allowedNet.Contains(ip) {
return true return true
} }
} }
@ -146,7 +151,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err) s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err)
return err return err
} }
ip := net.ParseIP(clientIP) ip, _ := netip.ParseAddr(string(clientIP))
if s.isIPAllowed(ip) { if s.isIPAllowed(ip) {
s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP) s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP)
} else { } else {
@ -163,7 +168,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
// Ensure we are compatible // Ensure we are compatible
if version[0] != socks5Version { if version[0] != socks5Version {
err := fmt.Errorf("Unsupported SOCKS version: %v", version) err := fmt.Errorf("unsupported SOCKS version: %v", version)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err
} }
@ -171,7 +176,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
// Authenticate the connection // Authenticate the connection
authContext, err := s.authenticate(conn, bufConn) authContext, err := s.authenticate(conn, bufConn)
if err != nil { if err != nil {
err = fmt.Errorf("Failed to authenticate: %v", err) err = fmt.Errorf("failed to authenticate: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err
} }
@ -180,19 +185,20 @@ func (s *Server) ServeConn(conn net.Conn) error {
if err != nil { if err != nil {
if err == unrecognizedAddrType { if err == unrecognizedAddrType {
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
} }
return fmt.Errorf("Failed to read destination address: %v", err) return fmt.Errorf("failed to read destination address: %v", err)
} }
request.AuthContext = authContext request.AuthContext = authContext
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} addr, _ := netip.ParseAddr(string(client.IP))
request.RemoteAddr = &AddrSpec{IP: addr, Port: client.Port}
} }
// Process the client request // Process the client request
if err := s.handleRequest(request, conn); err != nil { if err := s.handleRequest(request, conn); err != nil {
err = fmt.Errorf("Failed to handle request: %v", err) err = fmt.Errorf("failed to handle request: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err
} }