Skip to content

Commit 6f2c733

Browse files
authored
Merge pull request #106 from seapagan/fix/transaction-rollback-issue-104
Fix transaction rollback bug (Issue #104)
2 parents 471a57b + 8f2b863 commit 6f2c733

File tree

9 files changed

+486
-70
lines changed

9 files changed

+486
-70
lines changed

docs/guide/transactions.md

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,8 @@ with db:
99
# If an exception occurs, the transaction will be rolled back
1010
```
1111

12-
> [!WARNING]
13-
> **Known Issue:** Transaction rollback is currently broken. Changes made
14-
> inside a transaction are NOT rolled back when an exception occurs.
15-
>
16-
> **Status:** This is being tracked in [issue #104](https://github.com/seapagan/sqliter-py/issues/104).
17-
>
18-
> **Workaround:** Do not rely on transaction rollback for data integrity until
19-
> this is fixed. All changes are committed immediately.
20-
>
21-
> Using the context manager will automatically commit the transaction at the end
22-
> (unless an exception occurs), regardless of the `auto_commit` setting. The
23-
> `close()` method will also be called when the context manager exits, so you
24-
> do not need to call it manually.
12+
Using the context manager will automatically commit the transaction at the end
13+
(unless an exception occurs), regardless of the `auto_commit` setting. If an
14+
exception occurs, all changes made within the transaction block are rolled back.
15+
The `close()` method will also be called when the context manager exits, so you
16+
do not need to call it manually.

docs/tui-demo/errors.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ print("\nAttempting to create product with invalid data...")
114114

115115
try:
116116
# Wrong types: price should be float, quantity should be int
117-
invalid_product = Product(name="Invalid Widget", price="free", quantity="lots")
118-
db.insert(invalid_product)
117+
# ValidationError is raised by Pydantic during model instantiation
118+
Product(name="Invalid Widget", price="free", quantity="lots")
119119
except ValidationError as e:
120120
print(f"\nCaught error: {type(e).__name__}")
121121
print(f"Message: {e}")
@@ -316,12 +316,12 @@ try:
316316
raise ValueError("Invalid operation")
317317
except ValueError as e:
318318
print(f"Transaction failed: {e}")
319-
# Note: Changes are NOT rolled back due to bug (issue #104)
319+
print("Changes rolled back automatically")
320320

321-
# Verify balance unchanged
321+
# Verify balance unchanged (rollback restored original value)
322322
reloaded = db.get(Account, account.pk)
323323
if reloaded is not None:
324-
print(f"Balance: {reloaded.balance}") # Was 100.0
324+
print(f"Balance: {reloaded.balance}") # Still 100.0
325325
```
326326

327327
## Error Handling Best Practices

docs/tui-demo/transactions.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ db.close()
4343

4444
### What Happens
4545

46-
- Both updates succeed or both fail (intended behaviour)
47-
- If an error occurs, changes should be rolled back (see warning below about current limitations)
46+
- Both updates succeed or both fail
47+
- If an error occurs, changes are rolled back
4848
- Database remains in a consistent state
4949

5050
## Transaction Rollback
@@ -83,12 +83,6 @@ db.close()
8383
# --8<-- [end:transaction-rollback]
8484
```
8585

86-
!!! warning
87-
**Known Issue:** Transaction rollback is currently broken in SQLiter.
88-
The `update()`, `insert()`, and `delete()` methods use nested context
89-
managers that commit prematurely. This demo shows the expected behavior,
90-
but actual rollback may not work correctly until this is fixed.
91-
9286
### Rollback Behavior
9387

9488
- All changes within the transaction should be undone

sqliter/query/query.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -601,10 +601,10 @@ def _execute_query(
601601
self.db._log_sql(sql, values) # noqa: SLF001
602602

603603
try:
604-
with self.db.connect() as conn:
605-
cursor = conn.cursor()
606-
cursor.execute(sql, values)
607-
return cursor.fetchall() if not fetch_one else cursor.fetchone()
604+
conn = self.db.connect()
605+
cursor = conn.cursor()
606+
cursor.execute(sql, values)
607+
return cursor.fetchall() if not fetch_one else cursor.fetchone()
608608
except sqlite3.Error as exc:
609609
raise RecordFetchError(self.table_name) from exc
610610

@@ -911,12 +911,16 @@ def delete(self) -> int:
911911
self.db._log_sql(sql, values) # noqa: SLF001
912912

913913
try:
914-
with self.db.connect() as conn:
915-
cursor = conn.cursor()
916-
cursor.execute(sql, values)
917-
deleted_count = cursor.rowcount
918-
self.db._maybe_commit() # noqa: SLF001
919-
self.db._cache_invalidate_table(self.table_name) # noqa: SLF001
920-
return deleted_count
914+
conn = self.db.connect()
915+
cursor = conn.cursor()
916+
cursor.execute(sql, values)
917+
deleted_count = cursor.rowcount
918+
self.db._maybe_commit() # noqa: SLF001
919+
self.db._cache_invalidate_table(self.table_name) # noqa: SLF001
921920
except sqlite3.Error as exc:
921+
# Rollback implicit transaction if not in user-managed transaction
922+
if not self.db._in_transaction and self.db.conn: # noqa: SLF001
923+
self.db.conn.rollback()
922924
raise RecordDeletionError(self.table_name) from exc
925+
else:
926+
return deleted_count

sqliter/sqliter.py

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -888,12 +888,15 @@ def insert(
888888
""" # noqa: S608
889889

890890
try:
891-
with self.connect() as conn:
892-
cursor = conn.cursor()
893-
cursor.execute(insert_sql, values)
894-
self._maybe_commit()
891+
conn = self.connect()
892+
cursor = conn.cursor()
893+
cursor.execute(insert_sql, values)
894+
self._maybe_commit()
895895

896896
except sqlite3.IntegrityError as exc:
897+
# Rollback implicit transaction if not in user-managed transaction
898+
if not self._in_transaction and self.conn:
899+
self.conn.rollback()
897900
# Check for foreign key constraint violation
898901
if "FOREIGN KEY constraint failed" in str(exc):
899902
fk_operation = "insert"
@@ -903,6 +906,9 @@ def insert(
903906
) from exc
904907
raise RecordInsertionError(table_name) from exc
905908
except sqlite3.Error as exc:
909+
# Rollback implicit transaction if not in user-managed transaction
910+
if not self._in_transaction and self.conn:
911+
self.conn.rollback()
906912
raise RecordInsertionError(table_name) from exc
907913
else:
908914
self._cache_invalidate_table(table_name)
@@ -936,10 +942,10 @@ def get(
936942
""" # noqa: S608
937943

938944
try:
939-
with self.connect() as conn:
940-
cursor = conn.cursor()
941-
cursor.execute(select_sql, (primary_key_value,))
942-
result = cursor.fetchone()
945+
conn = self.connect()
946+
cursor = conn.cursor()
947+
cursor.execute(select_sql, (primary_key_value,))
948+
result = cursor.fetchone()
943949

944950
if result:
945951
result_dict = {
@@ -990,18 +996,26 @@ def update(self, model_instance: BaseDBModel) -> None:
990996
""" # noqa: S608
991997

992998
try:
993-
with self.connect() as conn:
994-
cursor = conn.cursor()
995-
cursor.execute(update_sql, (*values, primary_key_value))
999+
conn = self.connect()
1000+
cursor = conn.cursor()
1001+
cursor.execute(update_sql, (*values, primary_key_value))
9961002

997-
# Check if any rows were updated
998-
if cursor.rowcount == 0:
999-
raise RecordNotFoundError(primary_key_value)
1003+
# Check if any rows were updated
1004+
if cursor.rowcount == 0:
1005+
raise RecordNotFoundError(primary_key_value) # noqa: TRY301
10001006

1001-
self._maybe_commit()
1002-
self._cache_invalidate_table(table_name)
1007+
self._maybe_commit()
1008+
self._cache_invalidate_table(table_name)
10031009

1010+
except RecordNotFoundError:
1011+
# Rollback implicit transaction if not in user-managed transaction
1012+
if not self._in_transaction and self.conn:
1013+
self.conn.rollback()
1014+
raise
10041015
except sqlite3.Error as exc:
1016+
# Rollback implicit transaction if not in user-managed transaction
1017+
if not self._in_transaction and self.conn:
1018+
self.conn.rollback()
10051019
raise RecordUpdateError(table_name) from exc
10061020

10071021
def delete(
@@ -1026,15 +1040,23 @@ def delete(
10261040
""" # noqa: S608
10271041

10281042
try:
1029-
with self.connect() as conn:
1030-
cursor = conn.cursor()
1031-
cursor.execute(delete_sql, (primary_key_value,))
1043+
conn = self.connect()
1044+
cursor = conn.cursor()
1045+
cursor.execute(delete_sql, (primary_key_value,))
10321046

1033-
if cursor.rowcount == 0:
1034-
raise RecordNotFoundError(primary_key_value)
1035-
self._maybe_commit()
1036-
self._cache_invalidate_table(table_name)
1047+
if cursor.rowcount == 0:
1048+
raise RecordNotFoundError(primary_key_value) # noqa: TRY301
1049+
self._maybe_commit()
1050+
self._cache_invalidate_table(table_name)
1051+
except RecordNotFoundError:
1052+
# Rollback implicit transaction if not in user-managed transaction
1053+
if not self._in_transaction and self.conn:
1054+
self.conn.rollback()
1055+
raise
10371056
except sqlite3.IntegrityError as exc:
1057+
# Rollback implicit transaction if not in user-managed transaction
1058+
if not self._in_transaction and self.conn:
1059+
self.conn.rollback()
10381060
# Check for foreign key constraint violation (RESTRICT)
10391061
if "FOREIGN KEY constraint failed" in str(exc):
10401062
fk_operation = "delete"
@@ -1044,6 +1066,9 @@ def delete(
10441066
) from exc
10451067
raise RecordDeletionError(table_name) from exc
10461068
except sqlite3.Error as exc:
1069+
# Rollback implicit transaction if not in user-managed transaction
1070+
if not self._in_transaction and self.conn:
1071+
self.conn.rollback()
10471072
raise RecordDeletionError(table_name) from exc
10481073

10491074
def select(

sqliter/tui/demos/errors.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ class Product(BaseDBModel):
102102

103103
try:
104104
# Wrong types: price should be float, quantity should be int
105-
invalid_product = Product(
106-
name="Invalid Widget", price="free", quantity="lots"
107-
)
108-
db.insert(invalid_product)
105+
# ValidationError is raised by Pydantic during model instantiation
106+
Product(name="Invalid Widget", price="free", quantity="lots")
109107
except ValidationError as e:
110108
output.write(f"\nCaught error: {type(e).__name__}\n")
111109
output.write(f"Message: {e}\n")

sqliter/tui/demos/transactions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ class Account(BaseDBModel):
5151
def _run_rollback() -> str:
5252
"""Demonstrate transaction rollback behavior.
5353
54-
NOTE: This demo currently shows a BUG in SQLiter's transaction handling.
55-
The value should be restored to 10 after rollback, but it's not.
56-
See: https://github.com/seapagan/sqliter-py/issues/104
54+
When an exception occurs inside a `with db:` block, all changes made
55+
within that transaction are automatically rolled back.
5756
"""
5857
output = io.StringIO()
5958

@@ -85,7 +84,6 @@ class Item(BaseDBModel):
8584
except RuntimeError:
8685
output.write("Error occurred - transaction rolled back\n")
8786
# Verify rollback with NEW connection
88-
# BUG: This shows 5 instead of 10 - rollback doesn't work!
8987
db2 = SqliterDB(db_filename=db_path)
9088
try:
9189
restored = db2.get(Item, item.pk)
@@ -98,9 +96,11 @@ class Item(BaseDBModel):
9896
if restored_quantity == expected_quantity:
9997
output.write("✓ Rollback worked correctly\n")
10098
else:
101-
output.write(
102-
"✗ BUG: Rollback failed (expected 10, got 5)\n"
99+
msg = (
100+
f"✗ Rollback failed: expected {expected_quantity}, "
101+
f"got {restored_quantity}\n"
103102
)
103+
output.write(msg)
104104
finally:
105105
db2.close()
106106
finally:

0 commit comments

Comments
 (0)