Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion redash/query_runner/trino.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import re

from redash.models.users import ApiUser, User
from redash.query_runner import (
Expand Down Expand Up @@ -62,6 +63,31 @@ def _convert_row_types(value):
}


_DECIMAL_SCALE_RE = re.compile(r"^decimal\(\d+,\s*(\d+)\)$")


def _map_trino_type(type_name):
"""Map a Trino type name to a Redash column type.

Handles parameterised types such as ``timestamp(3)`` or ``decimal(10,2)``
by falling back to the base type when an exact match is not found.
"""
if not type_name:
return None
mapped = TRINO_TYPES_MAPPING.get(type_name)
if mapped is not None:
return mapped
# Strip parameters: "timestamp(3)" -> "timestamp"
base = type_name.split("(", 1)[0]
mapped = TRINO_TYPES_MAPPING.get(base)
# decimal(p, s) with s > 0 has fractional digits
if base == "decimal":
m = _DECIMAL_SCALE_RE.match(type_name)
if m and int(m.group(1)) > 0:
mapped = TYPE_FLOAT
return mapped


class Trino(BaseQueryRunner):
noop_query = "SELECT 1"
should_annotate_query = ANNOTATE_QUERY
Expand Down Expand Up @@ -215,7 +241,7 @@ def run_query(self, query, user):
cursor.execute(query)
results = cursor.fetchall()
description = cursor.description
columns = self.fetch_columns([(c[0], TRINO_TYPES_MAPPING.get(c[1], None)) for c in description])
columns = self.fetch_columns([(c[0], _map_trino_type(c[1])) for c in description])
column_names = [c["name"] for c in columns]
rows = [dict(zip(column_names, [_convert_row_types(v) for v in r])) for r in results]
data = {"columns": columns, "rows": rows}
Expand Down
37 changes: 36 additions & 1 deletion tests/query_runner/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from trino.types import NamedRowTuple

from redash.query_runner.trino import Trino, _convert_row_types
from redash.query_runner.trino import Trino, _convert_row_types, _map_trino_type


class TestTrino(TestCase):
Expand Down Expand Up @@ -96,3 +96,38 @@ def test_unnamed_fields_get_positional_names(self):
row = NamedRowTuple([1, 2], [None, None], ["integer", "integer"])
result = _convert_row_types(row)
self.assertEqual(result, {"_field0": 1, "_field1": 2})


class TestMapTrinoType(TestCase):
def test_exact_match(self):
self.assertEqual(_map_trino_type("timestamp"), "datetime")
self.assertEqual(_map_trino_type("integer"), "integer")
self.assertEqual(_map_trino_type("varchar"), "string")
self.assertEqual(_map_trino_type("date"), "date")

def test_parameterised_timestamp(self):
self.assertEqual(_map_trino_type("timestamp(0)"), "datetime")
self.assertEqual(_map_trino_type("timestamp(3)"), "datetime")
self.assertEqual(_map_trino_type("timestamp(6)"), "datetime")

def test_parameterised_decimal_with_scale(self):
self.assertEqual(_map_trino_type("decimal(10,2)"), "float")
self.assertEqual(_map_trino_type("decimal(18,6)"), "float")

def test_parameterised_decimal_without_scale(self):
self.assertEqual(_map_trino_type("decimal(10,0)"), "integer")
self.assertEqual(_map_trino_type("decimal"), "integer")

def test_parameterised_varchar(self):
self.assertEqual(_map_trino_type("varchar(255)"), "string")
self.assertEqual(_map_trino_type("char(1)"), "string")

def test_unknown_type_returns_none(self):
self.assertIsNone(_map_trino_type("unknown"))
self.assertIsNone(_map_trino_type("unknown(3)"))

def test_none_returns_none(self):
self.assertIsNone(_map_trino_type(None))

def test_empty_string_returns_none(self):
self.assertIsNone(_map_trino_type(""))
Loading