Skip to content

Commit e928b3e

Browse files
committed
Change default linker back to CVM
* Allow linker to change at runtime, and affect both `Mode` and `FAST_RUN` modes * Remove float32 parametrization from test suite. Tests that care about this should cover it explicitly
1 parent c3b70b3 commit e928b3e

File tree

8 files changed

+108
-125
lines changed

8 files changed

+108
-125
lines changed

.github/workflows/test.yml

Lines changed: 43 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
6666

6767
test:
68-
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
68+
name: "mode ${{ matrix.default-mode }} : py${{ matrix.python-version }} : ${{ matrix.os }} : ${{ matrix.part[0] }}"
6969
needs:
7070
- changes
7171
- style
@@ -74,101 +74,62 @@ jobs:
7474
strategy:
7575
fail-fast: false
7676
matrix:
77-
os: ["ubuntu-latest"]
77+
default-mode: ["C", "NUMBA", "FAST_COMPILE"]
7878
python-version: ["3.11", "3.14"]
79-
fast-compile: [0, 1]
80-
float32: [0, 1]
81-
install-numba: [0]
79+
os: ["ubuntu-latest"]
8280
install-jax: [0]
8381
install-torch: [0]
8482
install-mlx: [0]
8583
install-xarray: [0]
8684
part:
87-
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor --ignore=tests/link/numba"
88-
- "tests/scan"
89-
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py --ignore=tests/tensor/test_pad.py"
90-
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
91-
- "tests/tensor/test_math.py"
92-
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv tests/tensor/test_pad.py"
93-
- "tests/tensor/rewriting"
94-
- "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py"
85+
- [ "*rest", "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor --ignore=tests/link/numba" ]
86+
- [ "scan", "tests/scan" ]
87+
- [ "tensor *rest", "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py --ignore=tests/tensor/test_pad.py" ]
88+
- [ "tensor basic+elemwise", "tests/tensor/test_basic.py tests/tensor/test_elemwise.py" ]
89+
- [ "tensor math", "tests/tensor/test_math.py" ]
90+
- [ "tensor scipy+blas+conv+pad", "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/conv tests/tensor/test_pad.py" ]
91+
- [ "tensor rewriting", "tests/tensor/rewriting" ]
92+
- [ "tensor linalg", "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py" ]
9593
exclude:
9694
- python-version: "3.11"
97-
fast-compile: 1
98-
- python-version: "3.11"
99-
float32: 1
100-
- fast-compile: 1
101-
float32: 1
95+
default-mode: "FAST_COMPILE"
10296
include:
103-
- os: "ubuntu-latest"
104-
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
97+
- part: ["doctests", "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"]
98+
default-mode: "C"
10599
python-version: "3.12"
106-
fast-compile: 0
107-
float32: 0
108-
install-numba: 0
109-
install-jax: 0
110-
install-torch: 0
111-
install-mlx: 0
112-
install-xarray: 0
113-
- install-numba: 1
114100
os: "ubuntu-latest"
115-
python-version: "3.11"
116-
fast-compile: 0
117-
float32: 0
118-
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
119-
- install-numba: 1
101+
- part: ["numba link", "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"]
102+
default-mode: "C"
103+
python-version: "3.12"
120104
os: "ubuntu-latest"
121-
python-version: "3.14"
122-
fast-compile: 0
123-
float32: 0
124-
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
125-
- install-numba: 1
105+
- part: ["numba link slinalg", "tests/link/numba/test_slinalg.py"]
106+
default-mode: "C"
107+
python-version: "3.13"
126108
os: "ubuntu-latest"
109+
- part: ["jax link", "tests/link/jax"]
110+
install-jax: 1
111+
default-mode: "C"
127112
python-version: "3.14"
128-
fast-compile: 0
129-
float32: 0
130-
part: "tests/link/numba/test_slinalg.py"
131-
- install-jax: 1
132113
os: "ubuntu-latest"
114+
- part: ["pytorch link", "tests/link/pytorch"]
115+
install-torch: 1
116+
default-mode: "C"
133117
python-version: "3.11"
134-
fast-compile: 0
135-
float32: 0
136-
part: "tests/link/jax"
137-
- install-jax: 1
138118
os: "ubuntu-latest"
119+
- part: ["xtensor", "tests/xtensor"]
120+
install-xarray: 1
121+
default-mode: "C"
139122
python-version: "3.14"
140-
fast-compile: 0
141-
float32: 0
142-
part: "tests/link/jax"
143-
- install-torch: 1
144123
os: "ubuntu-latest"
145-
python-version: "3.11"
146-
fast-compile: 0
147-
float32: 0
148-
part: "tests/link/pytorch"
149-
- install-xarray: 1
150-
os: "ubuntu-latest"
151-
python-version: "3.14"
152-
fast-compile: 0
153-
float32: 0
154-
part: "tests/xtensor"
155-
- os: "macos-15"
156-
python-version: "3.11"
157-
fast-compile: 0
158-
float32: 0
124+
- part: ["mlx link", "tests/link/mlx"]
159125
install-mlx: 1
160-
install-numba: 0
161-
install-jax: 0
162-
install-torch: 0
163-
part: "tests/link/mlx"
164-
- os: "macos-15"
126+
default-mode: "C"
127+
python-version: "3.11"
128+
os: "macos-15"
129+
- part: ["macos smoke test", "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"]
130+
default-mode: "C"
165131
python-version: "3.14"
166-
fast-compile: 0
167-
float32: 0
168-
install-numba: 0
169-
install-jax: 0
170-
install-torch: 0
171-
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
132+
os: "macos-15"
172133

