add ALLOWED_NETS, use netip, small linting fixes

This commit is contained in:
David Dean 2023-12-26 07:00:06 +11:00
parent d0347549a4
commit fa37f4a22f
No known key found for this signature in database
GPG Key ID: 5B60492A7C71BBFC
8 changed files with 95 additions and 75 deletions

View File

@ -26,6 +26,7 @@ Simple socks5 server using go-socks5 with authentication, allowed ips list and d
|PROXY_PORT|String|1080|Set listen port for application inside docker container|
|ALLOWED_DEST_FQDN|String|EMPTY|Allowed destination address regular expression pattern. Default allows all.|
|ALLOWED_IPS|String|Empty|Set allowed IP's that can connect to proxy, separator `,`|
|ALLOWED_NETS|String|Empty|Set allowed networks that can connect to proxy, separator `,`|
# Build your own image:

View File

@ -3,8 +3,9 @@ package main
import (
"regexp"
"context"
"github.com/armon/go-socks5"
"golang.org/x/net/context"
)
// PermitDestAddrPattern returns a RuleSet which selectively allows addresses

View File

@ -2,11 +2,11 @@ package main
import (
"log"
"net"
"net/netip"
"os"
"github.com/armon/go-socks5"
"github.com/caarlos0/env/v6"
env "github.com/caarlos0/env/v6"
)
type params struct {
@ -15,6 +15,7 @@ type params struct {
Port string `env:"PROXY_PORT" envDefault:"1080"`
AllowedDestFqdn string `env:"ALLOWED_DEST_FQDN" envDefault:""`
AllowedIPs []string `env:"ALLOWED_IPS" envSeparator:"," envDefault:""`
AllowedNets []string `env:"ALLOWED_NETS" envSeparator:"," envDefault:""`
}
func main() {
@ -48,15 +49,25 @@ func main() {
}
// Set IP whitelist
if len(cfg.AllowedIPs) > 0 || len(cfg.AllowedNets) > 0 {
whitelist := make([]netip.Addr, len(cfg.AllowedIPs))
whitelistnet := make([]netip.Prefix, len(cfg.AllowedNets))
if len(cfg.AllowedIPs) > 0 {
whitelist := make([]net.IP, len(cfg.AllowedIPs))
for i, ip := range cfg.AllowedIPs {
whitelist[i] = net.ParseIP(ip)
whitelist[i], _ = netip.ParseAddr(ip)
}
}
if len(cfg.AllowedNets) > 0 {
for i, inet := range cfg.AllowedNets {
whitelistnet[i], _ = netip.ParsePrefix(inet)
}
server.SetIPWhitelist(whitelist)
}
log.Printf("Start listening proxy service on port %s\n", cfg.Port)
server.SetIPWhitelist(whitelist, whitelistnet)
}
log.Printf("Started proxy service listening on port %s\n", cfg.Port)
if err := server.ListenAndServe("tcp", ":"+cfg.Port); err != nil {
log.Fatal(err)
}

View File

@ -15,8 +15,8 @@ const (
)
var (
UserAuthFailed = fmt.Errorf("User authentication failed")
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
UserAuthFailed = fmt.Errorf("user authentication failed")
NoSupportedAuth = fmt.Errorf("no supported authentication mechanism")
)
// A Request encapsulates authentication state provided
@ -71,7 +71,7 @@ func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer)
// Ensure we are compatible
if header[0] != userAuthVersion {
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
return nil, fmt.Errorf("unsupported auth version: %v", header[0])
}
// Get the user name
@ -114,7 +114,7 @@ func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext,
// Get the methods
methods, err := readMethods(bufConn)
if err != nil {
return nil, fmt.Errorf("Failed to get auth methods: %v", err)
return nil, fmt.Errorf("failed to get auth methods: %v", err)
}
// Select a usable method

View File

@ -1,13 +1,13 @@
package socks5
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"golang.org/x/net/context"
)
const (
@ -32,7 +32,7 @@ const (
)
var (
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
unrecognizedAddrType = fmt.Errorf("unrecognized address type")
)
// AddressRewriter is used to rewrite a destination transparently
@ -44,7 +44,7 @@ type AddressRewriter interface {
// which may be specified as IPv4, IPv6, or a FQDN
type AddrSpec struct {
FQDN string
IP net.IP
IP netip.Addr
Port int
}
@ -58,7 +58,7 @@ func (a *AddrSpec) String() string {
// Address returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN
func (a AddrSpec) Address() string {
if 0 != len(a.IP) {
if len(a.String()) != 0 {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
}
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
@ -91,12 +91,12 @@ func NewRequest(bufConn io.Reader) (*Request, error) {
// Read the version byte
header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err)
return nil, fmt.Errorf("failed to get command version: %v", err)
}
// Ensure we are compatible
if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
return nil, fmt.Errorf("unsupported command version: %v", header[0])
}
// Read in the destination address
@ -125,9 +125,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
if err != nil {
if err := sendReply(conn, hostUnreachable, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
return fmt.Errorf("failed to resolve destination '%v': %v", dest.FQDN, err)
}
ctx = ctx_
dest.IP = addr
@ -149,9 +149,9 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
return s.handleAssociate(ctx, conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Unsupported command: %v", req.Command)
return fmt.Errorf("unsupported command: %v", req.Command)
}
}
@ -160,9 +160,9 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
return fmt.Errorf("connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
@ -184,17 +184,18 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
resp = networkUnreachable
}
if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
return fmt.Errorf("connect to %v failed: %v", req.DestAddr, err)
}
defer target.Close()
// Send success
local := target.LocalAddr().(*net.TCPAddr)
bind := AddrSpec{IP: local.IP, Port: local.Port}
localAddr, _ := netip.AddrFromSlice(local.IP)
bind := AddrSpec{IP: localAddr, Port: local.Port}
if err := sendReply(conn, successReply, &bind); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
// Start proxying
@ -218,16 +219,16 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
return fmt.Errorf("bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support bind
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
}
@ -237,16 +238,16 @@ func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) e
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
return fmt.Errorf("associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
return nil
}
@ -269,14 +270,14 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) {
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = net.IP(addr)
d.IP, _ = netip.AddrFromSlice(addr)
case ipv6Address:
addr := make([]byte, 16)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = net.IP(addr)
d.IP, _ = netip.AddrFromSlice(addr)
case fqdnAddress:
if _, err := r.Read(addrType); err != nil {
@ -320,18 +321,18 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
addrPort = uint16(addr.Port)
case addr.IP.To4() != nil:
case addr.IP.As4() != [4]byte{}:
addrType = ipv4Address
addrBody = []byte(addr.IP.To4())
addrBody = addr.IP.AsSlice()
addrPort = uint16(addr.Port)
case addr.IP.To16() != nil:
case addr.IP.As16() != [16]byte{}:
addrType = ipv6Address
addrBody = []byte(addr.IP.To16())
addrBody = addr.IP.AsSlice()
addrPort = uint16(addr.Port)
default:
return fmt.Errorf("Failed to format address: %v", addr)
return fmt.Errorf("failed to format address: %v", addr)
}
// Format the message

View File

@ -1,23 +1,23 @@
package socks5
import (
"net"
"golang.org/x/net/context"
"context"
"net/netip"
)
// NameResolver is used to implement custom name resolution
type NameResolver interface {
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error)
}
// DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{}
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name)
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, netip.Addr, error) {
addr, err := netip.ParseAddr(name)
if err != nil {
return ctx, nil, err
return ctx, netip.Addr{}, err
}
return ctx, addr.IP, err
return ctx, addr, err
}

