Skip to content

Commit 6a8ec2a

Browse files
perf: add SparkDirectSeeder (bypass dbt seed) and tune Spark config for -n8 parallelism
- Add SparkDirectSeeder that executes CREATE TABLE + INSERT VALUES directly via the dbt adapter, bypassing the ~4s dbt subprocess overhead per seed - Add execute_sql() and schema_name property to AdapterQueryRunner - DbtProject auto-selects SparkDirectSeeder when target is 'spark' - Tune spark-defaults.conf: executor.cores=4, default.parallelism=4, thriftServer.async=true for better concurrent session handling - Restore -n8 parallelism for Spark in CI (was -n4) Co-Authored-By: Itamar Hartstein <haritamar@gmail.com>
1 parent 6ed6a33 commit 6a8ec2a

File tree

5 files changed

+96
-11
lines changed

5 files changed

+96
-11
lines changed

.github/workflows/test-warehouse.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ jobs:
177177
178178
- name: Test
179179
working-directory: "${{ env.TESTS_DIR }}/tests"
180-
run: py.test -n${{ (inputs.warehouse-type == 'spark' && '4') || '8' }} -vvv --target "${{ inputs.warehouse-type }}" --junit-xml=test-results.xml --html=detailed_report_${{ inputs.warehouse-type }}_dbt_${{ inputs.dbt-version }}.html --self-contained-html --clear-on-end ${{ (inputs.dbt-version == 'fusion' && '--runner-method fusion') || '' }}
180+
run: py.test -n8 -vvv --target "${{ inputs.warehouse-type }}" --junit-xml=test-results.xml --html=detailed_report_${{ inputs.warehouse-type }}_dbt_${{ inputs.dbt-version }}.html --self-contained-html --clear-on-end ${{ (inputs.dbt-version == 'fusion' && '--runner-method fusion') || '' }}
181181

182182
- name: Upload test results
183183
if: always()
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
spark.driver.memory 2g
22
spark.executor.memory 2g
3+
spark.executor.cores 4
34
spark.hadoop.datanucleus.autoCreateTables true
45
spark.hadoop.datanucleus.schema.autoCreateTables true
56
spark.hadoop.datanucleus.fixedDatastore false
@@ -8,6 +9,7 @@ spark.driver.userClassPathFirst true
89
spark.sql.extensions io.delta.sql.DeltaSparkSessionExtension
910
spark.sql.catalog.spark_catalog org.apache.spark.sql.delta.catalog.DeltaCatalog
1011
spark.sql.shuffle.partitions 2
11-
spark.default.parallelism 2
12+
spark.default.parallelism 4
1213
spark.ui.enabled false
1314
spark.sql.adaptive.enabled true
15+
spark.sql.hive.thriftServer.async true

integration_tests/tests/adapter_query_runner.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,16 @@ def has_non_ref_jinja(query: str) -> bool:
239239
stripped = _SOURCE_PATTERN.sub("", stripped)
240240
return bool(_JINJA_EXPR_PATTERN.search(stripped))
241241

