@@ -50,25 +50,27 @@ def initialize(cls):
50
50
is_generic = hasattr (field_type , "__origin__" )
51
51
if (
52
52
is_generic
53
- and field_type . __origin__ == Union
54
- and field_type . __args__ [- 1 ] == None .__class__
53
+ and typing_get_origin ( field_type ) == Union
54
+ and typing_get_args ( field_type ) [- 1 ] == None .__class__
55
55
):
56
- field_type = field_type . __args__ [0 ]
56
+ field_type = typing_get_args ( field_type ) [0 ]
57
57
is_generic = hasattr (field_type , "__origin__" )
58
58
59
59
if (
60
60
is_generic
61
- and field_type . __origin__ == List
62
- and issubclass (field_type . __args__ [0 ], Model )
61
+ and typing_get_origin ( field_type ) in ( List , list )
62
+ and issubclass (typing_get_args ( field_type ) [0 ], Model )
63
63
):
64
- cls ._nested_model_list_fields [field ] = field_type .__args__ [0 ]
64
+ cls ._nested_model_list_fields [field ] = typing_get_args (field_type )[
65
+ 0
66
+ ]
65
67
66
68
elif (
67
69
is_generic
68
- and field_type . __origin__ == Tuple
69
- and any ([issubclass (v , Model ) for v in field_type . __args__ ])
70
+ and typing_get_origin ( field_type ) in ( Tuple , tuple )
71
+ and any ([issubclass (v , Model ) for v in typing_get_args ( field_type ) ])
70
72
):
71
- cls ._nested_model_tuple_fields [field ] = field_type . __args__
73
+ cls ._nested_model_tuple_fields [field ] = typing_get_args ( field_type )
72
74
73
75
elif issubclass (field_type , Model ):
74
76
cls ._nested_model_fields [field ] = field_type
@@ -401,7 +403,7 @@ def __get_select_fields(cls, columns: Optional[List[str]]) -> Optional[List[str]
401
403
if isinstance (field_type , type (Model )):
402
404
fields .append (f"{ NESTED_MODEL_PREFIX } { col } " )
403
405
elif issubclass (field_type , List ) and isinstance (
404
- field_type . __args__ [0 ], type (Model )
406
+ typing_get_args ( field_type ) [0 ], type (Model )
405
407
):
406
408
fields .append (f"{ NESTED_MODEL_LIST_FIELD_PREFIX } { col } " )
407
409
else :
@@ -487,3 +489,19 @@ def strip_leading(word: str, substring: str) -> str:
487
489
if word .startswith (substring ):
488
490
return word [len (substring ) :]
489
491
return word
492
+
493
+
494
+ def typing_get_args (v : Any ) -> Tuple [Any , ...]:
495
+ """Gets the __args__ of the annotations of a given typing"""
496
+ try :
497
+ return typing .get_args (v )
498
+ except AttributeError :
499
+ return getattr (v , "__args__" , ()) if v is not typing .Generic else typing .Generic
500
+
501
+
502
+ def typing_get_origin (v : Any ) -> Optional [Any ]:
503
+ """Gets the __origin__ of the annotations of a given typing"""
504
+ try :
505
+ return typing .get_origin (v )
506
+ except AttributeError :
507
+ return getattr (v , "__origin__" , None )
0 commit comments