|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +"""End-to-end multiprocessing tests for :py:class:`datafusion.Expr` pickling. |
| 19 | +
|
| 20 | +Motivation: users want to fan work out across processes |
| 21 | +via :py:mod:`multiprocessing` and pass an :py:class:`Expr` to each worker. |
| 22 | +
|
| 23 | +The ``spawn`` start method is forced so each worker is a fresh interpreter and |
| 24 | +the ``Expr`` argument is genuinely sent through pickle. Under ``fork`` the |
| 25 | +child inherits the parent's address space and would never exercise the wire |
| 26 | +format. |
| 27 | +
|
| 28 | +Two scenarios are covered: |
| 29 | +
|
| 30 | +1. A built-in expression — the workers can |
| 31 | + resolve everything against a default :py:class:`SessionContext`. |
| 32 | +2. A custom Python UDF — the worker must register the UDF on its global |
| 33 | + context *before* unpickling, since ``Expr.__setstate__`` resolves |
| 34 | + function references by name against the global context. |
| 35 | +""" |
| 36 | + |
| 37 | +from __future__ import annotations |
| 38 | + |
| 39 | +import multiprocessing as mp |
| 40 | + |
| 41 | +import pyarrow as pa |
| 42 | +import pyarrow.compute as pc |
| 43 | +from datafusion import SessionContext, col, lit, udf |
| 44 | + |
| 45 | +# Module-scope helpers — must be importable by name so the `spawn` workers |
| 46 | +# can resolve them after re-importing this module. |
| 47 | + |
| 48 | +_UDF_NAME = "mp_pickle_add_ten" |
| 49 | + |
| 50 | + |
| 51 | +def _add_ten_impl(array: pa.Array) -> pa.Array: |
| 52 | + return pc.add(array, 10) |
| 53 | + |
| 54 | + |
| 55 | +def _build_add_ten_udf(): |
| 56 | + return udf( |
| 57 | + _add_ten_impl, |
| 58 | + [pa.int64()], |
| 59 | + pa.int64(), |
| 60 | + volatility="immutable", |
| 61 | + name=_UDF_NAME, |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +def _register_udf_on_global_ctx() -> None: |
| 66 | + """Pool initializer: install a global ctx in the worker that knows the UDF. |
| 67 | +
|
| 68 | + ``Expr.__setstate__`` resolves UDF references by name against the |
| 69 | + *global* context, so the registration must happen before any task arg is |
| 70 | + unpickled — i.e. in the Pool's ``initializer``, not in the task body. |
| 71 | + """ |
| 72 | + ctx = SessionContext() |
| 73 | + ctx.register_udf(_build_add_ten_udf()) |
| 74 | + ctx.set_as_global() |
| 75 | + |
| 76 | + |
| 77 | +def _apply_builtin_expr(args: tuple) -> list: |
| 78 | + expr, values = args |
| 79 | + ctx = SessionContext() |
| 80 | + batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"]) |
| 81 | + df = ctx.create_dataframe([[batch]], name="t") |
| 82 | + return df.select(expr.alias("out")).collect()[0].column(0).to_pylist() |
| 83 | + |
| 84 | + |
| 85 | +def _apply_udf_expr(args: tuple) -> list: |
| 86 | + expr, values = args |
| 87 | + # Reuse the worker's global ctx so the UDF registered by the initializer |
| 88 | + # is visible during execution as well as during arg unpickling. |
| 89 | + ctx = SessionContext.global_ctx() |
| 90 | + batch = pa.RecordBatch.from_arrays([pa.array(values, type=pa.int64())], names=["a"]) |
| 91 | + df = ctx.create_dataframe([[batch]], name="t_udf") |
| 92 | + return df.select(expr.alias("out")).collect()[0].column(0).to_pylist() |
| 93 | + |
| 94 | + |
| 95 | +def test_builtin_expr_through_multiprocessing_pool() -> None: |
| 96 | + """A built-in ``Expr`` survives a real ``multiprocessing.Pool`` dispatch.""" |
| 97 | + spawn_ctx = mp.get_context("spawn") |
| 98 | + expr = (col("a") * lit(2)) + lit(1) |
| 99 | + chunks = [[1, 2, 3], [10, 20, 30]] |
| 100 | + |
| 101 | + with spawn_ctx.Pool(processes=2) as pool: |
| 102 | + results = pool.map(_apply_builtin_expr, [(expr, c) for c in chunks]) |
| 103 | + |
| 104 | + assert results == [[3, 5, 7], [21, 41, 61]] |
| 105 | + |
| 106 | + |
| 107 | +def test_udf_expr_through_multiprocessing_pool() -> None: |
| 108 | + """A UDF-backed ``Expr`` survives ``Pool.map`` when the worker registers |
| 109 | + the UDF on its global context via the Pool initializer.""" |
| 110 | + spawn_ctx = mp.get_context("spawn") |
| 111 | + add_ten = _build_add_ten_udf() |
| 112 | + expr = add_ten(col("a")) |
| 113 | + chunks = [[1, 2, 3], [10, 20, 30]] |
| 114 | + |
| 115 | + with spawn_ctx.Pool(processes=2, initializer=_register_udf_on_global_ctx) as pool: |
| 116 | + results = pool.map(_apply_udf_expr, [(expr, c) for c in chunks]) |
| 117 | + |
| 118 | + assert results == [[11, 12, 13], [20, 30, 40]] |
0 commit comments