Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@

class SQLAlchemyBuildContext(BaseBuildContext):
skip_computed_fields: bool
skip_system_fields: bool


class SQLAlchemyConstraints(Constraints):
computed: NotRequired[bool]
system: NotRequired[bool]


class SQLAlchemyPersistenceMethod(enum.Enum):
Expand Down Expand Up @@ -161,20 +163,24 @@ def _get_build_context(
build_context = cast("SQLAlchemyBuildContext", super()._get_build_context(build_context))
if build_context.get("skip_computed_fields") is None:
build_context["skip_computed_fields"] = False
if build_context.get("skip_system_fields") is None:
build_context["skip_system_fields"] = False

return build_context

@classmethod
def create_sync(cls, **kwargs: Any) -> T:
build_context = cls._get_build_context(kwargs.get("_build_context"))
build_context["skip_computed_fields"] = True
build_context["skip_system_fields"] = True
kwargs["_build_context"] = build_context
return super().create_sync(**kwargs)

@classmethod
async def create_async(cls, **kwargs: Any) -> T:
build_context = cls._get_build_context(kwargs.get("_build_context"))
build_context["skip_computed_fields"] = True
build_context["skip_system_fields"] = True
kwargs["_build_context"] = build_context
return await super().create_async(**kwargs)

Expand Down Expand Up @@ -240,6 +246,8 @@ def should_set_field_value(cls, field_meta: FieldMeta, **kwargs: Any) -> bool:
constraints = cast("SQLAlchemyConstraints", field_meta.constraints)
if constraints.get("computed") and build_context.get("skip_computed_fields"):
return False
if constraints.get("system") and build_context.get("skip_system_fields"):
return False

return super().should_set_field_value(field_meta, **kwargs)

Expand Down Expand Up @@ -309,6 +317,10 @@ def get_type_from_column(cls, column: Column) -> type:
constraints: SQLAlchemyConstraints = {"computed": True}
annotation = Annotated[annotation, Frozendict(constraints)] # type: ignore[assignment]

if getattr(column, "system", False):
system_constraints: SQLAlchemyConstraints = {"system": True}
annotation = Annotated[annotation, Frozendict(system_constraints)] # type: ignore[assignment]

return annotation

@classmethod
Expand Down
Loading