diff --git a/redash/query_runner/trino.py b/redash/query_runner/trino.py index 6bb6c5bccc..fca8977a38 100644 --- a/redash/query_runner/trino.py +++ b/redash/query_runner/trino.py @@ -1,5 +1,6 @@ import logging import os +import re from redash.models.users import ApiUser, User from redash.query_runner import ( @@ -62,6 +63,31 @@ def _convert_row_types(value): } +_DECIMAL_SCALE_RE = re.compile(r"^decimal\(\d+,\s*(\d+)\)$") + + +def _map_trino_type(type_name): + """Map a Trino type name to a Redash column type. + + Handles parameterised types such as ``timestamp(3)`` or ``decimal(10,2)`` + by falling back to the base type when an exact match is not found. + """ + if not type_name: + return None + mapped = TRINO_TYPES_MAPPING.get(type_name) + if mapped is not None: + return mapped + # Strip parameters: "timestamp(3)" -> "timestamp" + base = type_name.split("(", 1)[0] + mapped = TRINO_TYPES_MAPPING.get(base) + # decimal(p, s) with s > 0 has fractional digits + if base == "decimal": + m = _DECIMAL_SCALE_RE.match(type_name) + if m and int(m.group(1)) > 0: + mapped = TYPE_FLOAT + return mapped + + class Trino(BaseQueryRunner): noop_query = "SELECT 1" should_annotate_query = ANNOTATE_QUERY @@ -215,7 +241,7 @@ def run_query(self, query, user): cursor.execute(query) results = cursor.fetchall() description = cursor.description - columns = self.fetch_columns([(c[0], TRINO_TYPES_MAPPING.get(c[1], None)) for c in description]) + columns = self.fetch_columns([(c[0], _map_trino_type(c[1])) for c in description]) column_names = [c["name"] for c in columns] rows = [dict(zip(column_names, [_convert_row_types(v) for v in r])) for r in results] data = {"columns": columns, "rows": rows} diff --git a/tests/query_runner/test_trino.py b/tests/query_runner/test_trino.py index b5fad8e6ea..fa6280fba1 100644 --- a/tests/query_runner/test_trino.py +++ b/tests/query_runner/test_trino.py @@ -7,7 +7,7 @@ from trino.types import NamedRowTuple -from redash.query_runner.trino import Trino, _convert_row_types +from redash.query_runner.trino import Trino, _convert_row_types, _map_trino_type class TestTrino(TestCase): @@ -96,3 +96,38 @@ def test_unnamed_fields_get_positional_names(self): row = NamedRowTuple([1, 2], [None, None], ["integer", "integer"]) result = _convert_row_types(row) self.assertEqual(result, {"_field0": 1, "_field1": 2}) + + +class TestMapTrinoType(TestCase): + def test_exact_match(self): + self.assertEqual(_map_trino_type("timestamp"), "datetime") + self.assertEqual(_map_trino_type("integer"), "integer") + self.assertEqual(_map_trino_type("varchar"), "string") + self.assertEqual(_map_trino_type("date"), "date") + + def test_parameterised_timestamp(self): + self.assertEqual(_map_trino_type("timestamp(0)"), "datetime") + self.assertEqual(_map_trino_type("timestamp(3)"), "datetime") + self.assertEqual(_map_trino_type("timestamp(6)"), "datetime") + + def test_parameterised_decimal_with_scale(self): + self.assertEqual(_map_trino_type("decimal(10,2)"), "float") + self.assertEqual(_map_trino_type("decimal(18,6)"), "float") + + def test_parameterised_decimal_without_scale(self): + self.assertEqual(_map_trino_type("decimal(10,0)"), "integer") + self.assertEqual(_map_trino_type("decimal"), "integer") + + def test_parameterised_varchar(self): + self.assertEqual(_map_trino_type("varchar(255)"), "string") + self.assertEqual(_map_trino_type("char(1)"), "string") + + def test_unknown_type_returns_none(self): + self.assertIsNone(_map_trino_type("unknown")) + self.assertIsNone(_map_trino_type("unknown(3)")) + + def test_none_returns_none(self): + self.assertIsNone(_map_trino_type(None)) + + def test_empty_string_returns_none(self): + self.assertIsNone(_map_trino_type(""))