package socks5 import ( "bufio" "fmt" "log" "net" "os" "golang.org/x/net/context" ) const ( socks5Version = uint8(5) ) // Config is used to setup and configure a Server type Config struct { // AuthMethods can be provided to implement custom authentication // By default, "auth-less" mode is enabled. // For password-based auth use UserPassAuthenticator. AuthMethods []Authenticator // If provided, username/password authentication is enabled, // by appending a UserPassAuthenticator to AuthMethods. If not provided, // and AUthMethods is nil, then "auth-less" mode is enabled. Credentials CredentialStore // Resolver can be provided to do custom name resolution. // Defaults to DNSResolver if not provided. Resolver NameResolver // Rules is provided to enable custom logic around permitting // various commands. If not provided, PermitAll is used. Rules RuleSet // Rewriter can be used to transparently rewrite addresses. // This is invoked before the RuleSet is invoked. // Defaults to NoRewrite. Rewriter AddressRewriter // BindIP is used for bind or udp associate BindIP net.IP // Logger can be used to provide a custom log target. // Defaults to stdout. Logger *log.Logger // Optional function for dialing out Dial func(ctx context.Context, network, addr string) (net.Conn, error) } // Server is reponsible for accepting connections and handling // the details of the SOCKS5 protocol type Server struct { config *Config authMethods map[uint8]Authenticator isIPAllowed func(net.IP) bool } // New creates a new Server and potentially returns an error func New(conf *Config) (*Server, error) { // Ensure we have at least one authentication method enabled if len(conf.AuthMethods) == 0 { if conf.Credentials != nil { conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} } else { conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} } } // Ensure we have a DNS resolver if conf.Resolver == nil { conf.Resolver = DNSResolver{} } // Ensure we have a rule set if conf.Rules == nil { conf.Rules = PermitAll() } // Ensure we have a log target if conf.Logger == nil { conf.Logger = log.New(os.Stdout, "", log.LstdFlags) } server := &Server{ config: conf, } server.authMethods = make(map[uint8]Authenticator) for _, a := range conf.AuthMethods { server.authMethods[a.GetCode()] = a } // Set default IP whitelist function server.isIPAllowed = func(ip net.IP) bool { return true // default allow all IPs } return server, nil } // ListenAndServe is used to create a listener and serve on it func (s *Server) ListenAndServe(network, addr string) error { l, err := net.Listen(network, addr) if err != nil { return err } return s.Serve(l) } // Serve is used to serve connections from a listener func (s *Server) Serve(l net.Listener) error { for { conn, err := l.Accept() if err != nil { return err } go s.ServeConn(conn) } return nil } // SetIPWhitelist sets the function to check if a given IP is allowed func (s *Server) SetIPWhitelist(allowedIPs []net.IP) { s.isIPAllowed = func(ip net.IP) bool { for _, allowedIP := range allowedIPs { if ip.Equal(allowedIP) { return true } } return false } } // ServeConn is used to serve a single connection. func (s *Server) ServeConn(conn net.Conn) error { defer conn.Close() bufConn := bufio.NewReader(conn) // Check client IP against whitelist clientIP, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err != nil { s.config.Logger.Printf("[ERR] socks: Failed to get client IP address: %v", err) return err } ip := net.ParseIP(clientIP) if s.isIPAllowed(ip) { s.config.Logger.Printf("[INFO] socks: Connection from allowed IP address: %s", clientIP) } else { s.config.Logger.Printf("[WARN] socks: Connection from not allowed IP address: %s", clientIP) return fmt.Errorf("connection from not allowed IP address") } // Read the version byte version := []byte{0} if _, err := bufConn.Read(version); err != nil { s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) return err } // Ensure we are compatible if version[0] != socks5Version { err := fmt.Errorf("Unsupported SOCKS version: %v", version) s.config.Logger.Printf("[ERR] socks: %v", err) return err } // Authenticate the connection authContext, err := s.authenticate(conn, bufConn) if err != nil { err = fmt.Errorf("Failed to authenticate: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err) return err } request, err := NewRequest(bufConn) if err != nil { if err == unrecognizedAddrType { if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } } return fmt.Errorf("Failed to read destination address: %v", err) } request.AuthContext = authContext if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} } // Process the client request if err := s.handleRequest(request, conn); err != nil { err = fmt.Errorf("Failed to handle request: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err) return err } return nil }