@@ -17,8 +17,11 @@ package app
17
17
18
18
import (
19
19
"encoding/base64"
20
+ "errors"
20
21
"fmt"
21
22
"log"
23
+ "os"
24
+ "path/filepath"
22
25
"strconv"
23
26
"strings"
24
27
"time"
@@ -28,7 +31,7 @@ import (
28
31
29
32
"github.com/spf13/pflag"
30
33
31
- validator "github.com/go-playground/validator/v10 "
34
+ validator "github.com/asaskevich/govalidator "
32
35
)
33
36
34
37
type FlagType string
@@ -71,31 +74,59 @@ func initializePFlagMap() {
71
74
},
72
75
operatorFlag : func () pflag.Value {
73
76
// this validates a valid operator name
74
- return valueFactory (operatorFlag , validateString ("oneof=and or" ), "" )
77
+ operatorFlagValidator := func (val string ) error {
78
+ o := struct {
79
+ Value string `valid:"in(and|or)"`
80
+ }{val }
81
+ _ , err := validator .ValidateStruct (o )
82
+ return err
83
+ }
84
+ return valueFactory (operatorFlag , operatorFlagValidator , "" )
75
85
},
76
86
emailFlag : func () pflag.Value {
77
87
// this validates an email address
78
- return valueFactory (emailFlag , validateString ("required,email" ), "" )
88
+ emailValidator := func (val string ) error {
89
+ if ! validator .IsEmail (val ) {
90
+ return fmt .Errorf ("'%v' is not a valid email address" , val )
91
+ }
92
+ return nil
93
+ }
94
+ return valueFactory (emailFlag , emailValidator , "" )
79
95
},
80
96
logIndexFlag : func () pflag.Value {
81
97
// this checks for a valid integer >= 0
82
- return valueFactory (logIndexFlag , validateLogIndex , "" )
98
+ return valueFactory (logIndexFlag , validateUint , "" )
83
99
},
84
100
pkiFormatFlag : func () pflag.Value {
85
101
// this ensures a PKI implementation exists for the requested format
86
- return valueFactory (pkiFormatFlag , validateString (fmt .Sprintf ("required,oneof=%v" , strings .Join (pki .SupportedFormats (), " " ))), "pgp" )
102
+ pkiFormatValidator := func (val string ) error {
103
+ if ! validator .IsIn (val , pki .SupportedFormats ()... ) {
104
+ return fmt .Errorf ("'%v' is not a valid pki format" , val )
105
+ }
106
+ return nil
107
+ }
108
+ return valueFactory (pkiFormatFlag , pkiFormatValidator , "pgp" )
87
109
},
88
110
typeFlag : func () pflag.Value {
89
111
// this ensures the type of the log entry matches a type supported in the CLI
90
112
return valueFactory (typeFlag , validateTypeFlag , "rekord" )
91
113
},
92
114
fileFlag : func () pflag.Value {
93
115
// this validates that the file exists and can be opened by the current uid
94
- return valueFactory (fileFlag , validateString ( "required,file" ) , "" )
116
+ return valueFactory (fileFlag , validateFile , "" )
95
117
},
96
118
urlFlag : func () pflag.Value {
97
119
// this validates that the string is a valid http/https URL
98
- return valueFactory (urlFlag , validateString ("required,url,startswith=http|startswith=https" ), "" )
120
+ httpHTTPSValidator := func (val string ) error {
121
+ if ! validator .IsURL (val ) {
122
+ return fmt .Errorf ("'%v' is not a valid url" , val )
123
+ }
124
+ if ! (strings .HasPrefix (val , "http" ) || strings .HasPrefix (val , "https" )) {
125
+ return errors .New ("URL must be for http or https scheme" )
126
+ }
127
+ return nil
128
+ }
129
+ return valueFactory (urlFlag , httpHTTPSValidator , "" )
99
130
},
100
131
fileOrURLFlag : func () pflag.Value {
101
132
// applies logic of fileFlag OR urlFlag validators from above
@@ -111,7 +142,13 @@ func initializePFlagMap() {
111
142
},
112
143
formatFlag : func () pflag.Value {
113
144
// this validates the output format requested
114
- return valueFactory (formatFlag , validateString ("required,oneof=json default tle" ), "" )
145
+ formatValidator := func (val string ) error {
146
+ if ! validator .IsIn (val , "json" , "default" , "tle" ) {
147
+ return fmt .Errorf ("'%v' is not a valid output format" , val )
148
+ }
149
+ return nil
150
+ }
151
+ return valueFactory (formatFlag , formatValidator , "" )
115
152
},
116
153
timeoutFlag : func () pflag.Value {
117
154
// this validates the timeout is >= 0
@@ -257,33 +294,23 @@ func validateID(v string) error {
257
294
return fmt .Errorf ("ID len error, expected %v (EntryID) or %v (UUID) but got len %v for ID %v" , sharding .EntryIDHexStringLen , sharding .UUIDHexStringLen , len (v ), v )
258
295
}
259
296
260
- if err := validateString ( "required,hexadecimal" )( v ); err != nil {
297
+ if ! validator . IsHexadecimal ( v ) {
261
298
return fmt .Errorf ("invalid uuid: %v" , v )
262
299
}
263
300
264
301
return nil
265
302
}
266
303
267
- // validateLogIndex ensures that the supplied string is a valid log index (integer >= 0)
268
- func validateLogIndex (v string ) error {
269
- i , err := strconv .Atoi (v )
270
- if err != nil {
271
- return err
272
- }
273
- l := struct {
274
- Index int `validate:"gte=0"`
275
- }{i }
276
-
277
- return useValidator (logIndexFlag , l )
278
- }
279
-
280
304
// validateOID ensures that the supplied string is a valid ASN.1 object identifier
281
305
func validateOID (v string ) error {
282
- o := struct {
283
- Oid []string `validate:"dive,numeric"`
284
- }{strings .Split (v , "." )}
306
+ values := strings .Split (v , "." )
307
+ for _ , value := range values {
308
+ if ! validator .IsNumeric (value ) {
309
+ return fmt .Errorf ("field '%v' is not a valid number" , value )
310
+ }
311
+ }
285
312
286
- return useValidator ( oidFlag , o )
313
+ return nil
287
314
}
288
315
289
316
// validateTimeout ensures that the supplied string is a valid time.Duration value >= 0
@@ -292,10 +319,10 @@ func validateTimeout(v string) error {
292
319
if err != nil {
293
320
return err
294
321
}
295
- d := struct {
296
- Duration time. Duration `validate:"min=0"`
297
- }{ duration }
298
- return useValidator ( timeoutFlag , d )
322
+ if duration < 0 {
323
+ return errors . New ( "timeout must be a positive value" )
324
+ }
325
+ return nil
299
326
}
300
327
301
328
// validateBase64 ensures that the supplied string is valid base64 encoded data
@@ -312,26 +339,6 @@ func validateTypeFlag(v string) error {
312
339
return err
313
340
}
314
341
315
- // validateString returns a function that validates an input string against the specified tag,
316
- // as defined in the format supported by go-playground/validator
317
- func validateString (tag string ) validationFunc {
318
- return func (v string ) error {
319
- validator := validator .New ()
320
- return validator .Var (v , tag )
321
- }
322
- }
323
-
324
- // useValidator performs struct level validation on s as defined in the struct's tags using
325
- // the go-playground/validator library
326
- func useValidator (flagType FlagType , s interface {}) error {
327
- validate := validator .New ()
328
- if err := validate .Struct (s ); err != nil {
329
- return fmt .Errorf ("error parsing %v flag: %w" , flagType , err )
330
- }
331
-
332
- return nil
333
- }
334
-
335
342
// validateUint ensures that the supplied string is a valid unsigned integer >= 0
336
343
func validateUint (v string ) error {
337
344
i , err := strconv .Atoi (v )
@@ -341,9 +348,17 @@ func validateUint(v string) error {
341
348
if i < 0 {
342
349
return fmt .Errorf ("invalid unsigned int: %v" , v )
343
350
}
344
- u := struct {
345
- Uint uint `validate:"gte=0"`
346
- }{uint (i )}
351
+ return nil
352
+ }
347
353
348
- return useValidator (uintFlag , u )
354
+ // validateFile ensures that the supplied string is a valid path to a file that exists
355
+ func validateFile (v string ) error {
356
+ fileInfo , err := os .Stat (filepath .Clean (v ))
357
+ if err != nil {
358
+ return err
359
+ }
360
+ if fileInfo .IsDir () {
361
+ return errors .New ("path to a directory was provided" )
362
+ }
363
+ return nil
349
364
}
0 commit comments