View File

@ -1,7 +1,7 @@
package socks5
import (
"golang.org/x/net/context"
"context"
)
// RuleSet is used to provide custom rules to allow or prohibit actions

View File

@ -2,12 +2,12 @@ package socks5
import (
"bufio"
"context"
"fmt"
"log"
"net"
"net/netip"
"os"
"golang.org/x/net/context"
)
const (
@ -40,7 +40,7 @@ type Config struct {
Rewriter AddressRewriter
// BindIP is used for bind or udp associate
BindIP net.IP
BindIP netip.Addr
// Logger can be used to provide a custom log target.
// Defaults to stdout.
@ -55,7 +55,8 @@ type Config struct {
type Server struct {
config *Config
authMethods map[uint8]Authenticator
isIPAllowed func(net.IP) bool
isIPAllowed func(netip.Addr) bool
isNetAllowed func(netip.Addr) bool
}
// New creates a new Server and potentially returns an error
@ -95,7 +96,7 @@ func New(conf *Config) (*Server, error) {
}
// Set default IP whitelist function
server.isIPAllowed = func(ip net.IP) bool {
server.isIPAllowed = func(ip netip.Addr) bool {
return true // default allow all IPs
}
@ -120,14 +121,18 @@ func (s *Server) Serve(l net.Listener) error {
}
go s.ServeConn(conn)
}
return nil
}
// SetIPWhitelist sets the function to check if a given IP is allowed
func (s *Server) SetIPWhitelist(allowedIPs []net.IP) {
s.isIPAllowed = func(ip net.IP) bool {
func (s *Server) SetIPWhitelist(allowedIPs []netip.Addr, allowedNets []netip.Prefix) {
s.isIPAllowed = func(ip netip.Addr) bool {
for _, allowedIP := range allowedIPs {
if ip.Equal(allowedIP) {
if ip.Compare(allowedIP) == 0 {
return true
}
}
for _, allowedNet := range allowedNets {
if allowedNet.Contains(ip) {
return true
}
}
@ -146,7 +151,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err)
return err
}
ip := net.ParseIP(clientIP)
ip, _ := netip.ParseAddr(string(clientIP))
if s.isIPAllowed(ip) {
s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP)
} else {
@ -163,7 +168,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
// Ensure we are compatible
if version[0] != socks5Version {
err := fmt.Errorf("Unsupported SOCKS version: %v", version)
err := fmt.Errorf("unsupported SOCKS version: %v", version)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}
@ -171,7 +176,7 @@ func (s *Server) ServeConn(conn net.Conn) error {
// Authenticate the connection
authContext, err := s.authenticate(conn, bufConn)
if err != nil {
err = fmt.Errorf("Failed to authenticate: %v", err)
err = fmt.Errorf("failed to authenticate: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}
@ -180,19 +185,20 @@ func (s *Server) ServeConn(conn net.Conn) error {
if err != nil {
if err == unrecognizedAddrType {
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
return fmt.Errorf("failed to send reply: %v", err)
}
}
return fmt.Errorf("Failed to read destination address: %v", err)
return fmt.Errorf("failed to read destination address: %v", err)
}
request.AuthContext = authContext
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
addr, _ := netip.ParseAddr(string(client.IP))
request.RemoteAddr = &AddrSpec{IP: addr, Port: client.Port}
}
// Process the client request
if err := s.handleRequest(request, conn); err != nil {
err = fmt.Errorf("Failed to handle request: %v", err)
err = fmt.Errorf("failed to handle request: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}