Skip to content

Commit 81fbad6

Browse files
tim-bandTim Band
andauthored
Fixed some exception escapes (#71)
Co-authored-by: Tim Band <t.b@ucl>
1 parent 8d43205 commit 81fbad6

File tree

5 files changed

+136
-9
lines changed

5 files changed

+136
-9
lines changed

datafaker/interactive/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class DbCmd(ABC, cmd.Cmd):
104104
"Error: '{0}' is not the name of a table"
105105
" in this database or a column in this table"
106106
)
107+
ERROR_FAILED_SQL = 'SQL query "{query}" caused exception {exc}'
108+
ERROR_FAILED_DISPLAY = "Error: Failed to display: {}"
107109
ROW_COUNT_MSG = "Total row count: {}"
108110

109111
@abstractmethod
@@ -325,7 +327,7 @@ def do_counts(self, _arg: str) -> None:
325327
return
326328
table_name = self.table_name()
327329
nonnull_columns = self.get_nonnull_columns(table_name)
328-
colcounts = [f", COUNT({nnc}) AS {nnc}" for nnc in nonnull_columns]
330+
colcounts = [f', COUNT("{nnc}") AS "{nnc}"' for nnc in nonnull_columns]
329331
with self.sync_engine.connect() as connection:
330332
result = (
331333
connection.execute(
@@ -353,19 +355,24 @@ def do_counts(self, _arg: str) -> None:
353355
def do_select(self, arg: str) -> None:
354356
"""Run a select query over the database and show the first 50 results."""
355357
max_select_rows = 50
358+
query = "SELECT " + arg
356359
with self.sync_engine.connect() as connection:
357360
try:
358-
result = connection.execute(sqlalchemy.text("SELECT " + arg))
361+
result = connection.execute(sqlalchemy.text(query))
359362
except sqlalchemy.exc.DatabaseError as exc:
360-
self.print("Failed to execute: {}", exc)
363+
self.print(self.ERROR_FAILED_SQL, exc, query)
361364
return
362365
row_count = result.rowcount
363366
self.print(self.ROW_COUNT_MSG, row_count)
364367
if 50 < row_count:
365368
self.print("Showing the first {} rows", max_select_rows)
366369
fields = list(result.keys())
367370
rows = result.fetchmany(max_select_rows)
368-
self.print_table(fields, rows)
371+
try:
372+
self.print_table(fields, rows)
373+
except ValueError as exc:
374+
self.print(self.ERROR_FAILED_DISPLAY, exc)
375+
return
369376

370377
def do_peek(self, arg: str) -> None:
371378
"""
@@ -383,9 +390,9 @@ def do_peek(self, arg: str) -> None:
383390
col_names = arg.split()
384391
if not col_names:
385392
col_names = self._get_column_names()
386-
nonnulls = [cn + " IS NOT NULL" for cn in col_names]
393+
nonnulls = [f'"{cn}" IS NOT NULL' for cn in col_names]
387394
with self.sync_engine.connect() as connection:
388-
cols = ",".join(col_names)
395+
cols = ", ".join(f'"{cn}"' for cn in col_names)
389396
where = "WHERE" if nonnulls else ""
390397
nonnull = " OR ".join(nonnulls)
391398
query = sqlalchemy.text(
@@ -395,7 +402,7 @@ def do_peek(self, arg: str) -> None:
395402
try:
396403
result = connection.execute(query)
397404
except sqlalchemy.exc.SQLAlchemyError as exc:
398-
self.print(f'SQL query "{query}" caused exception {exc}')
405+
self.print(self.ERROR_FAILED_SQL, exc, query)
399406
return
400407
self.print_table(list(result.keys()), result.fetchmany(max_peek_rows))
401408

datafaker/interactive/table.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def __init__(
7979
src_schema: str | None,
8080
metadata: MetaData,
8181
config: MutableMapping[str, Any],
82+
*args: Any,
83+
**kwargs: Any,
8284
) -> None:
8385
"""Initialise a TableCmd."""
84-
super().__init__(src_dsn, src_schema, metadata, config)
86+
super().__init__(src_dsn, src_schema, metadata, config, *args, **kwargs)
8587
self.set_prompt()
8688

8789
@property

tests/examples/tricky.sql

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- DROP DATABASE IF EXISTS tricky WITH (FORCE);
2+
CREATE DATABASE tricky WITH TEMPLATE template0 ENCODING = 'UTF8' LOCALE = 'en_US.utf8';
3+
ALTER DATABASE tricky OWNER TO postgres;
4+
5+
\connect tricky
6+
7+
CREATE TABLE public.names (
8+
id INTEGER NOT NULL,
9+
"offset" INTEGER,
10+
"count" INTEGER NOT NULL,
11+
sensible TEXT
12+
);
13+
14+
ALTER TABLE ONLY public.names ADD CONSTRAINT names_pkey PRIMARY KEY (id);
15+
16+
ALTER TABLE public.names OWNER TO postgres;
17+
18+
INSERT INTO public.names VALUES (1, 10, 5, 'reasonable');
19+
INSERT INTO public.names VALUES (2, NULL, 6, 'clear-headed');

tests/test_interactive_table.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import select
66

77
from datafaker.interactive import TableCmd
8+
from datafaker.serialize_metadata import dict_to_metadata
89
from tests.utils import RequiresDBTestCase, TestDbCmdMixin
910

1011

@@ -396,3 +397,98 @@ def test_sanity_checks_errors_only(self) -> None:
396397
{},
397398
),
398399
)
400+
401+
402+
class TrickyTests(ConfigureTablesTests):
403+
"""Testing configure-tables with the instrument.sql database."""
404+
405+
dump_file_path = "tricky.sql"
406+
database_name = "tricky"
407+
schema_name = "public"
408+
409+
def do_and_test_peek_tricky(self, tc: TestTableCmd) -> None:
410+
"""Peek the "names" table and check the output."""
411+
tc.reset()
412+
tc.do_peek("")
413+
self.assertSetEqual(set(tc.headings), {"id", "offset", "count", "sensible"})
414+
self.assertSetEqual(
415+
set(tc.rows), {(1, 10, 5, "reasonable"), (2, None, 6, "clear-headed")}
416+
)
417+
418+
def test_peek_with_tricky_names(self) -> None:
419+
"""
420+
Peek with column names that are function names (#66).
421+
"""
422+
with self._get_cmd({}) as tc:
423+
tc.do_next("names")
424+
self.do_and_test_peek_tricky(tc)
425+
426+
def test_count_with_tricky_names(self) -> None:
427+
"""
428+
Count with column names that are function names (#66).
429+
"""
430+
with self._get_cmd({}) as tc:
431+
tc.do_next("names")
432+
self.do_and_test_peek_tricky(tc)
433+
tc.do_counts("")
434+
self.assertSequenceEqual(tc.rows, [["offset", 1], ["sensible", 0]])
435+
436+
def test_incorrect_orm_yaml_columns(self) -> None:
437+
"""
438+
Peek with incorrect columns in orm.yaml (#70).
439+
"""
440+
self.metadata = dict_to_metadata(
441+
{
442+
"tables": {
443+
"names": {
444+
"columns": {
445+
"id": {
446+
"primary": True,
447+
"nullable": False,
448+
"type": "INTEGER",
449+
},
450+
"sensible": {
451+
"primary": False,
452+
"nullable": False,
453+
"type": "TEXT",
454+
},
455+
"nonexistent": {
456+
"primary": False,
457+
"nullable": False,
458+
"type": "TEXT",
459+
},
460+
}
461+
}
462+
}
463+
}
464+
)
465+
with self._get_cmd({}) as tc:
466+
tc.reset()
467+
tc.do_peek("")
468+
self.assertIn("SQL query", "/".join(m for (m, _a, _kw) in tc.messages))
469+
470+
def test_repeated_field_does_not_throw_exception(self) -> None:
471+
"""
472+
Select with repeated fields (#70).
473+
"""
474+
with TestTableCmd(
475+
src_dsn=self.dsn,
476+
src_schema=self.schema_name,
477+
metadata=self.metadata,
478+
config={},
479+
print_tables=True,
480+
) as tc:
481+
tc.reset()
482+
tc.do_select('sensible AS same, "offset" AS same FROM names')
483+
self.assertIn(
484+
"Failed to display", "/".join(m for (m, _a, _kw) in tc.messages)
485+
)
486+
487+
def test_sql_error_does_not_throw_exception(self) -> None:
488+
"""
489+
Select with a SQL error.
490+
"""
491+
with self._get_cmd({}) as tc:
492+
tc.reset()
493+
tc.do_select("+++")
494+
self.assertIn("SQL query", "/".join(m for (m, a, kw) in tc.messages))

tests/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,10 @@ def generate_data(
293293
class TestDbCmdMixin(DbCmd):
294294
"""A mixin for capturing output from interactive commands."""
295295

296-
def __init__(self, *args: Any, **kwargs: Any) -> None:
296+
def __init__(self, *args: Any, print_tables: bool = False, **kwargs: Any) -> None:
297297
"""Initialize a TestDbCmdMixin"""
298298
super().__init__(*args, **kwargs)
299+
self._print_tables = print_tables
299300
self.reset()
300301

301302
def reset(self) -> None:
@@ -316,6 +317,8 @@ def print_table(
316317
"""Capture the printed table."""
317318
self.headings = headings
318319
self.rows = rows
320+
if self._print_tables:
321+
super().print_table(headings, rows)
319322

320323
def print_table_by_columns(self, columns: Mapping[str, Sequence[str]]) -> None:
321324
"""Capture the printed table."""

0 commit comments

Comments
 (0)