Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
from typing_extensions import Annotated, get_args, get_origin

# Reassign variable to make it reexported for mypy
PYDANTIC_VERSION = P_VERSION
Expand Down Expand Up @@ -177,16 +177,17 @@ def is_field_noneable(field: "FieldInfo") -> bool:
return False
return False

def get_type_from_field(field: Any) -> Any:
type_: Any = field.annotation
def get_sa_type_from_type_annotation(annotation: Any) -> Any:
# Resolve Optional fields
if type_ is None:
if annotation is None:
raise ValueError("Missing field type")
origin = get_origin(type_)
origin = get_origin(annotation)
if origin is None:
return type_
return annotation
elif origin is Annotated:
return get_sa_type_from_type_annotation(get_args(annotation)[0])
if _is_union_type(origin):
bases = get_args(type_)
bases = get_args(annotation)
if len(bases) > 2:
raise ValueError(
"Cannot have a (non-optional) union as a SQLAlchemy field"
Expand All @@ -197,9 +198,14 @@ def get_type_from_field(field: Any) -> Any:
"Cannot have a (non-optional) union as a SQLAlchemy field"
)
# Optional unions are allowed
return bases[0] if bases[0] is not NoneType else bases[1]
use_type = bases[0] if bases[0] is not NoneType else bases[1]
return get_sa_type_from_type_annotation(use_type)
return origin

def get_sa_type_from_field(field: Any) -> Any:
type_: Any = field.annotation
return get_sa_type_from_type_annotation(type_)

def get_field_metadata(field: Any) -> Any:
for meta in field.metadata:
if isinstance(meta, (PydanticMetadata, MaxLen)):
Expand Down Expand Up @@ -444,7 +450,7 @@ def is_field_noneable(field: "FieldInfo") -> bool:
)
return field.allow_none # type: ignore[no-any-return, attr-defined]

def get_type_from_field(field: Any) -> Any:
def get_sa_type_from_field(field: Any) -> Any:
if isinstance(field.type_, type) and field.shape == SHAPE_SINGLETON:
return field.type_
raise ValueError(f"The field {field.name} has no matching SQLAlchemy type")
Expand Down
4 changes: 2 additions & 2 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
get_field_metadata,
get_model_fields,
get_relationship_to,
get_type_from_field,
get_sa_type_from_field,
init_pydantic_private_attrs,
is_field_noneable,
is_table_model_class,
Expand Down Expand Up @@ -649,7 +649,7 @@ def get_sqlalchemy_type(field: Any) -> Any:
if sa_type is not Undefined:
return sa_type

type_ = get_type_from_field(field)
type_ = get_sa_type_from_field(field)
metadata = get_field_metadata(field)

# Check enums first as an enum can also be a str, needed by Pydantic/FastAPI
Expand Down
26 changes: 26 additions & 0 deletions tests/test_annotated_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import uuid
from typing import Optional

from sqlmodel import Field, Session, SQLModel, create_engine, select

from tests.conftest import needs_pydanticv2


@needs_pydanticv2
def test_annotated_optional_types(clear_sqlmodel) -> None:
from pydantic import UUID4

class Hero(SQLModel, table=True):
# Pydantic UUID4 is: Annotated[UUID, UuidVersion(4)]
id: Optional[UUID4] = Field(default_factory=uuid.uuid4, primary_key=True)

engine = create_engine("sqlite:///:memory:")
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
db.commit()
statement = select(Hero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(hero.id, uuid.UUID)