242+
def execute_sql(self, sql: str) -> None:
243+
"""Execute a SQL statement that does not return results (DDL/DML)."""
244+
with self._adapter.connection_named("execute_sql"):
245+
self._adapter.execute(sql, fetch=False)
246+
247+
@property
248+
def schema_name(self) -> str:
249+
"""Return the base schema name from the adapter credentials."""
250+
return self._adapter.config.credentials.schema
251+
242252
def run_query(self, prerendered_query: str) -> List[Dict[str, Any]]:
243253
"""Render Jinja refs/sources and execute a query, returning rows as dicts.
244254

integration_tests/tests/data_seeder.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import csv
22
from contextlib import contextmanager
33
from pathlib import Path
4-
from typing import Generator, List
4+
from typing import TYPE_CHECKING, Generator, List
55

66
from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
77
from logger import get_logger
88

9-
# TODO: Write more performant data seeders per adapter.
9+
if TYPE_CHECKING:
10+
from adapter_query_runner import AdapterQueryRunner
1011

1112
logger = get_logger(__name__)
1213

1314

1415
class DbtDataSeeder:
16+
"""Default seeder: writes a CSV and calls ``dbt seed``."""
17+
1518
def __init__(
1619
self, dbt_runner: BaseDbtRunner, dbt_project_path: Path, seeds_dir_path: Path
1720
):
@@ -48,3 +51,67 @@ def seed(self, data: List[dict], table_name: str) -> Generator[None, None, None]
4851
yield
4952
finally:
5053
seed_path.unlink()
54+
55+
56+
# Maximum number of rows per INSERT VALUES statement. Spark's Thrift
57+
# protocol can choke on very large statements, so we batch inserts.
58+
_INSERT_BATCH_SIZE = 500
59+
60+
61+
class SparkDirectSeeder:
62+
"""Fast seeder for Spark: executes CREATE TABLE + INSERT directly.
63+
64+
Bypasses the ``dbt seed`` subprocess entirely, avoiding the ~4 s
65+
Python/manifest-parsing overhead per invocation. All columns are
66+
created as STRING, which matches ``dbt seed`` behaviour.
67+
"""
68+
69+
def __init__(self, query_runner: "AdapterQueryRunner", schema: str) -> None:
70+
self._query_runner = query_runner
71+
self._schema = schema
72+
73+
# ------------------------------------------------------------------
74+
# helpers
75+
# ------------------------------------------------------------------
76+
77+
@staticmethod
78+
def _escape(value: object) -> str:
79+
"""Escape a value for a Spark SQL string literal."""
80+
if value is None or (isinstance(value, str) and value == ""):
81+
return "NULL"
82+
text = str(value)
83+
# Replace backslashes first, then single-quotes.
84+
text = text.replace("\\", "\\\\")
85+
text = text.replace("'", "\\'")
86+
# Spark INSERT VALUES doesn't support embedded newlines.
87+
text = text.replace("\n", " ").replace("\r", " ")
88+
return f"'{text}'"
89+
90+
# ------------------------------------------------------------------
91+
# public API (same shape as DbtDataSeeder)
92+
# ------------------------------------------------------------------
93+
94+
@contextmanager
95+
def seed(self, data: List[dict], table_name: str) -> Generator[None, None, None]:
96+
columns = list(data[0].keys())
97+
col_defs = ", ".join(f"`{col}` STRING" for col in columns)
98+
fq_table = f"`{self._schema}`.`{table_name}`"
99+
100+
# DROP + CREATE is the fastest way to get a clean table.
101+
self._query_runner.execute_sql(f"DROP TABLE IF EXISTS {fq_table}")
102+
self._query_runner.execute_sql(
103+
f"CREATE TABLE {fq_table} ({col_defs}) USING DELTA"
104+
)
105+
106+
# Insert in batches.
107+
for batch_start in range(0, len(data), _INSERT_BATCH_SIZE):
108+
batch = data[batch_start : batch_start + _INSERT_BATCH_SIZE]
109+
rows_sql = ", ".join(
110+
"(" + ", ".join(self._escape(row.get(c)) for c in columns) + ")"
111+
for row in batch
112+
)
113+
self._query_runner.execute_sql(f"INSERT INTO {fq_table} VALUES {rows_sql}")
114+
115+
logger.info("SparkDirectSeeder: loaded %d rows into %s", len(data), fq_table)
116+
117+
yield

integration_tests/tests/dbt_project.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from uuid import uuid4
88

99
from adapter_query_runner import AdapterQueryRunner, UnsupportedJinjaError
10-
from data_seeder import DbtDataSeeder
10+
from data_seeder import DbtDataSeeder, SparkDirectSeeder
1111
from dbt_utils import get_database_and_schema_properties
1212
from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
1313
from elementary.clients.dbt.factory import RunnerMethod, create_dbt_runner
@@ -326,10 +326,18 @@ def test(
326326
}
327327
return [test_result] if multiple_results else test_result
328328

329-
def seed(self, data: List[dict], table_name: str):
330-
with DbtDataSeeder(
329+
def _create_seeder(self) -> Union[DbtDataSeeder, SparkDirectSeeder]:
330+
"""Return the fastest available seeder for the current target."""
331+
if self.target == "spark":
332+
runner = self._get_query_runner()
333+
schema = runner.schema_name + SCHEMA_NAME_SUFFIX
334+
return SparkDirectSeeder(runner, schema)
335+
return DbtDataSeeder(
331336
self.dbt_runner, self.project_dir_path, self.seeds_dir_path
332-
).seed(data, table_name):
337+
)
338+
339+
def seed(self, data: List[dict], table_name: str):
340+
with self._create_seeder().seed(data, table_name):
333341
self._fix_seed_if_needed(table_name)
334342

335343
def _fix_seed_if_needed(self, table_name: str):
@@ -345,9 +353,7 @@ def _fix_seed_if_needed(self, table_name: str):
345353
def seed_context(
346354
self, data: List[dict], table_name: str
347355
) -> Generator[None, None, None]:
348-
with DbtDataSeeder(
349-
self.dbt_runner, self.project_dir_path, self.seeds_dir_path
350-
).seed(data, table_name):
356+
with self._create_seeder().seed(data, table_name):
351357
yield
352358

353359
@contextmanager

0 commit comments

Comments
 (0)