173134
steps:
174135
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
@@ -199,11 +160,10 @@ jobs:
199160
run: |
200161
201162
if [[ $OS == "macos-15" ]]; then
202-
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
163+
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
203164
else
204-
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
165+
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
205166
fi
206-
pip install "numba>=0.63"
207167
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
208168
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
209169
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
@@ -219,28 +179,26 @@ jobs:
219179
fi
220180
env:
221181
PYTHON_VERSION: ${{ matrix.python-version }}
222-
INSTALL_NUMBA: ${{ matrix.install-numba }}
223182
INSTALL_JAX: ${{ matrix.install-jax }}
224-
INSTALL_TORCH: ${{ matrix.install-torch}}
183+
INSTALL_TORCH: ${{ matrix.install-torch }}
225184
INSTALL_XARRAY: ${{ matrix.install-xarray }}
226185
INSTALL_MLX: ${{ matrix.install-mlx }}
227186
OS: ${{ matrix.os}}
228187

229188
- name: Run tests
230189
shell: micromamba-shell {0}
231190
run: |
232-
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
233-
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
191+
if [[ $DEFAULT_MODE == "FAST_COMPILE" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
192+
if [[ $DEFAULT_MODE == "NUMBA" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,linker=numba; fi
234193
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
235194
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
236195
env:
237196
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
238197
MKL_THREADING_LAYER: GNU
239198
MKL_NUM_THREADS: 1
240199
OMP_NUM_THREADS: 1
241-
PART: ${{ matrix.part }}
242-
FAST_COMPILE: ${{ matrix.fast-compile }}
243-
FLOAT32: ${{ matrix.float32 }}
200+
PART: ${{ matrix.part[1] }}
201+
DEFAULT_MODE: ${{ matrix.default-mode }}
244202

245203
- name: Upload coverage file
246204
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

pytensor/compile/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
)
1818
from pytensor.compile.io import In, Out, SymbolicInput, SymbolicOutput
1919
from pytensor.compile.mode import (
20+
CVM,
2021
FAST_COMPILE,
21-
FAST_RUN,
2222
JAX,
2323
NUMBA,
2424
OPT_FAST_COMPILE,
@@ -33,6 +33,7 @@
3333
PYTORCH,
3434
AddDestroyHandler,
3535
AddFeatureOptimizer,
36+
C,
3637
Mode,
3738
PrintCurrentFunctionGraph,
3839
get_default_mode,

pytensor/compile/mode.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import warnings
8-
from typing import Literal
8+
from typing import Any, Literal
99

1010
from pytensor.compile.function.types import Supervisor
1111
from pytensor.configdefaults import config
@@ -62,20 +62,17 @@ def register_linker(name, linker):
6262
predefined_linkers[name] = linker
6363

6464

65-
exclude = []
66-
if not config.cxx:
67-
exclude = ["cxx_only"]
68-
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
65+
OPT_NONE = RewriteDatabaseQuery(include=[])
6966
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
70-
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
67+
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"])
7168
# Even if multiple merge optimizer call will be there, this shouldn't
7269
# impact performance.
73-
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
74-
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
70+
OPT_MERGE = RewriteDatabaseQuery(include=["merge"])
71+
OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"])
7572
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable")
7673

77-
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude)
78-
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
74+
OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"])
75+
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"])
7976
OPT_STABILIZE.position_cutoff = 1.5000001
8077
OPT_NONE.name = "OPT_NONE"
8178
OPT_MINIMUM.name = "OPT_MINIMUM"
@@ -313,6 +310,8 @@ def __init__(
313310
):
314311
if linker is None:
315312
linker = config.linker
313+
if isinstance(linker, str) and linker == "auto":
314+
linker = "cvm" if config.cxx else "vm"
316315
if isinstance(optimizer, str) and optimizer == "default":
317316
optimizer = config.optimizer
318317

@@ -448,20 +447,15 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
448447
return new_mode
449448

450449

450+
C = Mode("c", "fast_run")
451+
CVM = Mode("cvm", "fast_run")
452+
VM = (Mode("vm", "fast_run"),)
453+
451454
NUMBA = Mode(
452455
NumbaLinker(),
453456
RewriteDatabaseQuery(include=["fast_run", "numba"]),
454457
)
455458

456-
FAST_COMPILE = Mode(
457-
NumbaLinker(),
458-
RewriteDatabaseQuery(include=["fast_compile"]),
459-
)
460-
FAST_RUN = NUMBA
461-
462-
C = Mode("c", "fast_run")
463-
CVM = Mode("cvm", "fast_run")
464-
465459
JAX = Mode(
466460
JAXLinker(),
467461
RewriteDatabaseQuery(include=["fast_run", "jax"]),
@@ -476,10 +470,19 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
476470
RewriteDatabaseQuery(include=["fast_run"]),
477471
)
478472

473+
FAST_COMPILE = Mode(
474+
VMLinker(use_cloop=False, c_thunks=False),
475+
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
476+
)
477+
478+
fast_run_linkers_to_mode = {
479+
"cvm": CVM,
480+
"vm": VM,
481+
"numba": NUMBA,
482+
}
479483

480484
predefined_modes = {
481485
"FAST_COMPILE": FAST_COMPILE,
482-
"FAST_RUN": FAST_RUN,
483486
"C": C,
484487
"CVM": CVM,
485488
"JAX": JAX,
@@ -488,7 +491,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
488491
"MLX": MLX,
489492
}
490493

491-
_CACHED_RUNTIME_MODES: dict[str, Mode] = {}
494+
_CACHED_RUNTIME_MODES: dict[Any, Mode] = {}
492495

493496

494497
def get_mode(orig_string):
@@ -506,10 +509,20 @@ def get_mode(orig_string):
506509
if upper_string in predefined_modes:
507510
return predefined_modes[upper_string]
508511

512+
if upper_string == "FAST_RUN":
513+
linker = config.linker
514+
if linker == "auto":
515+
return CVM if config.cxx else VM
516+
return fast_run_linkers_to_mode[linker]
517+
509518
global _CACHED_RUNTIME_MODES
510519

511-
if upper_string in _CACHED_RUNTIME_MODES:
512-
return _CACHED_RUNTIME_MODES[upper_string]
520+
cache_key = ("MODE", config.linker) if upper_string == "MODE" else upper_string
521+
522+
try:
523+
return _CACHED_RUNTIME_MODES[cache_key]
524+
except KeyError:
525+
pass
513526

514527
# Need to define the mode for the first time
515528
if upper_string == "MODE":
@@ -535,7 +548,7 @@ def get_mode(orig_string):
535548
if config.optimizer_requiring:
536549
ret = ret.requiring(*config.optimizer_requiring.split(":"))
537550
# Cache the mode for next time
538-
_CACHED_RUNTIME_MODES[upper_string] = ret
551+
_CACHED_RUNTIME_MODES[cache_key] = ret
539552

540553
return ret
541554

pytensor/configdefaults.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -371,24 +371,26 @@ def add_compile_configvars():
371371
)
372372
del param
373373

374-
default_linker = "numba"
374+
default_linker = "auto"
375375

376376
if rc == 0 and config.cxx != "":
377377
# Keep the default linker the same as the one for the mode FAST_RUN
378378
linker_options = [
379-
"cvmc|py",
379+
"cvm",
380+
"c|py",
380381
"py",
381382
"c",
382383
"c|py_nogc",
383384
"vm",
384385
"vm_nogc",
385386
"cvm_nogc",
387+
"numba",
386388
"jax",
387389
]
388390
else:
389391
# g++ is not present or the user disabled it,
390392
# linker should default to python only.
391-
linker_options = ["py", "vm", "vm_nogc", "jax"]
393+
linker_options = ["py", "vm", "vm_nogc", "numba", "jax"]
392394
if type(config).cxx.is_default:
393395
# If the user provided an empty value for cxx, do not warn.
394396
_logger.warning(
@@ -400,9 +402,8 @@ def add_compile_configvars():
400402

401403
config.add(
402404
"linker",
403-
"Default linker used if the pytensor flags mode is Mode",
404-
# Not mutable because the default mode is cached after the first use.
405-
EnumStr(default_linker, linker_options, mutable=False),
405+
"Default linker used if the pytensor flags mode is Mode or FAST_RUN",
406+
EnumStr(default_linker, linker_options, mutable=True),
406407
in_c_key=False,
407408
)
408409

pytensor/configparser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class PyTensorConfigParser:
7676
unpickle_function: bool
7777
# add_compile_configvars
7878
mode: str
79+
fast_run_backend: str
7980
cxx: str
8081
linker: str
8182
allow_gc: bool

0 commit comments

Comments
 (0)