Skip to content

Commit 07ffa21

Browse files
committed
Add multiprocessing test since that was motivating factor
1 parent 5796b53 commit 07ffa21

1 file changed

Lines changed: 118 additions & 0 deletions

File tree

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""End-to-end multiprocessing tests for :py:class:`datafusion.Expr` pickling.
19+
20+
Motivation: users want to fan work out across processes
21+
via :py:mod:`multiprocessing` and pass an :py:class:`Expr` to each worker.
22+
23+
The ``spawn`` start method is forced so each worker is a fresh interpreter and
24+
the ``Expr`` argument is genuinely sent through pickle. Under ``fork`` the
25+
child inherits the parent's address space and would never exercise the wire
26+
format.
27+
28+
Two scenarios are covered:
29+
30+
1. A built-in expression — the workers can
31+
resolve everything against a default :py:class:`SessionContext`.
32+
2. A custom Python UDF — the worker must register the UDF on its global
33+
context *before* unpickling, since ``Expr.__setstate__`` resolves
34+
function references by name against the global context.
35+
"""
36+
37+
from __future__ import annotations
38+
39+
import multiprocessing as mp
40+
41+
import pyarrow as pa
42+
import pyarrow.compute as pc
43+
from datafusion import SessionContext, col, lit, udf
44+
45+
# Module-scope helpers — must be importable by name so the `spawn` workers
46+
# can resolve them after re-importing this module.
47+
48+
_UDF_NAME = "mp_pickle_add_ten"
49+
50+
51+
def _add_ten_impl(array: pa.Array) -> pa.Array:
52+
return pc.add(array, 10)
53+
54+
55+
def _build_add_ten_udf():
56+
return udf(
57+
_add_ten_impl,
58+
[pa.int64()],
59+
pa.int64(),
60+
volatility="immutable",
61+
name=_UDF_NAME,
62+
)
63+
64+
65+
def _register_udf_on_global_ctx() -> None:
66+
"""Pool initializer: install a global ctx in the worker that knows the UDF.
67+
68+
``Expr.__setstate__`` resolves UDF references by name against the
69+
*global* context, so the registration must happen before any task arg is
70+
unpickled — i.e. in the Pool's ``initializer``, not in the task body.
71+
"""
72+
ctx = SessionContext()
73+
ctx.register_udf(_build_add_ten_udf())
74+
ctx.set_as_global()
75+
76+
77+
def _apply_builtin_expr(args: tuple) -> list:
78+
expr, values = args
79+
ctx = SessionContext()
80+
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
81+
df = ctx.create_dataframe([[batch]], name="t")
82+
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
83+
84+
85+
def _apply_udf_expr(args: tuple) -> list:
86+
expr, values = args
87+
# Reuse the worker's global ctx so the UDF registered by the initializer
88+
# is visible during execution as well as during arg unpickling.
89+
ctx = SessionContext.global_ctx()
90+
batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"])
91+
df = ctx.create_dataframe([[batch]], name="t_udf")
92+
return df.select(expr.alias("out")).collect()[0].column(0).to_pylist()
93+
94+
95+
def test_builtin_expr_through_multiprocessing_pool() -> None:
96+
"""A built-in ``Expr`` survives a real ``multiprocessing.Pool`` dispatch."""
97+
spawn_ctx = mp.get_context("spawn")
98+
expr = (col("a") * lit(2)) + lit(1)
99+
chunks = [[1, 2, 3], [10, 20, 30]]
100+
101+
with spawn_ctx.Pool(processes=2) as pool:
102+
results = pool.map(_apply_builtin_expr, [(expr, c) for c in chunks])
103+
104+
assert results == [[3, 5, 7], [21, 41, 61]]
105+
106+
107+
def test_udf_expr_through_multiprocessing_pool() -> None:
108+
"""A UDF-backed ``Expr`` survives ``Pool.map`` when the worker registers
109+
the UDF on its global context via the Pool initializer."""
110+
spawn_ctx = mp.get_context("spawn")
111+
add_ten = _build_add_ten_udf()
112+
expr = add_ten(col("a"))
113+
chunks = [[1, 2, 3], [10, 20, 30]]
114+
115+
with spawn_ctx.Pool(processes=2, initializer=_register_udf_on_global_ctx) as pool:
116+
results = pool.map(_apply_udf_expr, [(expr, c) for c in chunks])
117+
118+
assert results == [[11, 12, 13], [20, 30, 40]]

0 commit comments

Comments
 (0)