diff --git a/sqlmodel/_compat.py b/sqlmodel/_compat.py index a220b193f1..cb3d296e79 100644 --- a/sqlmodel/_compat.py +++ b/sqlmodel/_compat.py @@ -242,6 +242,9 @@ def sqlmodel_table_construct( _extra = {} for k, v in values.items(): _extra[k] = v + setattr_ = ( + object.__setattr__ if self_instance.model_config.get("frozen") else setattr + ) # SQLModel override, do not include everything, only the model fields # else: # fields_values.update(values) @@ -251,7 +254,7 @@ def sqlmodel_table_construct( # object.__setattr__(new_obj, "__dict__", fields_values) # instrumentation for key, value in {**old_dict, **fields_values}.items(): - setattr(self_instance, key, value) + setattr_(self_instance, key, value) # End SQLModel override object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set) if not cls.__pydantic_root_model__: @@ -268,7 +271,7 @@ def sqlmodel_table_construct( for key in self_instance.__sqlmodel_relationships__: value = values.get(key, Undefined) if value is not Undefined: - setattr(self_instance, key, value) + setattr_(self_instance, key, value) # End SQLModel override return self_instance @@ -305,6 +308,7 @@ def sqlmodel_validate( context=context, self_instance=new_obj, ) + setattr_ = object.__setattr__ if new_obj.model_config.get("frozen") else setattr # Capture fields set to restore it later fields_set = new_obj.__pydantic_fields_set__.copy() if not is_table_model_class(cls): @@ -314,7 +318,7 @@ def sqlmodel_validate( # Do not set __dict__, instead use setattr to trigger SQLAlchemy # instrumentation for key, value in {**old_dict, **new_obj.__dict__}.items(): - setattr(new_obj, key, value) + setattr_(new_obj, key, value) # Restore fields set object.__setattr__(new_obj, "__pydantic_fields_set__", fields_set) # Get and set any relationship objects @@ -322,7 +326,7 @@ def sqlmodel_validate( for key in new_obj.__sqlmodel_relationships__: value = getattr(use_obj, key, Undefined) if value is not Undefined: - setattr(new_obj, key, value) + setattr_(new_obj, key, value) return new_obj diff --git a/tests/test_frozen.py b/tests/test_frozen.py new file mode 100644 index 0000000000..b11befe71e --- /dev/null +++ b/tests/test_frozen.py @@ -0,0 +1,80 @@ +import pytest +from pydantic import ConfigDict, ValidationError +from sqlmodel import Field, Session, SQLModel, create_engine, select + + +def test_frozen_non_table_model_creation(clear_sqlmodel): + class HeroBase(SQLModel): + model_config = ConfigDict(frozen=True) + + name: str + age: int | None = None + + hero = HeroBase(name="Deadpond", age=30) + + assert hero.name == "Deadpond" + assert hero.age == 30 + + +def test_frozen_non_table_model_is_immutable(clear_sqlmodel): + class HeroBase(SQLModel): + model_config = ConfigDict(frozen=True) + + name: str + age: int | None = None + + hero = HeroBase(name="Deadpond", age=30) + + with pytest.raises((ValidationError, TypeError)): + hero.name = "Spider-Boy" # type: ignore[misc] + + +def test_frozen_table_model_creation(clear_sqlmodel): + class Hero(SQLModel, table=True): + model_config = ConfigDict(frozen=True) + + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + hero = Hero(name="Deadpond", age=30) + + assert hero.name == "Deadpond" + assert hero.age == 30 + + +def test_frozen_table_model_persists_and_retrieves(clear_sqlmodel): + class Hero(SQLModel, table=True): + model_config = ConfigDict(frozen=True) + + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + engine = create_engine("sqlite://") + SQLModel.metadata.create_all(engine) + + with Session(engine) as session: + hero = Hero(name="Deadpond", age=30) + session.add(hero) + session.commit() + session.refresh(hero) + + with Session(engine) as session: + retrieved = session.exec(select(Hero)).one() + assert retrieved.name == "Deadpond" + assert retrieved.age == 30 + + +def test_frozen_table_model_validate(clear_sqlmodel): + class Hero(SQLModel, table=True): + model_config = ConfigDict(frozen=True) + + id: int | None = Field(default=None, primary_key=True) + name: str + age: int | None = None + + hero = Hero.model_validate({"name": "Deadpond", "age": 30}) + + assert hero.name == "Deadpond" + assert hero.age == 30