2121from pydantic import VERSION as P_VERSION
2222from pydantic import BaseModel
2323from pydantic .fields import FieldInfo
24- from typing_extensions import get_args , get_origin
24+ from typing_extensions import Annotated , get_args , get_origin
2525
2626# Reassign variable to make it reexported for mypy
2727PYDANTIC_VERSION = P_VERSION
@@ -177,16 +177,17 @@ def is_field_noneable(field: "FieldInfo") -> bool:
177177 return False
178178 return False
179179
180- def get_type_from_field (field : Any ) -> Any :
181- type_ : Any = field .annotation
180+ def get_sa_type_from_type_annotation (annotation : Any ) -> Any :
182181 # Resolve Optional fields
183- if type_ is None :
182+ if annotation is None :
184183 raise ValueError ("Missing field type" )
185- origin = get_origin (type_ )
184+ origin = get_origin (annotation )
186185 if origin is None :
187- return type_
186+ return annotation
187+ elif origin is Annotated :
188+ return get_sa_type_from_type_annotation (get_args (annotation )[0 ])
188189 if _is_union_type (origin ):
189- bases = get_args (type_ )
190+ bases = get_args (annotation )
190191 if len (bases ) > 2 :
191192 raise ValueError (
192193 "Cannot have a (non-optional) union as a SQLAlchemy field"
@@ -197,9 +198,14 @@ def get_type_from_field(field: Any) -> Any:
197198 "Cannot have a (non-optional) union as a SQLAlchemy field"
198199 )
199200 # Optional unions are allowed
200- return bases [0 ] if bases [0 ] is not NoneType else bases [1 ]
201+ use_type = bases [0 ] if bases [0 ] is not NoneType else bases [1 ]
202+ return get_sa_type_from_type_annotation (use_type )
201203 return origin
202204
205+ def get_sa_type_from_field (field : Any ) -> Any :
206+ type_ : Any = field .annotation
207+ return get_sa_type_from_type_annotation (type_ )
208+
203209 def get_field_metadata (field : Any ) -> Any :
204210 for meta in field .metadata :
205211 if isinstance (meta , (PydanticMetadata , MaxLen )):
@@ -444,7 +450,7 @@ def is_field_noneable(field: "FieldInfo") -> bool:
444450 )
445451 return field .allow_none # type: ignore[no-any-return, attr-defined]
446452
447- def get_type_from_field (field : Any ) -> Any :
453+ def get_sa_type_from_field (field : Any ) -> Any :
448454 if isinstance (field .type_ , type ) and field .shape == SHAPE_SINGLETON :
449455 return field .type_
450456 raise ValueError (f"The field { field .name } has no matching SQLAlchemy type" )
0 commit comments