399 lines
9.5 KiB
Go
399 lines
9.5 KiB
Go
package env
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
// ErrNotAStructPtr is returned if you pass something that is not a pointer to a
|
|
// Struct to Parse
|
|
ErrNotAStructPtr = errors.New("Expected a pointer to a Struct")
|
|
// ErrUnsupportedType if the struct field type is not supported by env
|
|
ErrUnsupportedType = errors.New("Type is not supported")
|
|
// ErrUnsupportedSliceType if the slice element type is not supported by env
|
|
ErrUnsupportedSliceType = errors.New("Unsupported slice type")
|
|
// OnEnvVarSet is an optional convenience callback, such as for logging purposes.
|
|
// If not nil, it's called after successfully setting the given field from the given value.
|
|
OnEnvVarSet func(reflect.StructField, string)
|
|
// Friendly names for reflect types
|
|
sliceOfInts = reflect.TypeOf([]int(nil))
|
|
sliceOfInt64s = reflect.TypeOf([]int64(nil))
|
|
sliceOfUint64s = reflect.TypeOf([]uint64(nil))
|
|
sliceOfStrings = reflect.TypeOf([]string(nil))
|
|
sliceOfBools = reflect.TypeOf([]bool(nil))
|
|
sliceOfFloat32s = reflect.TypeOf([]float32(nil))
|
|
sliceOfFloat64s = reflect.TypeOf([]float64(nil))
|
|
sliceOfDurations = reflect.TypeOf([]time.Duration(nil))
|
|
)
|
|
|
|
// CustomParsers is a friendly name for the type that `ParseWithFuncs()` accepts
|
|
type CustomParsers map[reflect.Type]ParserFunc
|
|
|
|
// ParserFunc defines the signature of a function that can be used within `CustomParsers`
|
|
type ParserFunc func(v string) (interface{}, error)
|
|
|
|
// Parse parses a struct containing `env` tags and loads its values from
|
|
// environment variables.
|
|
func Parse(v interface{}) error {
|
|
ptrRef := reflect.ValueOf(v)
|
|
if ptrRef.Kind() != reflect.Ptr {
|
|
return ErrNotAStructPtr
|
|
}
|
|
ref := ptrRef.Elem()
|
|
if ref.Kind() != reflect.Struct {
|
|
return ErrNotAStructPtr
|
|
}
|
|
return doParse(ref, make(map[reflect.Type]ParserFunc, 0))
|
|
}
|
|
|
|
// ParseWithFuncs is the same as `Parse` except it also allows the user to pass
|
|
// in custom parsers.
|
|
func ParseWithFuncs(v interface{}, funcMap CustomParsers) error {
|
|
ptrRef := reflect.ValueOf(v)
|
|
if ptrRef.Kind() != reflect.Ptr {
|
|
return ErrNotAStructPtr
|
|
}
|
|
ref := ptrRef.Elem()
|
|
if ref.Kind() != reflect.Struct {
|
|
return ErrNotAStructPtr
|
|
}
|
|
return doParse(ref, funcMap)
|
|
}
|
|
|
|
func doParse(ref reflect.Value, funcMap CustomParsers) error {
|
|
refType := ref.Type()
|
|
var errorList []string
|
|
|
|
for i := 0; i < refType.NumField(); i++ {
|
|
refField := ref.Field(i)
|
|
if reflect.Ptr == refField.Kind() && !refField.IsNil() && refField.CanSet() {
|
|
err := Parse(refField.Interface())
|
|
if nil != err {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
refTypeField := refType.Field(i)
|
|
value, err := get(refTypeField)
|
|
if err != nil {
|
|
errorList = append(errorList, err.Error())
|
|
continue
|
|
}
|
|
if value == "" {
|
|
continue
|
|
}
|
|
if err := set(refField, refTypeField, value, funcMap); err != nil {
|
|
errorList = append(errorList, err.Error())
|
|
continue
|
|
}
|
|
if OnEnvVarSet != nil {
|
|
OnEnvVarSet(refTypeField, value)
|
|
}
|
|
}
|
|
if len(errorList) == 0 {
|
|
return nil
|
|
}
|
|
return errors.New(strings.Join(errorList, ". "))
|
|
}
|
|
|
|
func get(field reflect.StructField) (string, error) {
|
|
var (
|
|
val string
|
|
err error
|
|
)
|
|
|
|
key, opts := parseKeyForOption(field.Tag.Get("env"))
|
|
|
|
defaultValue := field.Tag.Get("envDefault")
|
|
val = getOr(key, defaultValue)
|
|
|
|
if len(opts) > 0 {
|
|
for _, opt := range opts {
|
|
// The only option supported is "required".
|
|
switch opt {
|
|
case "":
|
|
break
|
|
case "required":
|
|
val, err = getRequired(key)
|
|
default:
|
|
err = errors.New("Env tag option " + opt + " not supported.")
|
|
}
|
|
}
|
|
}
|
|
|
|
return val, err
|
|
}
|
|
|
|
// split the env tag's key into the expected key and desired option, if any.
|
|
func parseKeyForOption(key string) (string, []string) {
|
|
opts := strings.Split(key, ",")
|
|
return opts[0], opts[1:]
|
|
}
|
|
|
|
func getRequired(key string) (string, error) {
|
|
if value, ok := os.LookupEnv(key); ok {
|
|
return value, nil
|
|
}
|
|
// We do not use fmt.Errorf to avoid another import.
|
|
return "", errors.New("Required environment variable " + key + " is not set")
|
|
}
|
|
|
|
func getOr(key, defaultValue string) string {
|
|
value, ok := os.LookupEnv(key)
|
|
if ok {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func set(field reflect.Value, refType reflect.StructField, value string, funcMap CustomParsers) error {
|
|
switch field.Kind() {
|
|
case reflect.Slice:
|
|
separator := refType.Tag.Get("envSeparator")
|
|
return handleSlice(field, value, separator)
|
|
case reflect.String:
|
|
field.SetString(value)
|
|
case reflect.Bool:
|
|
bvalue, err := strconv.ParseBool(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetBool(bvalue)
|
|
case reflect.Int:
|
|
intValue, err := strconv.ParseInt(value, 10, 32)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetInt(intValue)
|
|
case reflect.Uint:
|
|
uintValue, err := strconv.ParseUint(value, 10, 32)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetUint(uintValue)
|
|
case reflect.Float32:
|
|
v, err := strconv.ParseFloat(value, 32)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetFloat(v)
|
|
case reflect.Float64:
|
|
v, err := strconv.ParseFloat(value, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(v))
|
|
case reflect.Int64:
|
|
if refType.Type.String() == "time.Duration" {
|
|
dValue, err := time.ParseDuration(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(dValue))
|
|
} else {
|
|
intValue, err := strconv.ParseInt(value, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetInt(intValue)
|
|
}
|
|
case reflect.Uint64:
|
|
uintValue, err := strconv.ParseUint(value, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.SetUint(uintValue)
|
|
case reflect.Struct:
|
|
return handleStruct(field, refType, value, funcMap)
|
|
default:
|
|
parserFunc, ok := funcMap[refType.Type]
|
|
if !ok {
|
|
return ErrUnsupportedType
|
|
}
|
|
val, err := parserFunc(value)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(val))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func handleStruct(field reflect.Value, refType reflect.StructField, value string, funcMap CustomParsers) error {
|
|
// Does the custom parser func map contain this type?
|
|
parserFunc, ok := funcMap[field.Type()]
|
|
if !ok {
|
|
// Map does not contain a custom parser for this type
|
|
return ErrUnsupportedType
|
|
}
|
|
|
|
// Call on the custom parser func
|
|
data, err := parserFunc(value)
|
|
if err != nil {
|
|
return fmt.Errorf("Custom parser error: %v", err)
|
|
}
|
|
|
|
// Set the field to the data returned by the customer parser func
|
|
rv := reflect.ValueOf(data)
|
|
field.Set(rv)
|
|
|
|
return nil
|
|
}
|
|
|
|
func handleSlice(field reflect.Value, value, separator string) error {
|
|
if separator == "" {
|
|
separator = ","
|
|
}
|
|
|
|
splitData := strings.Split(value, separator)
|
|
|
|
switch field.Type() {
|
|
case sliceOfStrings:
|
|
field.Set(reflect.ValueOf(splitData))
|
|
case sliceOfInts:
|
|
intData, err := parseInts(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(intData))
|
|
case sliceOfInt64s:
|
|
int64Data, err := parseInt64s(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(int64Data))
|
|
case sliceOfUint64s:
|
|
uint64Data, err := parseUint64s(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(uint64Data))
|
|
case sliceOfFloat32s:
|
|
data, err := parseFloat32s(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(data))
|
|
case sliceOfFloat64s:
|
|
data, err := parseFloat64s(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(data))
|
|
case sliceOfBools:
|
|
boolData, err := parseBools(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(boolData))
|
|
case sliceOfDurations:
|
|
durationData, err := parseDurations(splitData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
field.Set(reflect.ValueOf(durationData))
|
|
default:
|
|
return ErrUnsupportedSliceType
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func parseInts(data []string) ([]int, error) {
|
|
intSlice := make([]int, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
intValue, err := strconv.ParseInt(v, 10, 32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
intSlice = append(intSlice, int(intValue))
|
|
}
|
|
return intSlice, nil
|
|
}
|
|
|
|
func parseInt64s(data []string) ([]int64, error) {
|
|
intSlice := make([]int64, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
intValue, err := strconv.ParseInt(v, 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
intSlice = append(intSlice, int64(intValue))
|
|
}
|
|
return intSlice, nil
|
|
}
|
|
|
|
func parseUint64s(data []string) ([]uint64, error) {
|
|
var uintSlice []uint64
|
|
|
|
for _, v := range data {
|
|
uintValue, err := strconv.ParseUint(v, 10, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
uintSlice = append(uintSlice, uint64(uintValue))
|
|
}
|
|
return uintSlice, nil
|
|
}
|
|
|
|
func parseFloat32s(data []string) ([]float32, error) {
|
|
float32Slice := make([]float32, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
data, err := strconv.ParseFloat(v, 32)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
float32Slice = append(float32Slice, float32(data))
|
|
}
|
|
return float32Slice, nil
|
|
}
|
|
|
|
func parseFloat64s(data []string) ([]float64, error) {
|
|
float64Slice := make([]float64, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
data, err := strconv.ParseFloat(v, 64)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
float64Slice = append(float64Slice, float64(data))
|
|
}
|
|
return float64Slice, nil
|
|
}
|
|
|
|
func parseBools(data []string) ([]bool, error) {
|
|
boolSlice := make([]bool, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
bvalue, err := strconv.ParseBool(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
boolSlice = append(boolSlice, bvalue)
|
|
}
|
|
return boolSlice, nil
|
|
}
|
|
|
|
func parseDurations(data []string) ([]time.Duration, error) {
|
|
durationSlice := make([]time.Duration, 0, len(data))
|
|
|
|
for _, v := range data {
|
|
dvalue, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
durationSlice = append(durationSlice, dvalue)
|
|
}
|
|
return durationSlice, nil
|
|
}
|