Skip to content

Commit da8e246

Browse files
authored
[python] Support read paimon table as pytorch dataset (#6987)
1 parent 03f7c27 commit da8e246

File tree

10 files changed

+953
-117
lines changed

10 files changed

+953
-117
lines changed

.github/workflows/paimon-python-checks.yml

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
container: "python:${{ matrix.python-version }}-slim"
4747
strategy:
4848
matrix:
49-
python-version: ['3.6.15', '3.10']
49+
python-version: [ '3.6.15', '3.10' ]
5050

5151
steps:
5252
- name: Checkout code
@@ -70,6 +70,7 @@ jobs:
7070
build-essential \
7171
git \
7272
curl \
73+
&& apt-get clean \
7374
&& rm -rf /var/lib/apt/lists/*
7475
7576
- name: Verify Java and Maven installation
@@ -88,66 +89,82 @@ jobs:
8889
- name: Install Python dependencies
8990
shell: bash
9091
run: |
92+
df -h
9193
if [[ "${{ matrix.python-version }}" == "3.6.15" ]]; then
9294
python -m pip install --upgrade pip==21.3.1
9395
python --version
94-
python -m pip install -q pyroaring readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null
96+
python -m pip install --no-cache-dir pyroaring readerwriterlock==1.0.9 'fsspec==2021.10.1' 'cachetools==4.2.4' 'ossfs==2021.8.0' pyarrow==6.0.1 pandas==1.1.5 'polars==0.9.12' 'fastavro==1.4.7' zstandard==0.19.0 dataclasses==0.8.0 flake8 pytest py4j==0.10.9.9 requests parameterized==0.8.1 2>&1 >/dev/null
9597
else
9698
python -m pip install --upgrade pip
97-
python -m pip install -q pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0 2>&1 >/dev/null
99+
pip install torch --index-url https://download.pytorch.org/whl/cpu
100+
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0
98101
fi
102+
df -h
99103
- name: Run lint-python.sh
100104
shell: bash
101105
run: |
102106
chmod +x paimon-python/dev/lint-python.sh
103-
./paimon-python/dev/lint-python.sh
107+
./paimon-python/dev/lint-python.sh -e pytest_torch
104108
105-
requirement_version_compatible_test:
109+
torch_test:
106110
runs-on: ubuntu-latest
107111
container: "python:3.10-slim"
108112

109113
steps:
110114
- name: Checkout code
111115
uses: actions/checkout@v2
112116

113-
- name: Set up JDK ${{ env.JDK_VERSION }}
114-
uses: actions/setup-java@v4
115-
with:
116-
java-version: ${{ env.JDK_VERSION }}
117-
distribution: 'temurin'
118-
119-
- name: Set up Maven
120-
uses: stCarolas/[email protected]
121-
with:
122-
maven-version: 3.8.8
123-
124117
- name: Install system dependencies
125118
shell: bash
126119
run: |
127120
apt-get update && apt-get install -y \
128121
build-essential \
129122
git \
130123
curl \
124+
&& apt-get clean \
131125
&& rm -rf /var/lib/apt/lists/*
132126
133-
- name: Verify Java and Maven installation
134-
run: |
135-
java -version
136-
mvn -version
137-
138127
- name: Verify Python version
139128
run: python --version
140129

141-
- name: Build Java
130+
- name: Install Python dependencies
131+
shell: bash
142132
run: |
143-
echo "Start compiling modules"
144-
mvn -T 2C -B clean install -DskipTests
133+
python -m pip install --upgrade pip
134+
pip install torch --index-url https://download.pytorch.org/whl/cpu
135+
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0
136+
- name: Run lint-python.sh
137+
shell: bash
138+
run: |
139+
chmod +x paimon-python/dev/lint-python.sh
140+
./paimon-python/dev/lint-python.sh -i pytest_torch
141+
142+
requirement_version_compatible_test:
143+
runs-on: ubuntu-latest
144+
container: "python:3.10-slim"
145+
146+
steps:
147+
- name: Checkout code
148+
uses: actions/checkout@v2
149+
150+
- name: Install system dependencies
151+
shell: bash
152+
run: |
153+
apt-get update && apt-get install -y \
154+
build-essential \
155+
git \
156+
curl \
157+
&& rm -rf /var/lib/apt/lists/*
158+
159+
- name: Verify Python version
160+
run: python --version
145161

146162
- name: Install base Python dependencies
147163
shell: bash
148164
run: |
149165
python -m pip install --upgrade pip
150-
python -m pip install -q \
166+
pip install torch --index-url https://download.pytorch.org/whl/cpu
167+
python -m pip install --no-cache-dir \
151168
pyroaring \
152169
readerwriterlock==1.0.9 \
153170
fsspec==2024.3.1 \
@@ -165,36 +182,37 @@ jobs:
165182
requests \
166183
parameterized==0.9.0 \
167184
packaging
185+
168186
169187
- name: Test requirement version compatibility
170188
shell: bash
171189
run: |
172190
cd paimon-python
173-
191+
174192
# Test Ray version compatibility
175193
echo "=========================================="
176194
echo "Testing Ray version compatibility"
177195
echo "=========================================="
178196
for ray_version in 2.44.0 2.48.0 2.53.0; do
179197
echo "Testing Ray version: $ray_version"
180-
198+
181199
# Install specific Ray version
182-
python -m pip install -q ray==$ray_version
183-
200+
python -m pip install --no-cache-dir -q ray==$ray_version
201+
184202
# Verify Ray version
185203
python -c "import ray; print(f'Ray version: {ray.__version__}')"
186204
python -c "from packaging.version import parse; import ray; assert parse(ray.__version__) == parse('$ray_version'), f'Expected Ray $ray_version, got {ray.__version__}'"
187-
205+
188206
# Run tests
189207
python -m pytest pypaimon/tests/ray_data_test.py::RayDataTest -v --tb=short || {
190208
echo "Tests failed for Ray $ray_version"
191209
exit 1
192210
}
193-
211+
194212
# Uninstall Ray to avoid conflicts
195213
python -m pip uninstall -y ray
196214
done
197-
215+
198216
# Add other dependency version tests here in the future
199217
# Example:
200218
# echo "=========================================="

docs/content/program-api/python-api.md

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ catalog_options = {
7272
}
7373
catalog = CatalogFactory.create(catalog_options)
7474
```
75+
7576
{{< /tab >}}
7677
{{< /tabs >}}
7778

@@ -473,6 +474,38 @@ ray_dataset = table_read.to_ray(splits)
473474

474475
See [Ray Data API Documentation](https://docs.ray.io/en/latest/data/api/doc/ray.data.read_datasource.html) for more details.
475476

477+
### Read Pytorch Dataset
478+
479+
This requires `torch` to be installed.
480+
481+
You can read all the data into a `torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`:
482+
483+
```python
484+
from torch.utils.data import DataLoader
485+
486+
table_read = read_builder.new_read()
487+
dataset = table_read.to_torch(splits, streaming=True)
488+
dataloader = DataLoader(
489+
dataset,
490+
batch_size=2,
491+
num_workers=2, # Concurrency to read data
492+
shuffle=False
493+
)
494+
495+
# Collect all data from dataloader
496+
for batch_idx, batch_data in enumerate(dataloader):
497+
print(batch_data)
498+
499+
# output:
500+
# {'user_id': tensor([1, 2]), 'behavior': ['a', 'b']}
501+
# {'user_id': tensor([3, 4]), 'behavior': ['c', 'd']}
502+
# {'user_id': tensor([5, 6]), 'behavior': ['e', 'f']}
503+
# {'user_id': tensor([7, 8]), 'behavior': ['g', 'h']}
504+
```
505+
506+
When the `streaming` parameter is true, it will iteratively read;
507+
when it is false, it will read the full amount of data into memory.
508+
476509
### Incremental Read
477510

478511
This API allows reading data committed between two snapshot timestamps. The steps are as follows.
@@ -671,22 +704,22 @@ Key points about shard read:
671704
The following shows the supported features of Python Paimon compared to Java Paimon:
672705

673706
**Catalog Level**
674-
- FileSystemCatalog
675-
- RestCatalog
707+
- FileSystemCatalog
708+
- RestCatalog
676709

677710
**Table Level**
678-
- Append Tables
679-
- `bucket = -1` (unaware)
680-
- `bucket > 0` (fixed)
681-
- Primary Key Tables
682-
- only support deduplicate
683-
- `bucket = -2` (postpone)
684-
- `bucket > 0` (fixed)
685-
- read with deletion vectors enabled
686-
- Read/Write Operations
687-
- Batch read and write for append tables and primary key tables
688-
- Predicate filtering
689-
- Overwrite semantics
690-
- Incremental reading of Delta data
691-
- Reading and writing blob data
692-
- `with_shard` feature
711+
- Append Tables
712+
- `bucket = -1` (unaware)
713+
- `bucket > 0` (fixed)
714+
- Primary Key Tables
715+
- only support deduplicate
716+
- `bucket = -2` (postpone)
717+
- `bucket > 0` (fixed)
718+
- read with deletion vectors enabled
719+
- Read/Write Operations
720+
- Batch read and write for append tables and primary key tables
721+
- Predicate filtering
722+
- Overwrite semantics
723+
- Incremental reading of Delta data
724+
- Reading and writing blob data
725+
- `with_shard` feature

paimon-python/dev/lint-python.sh

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ function collect_checks() {
107107
function get_all_supported_checks() {
108108
_OLD_IFS=$IFS
109109
IFS=$'\n'
110-
SUPPORT_CHECKS=("flake8_check" "pytest_check" "mixed_check") # control the calling sequence
110+
SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check" "mixed_check") # control the calling sequence
111111
for fun in $(declare -F); do
112112
if [[ `regexp_match "$fun" "_check$"` = true ]]; then
113113
check_name="${fun:11}"
@@ -179,7 +179,7 @@ function pytest_check() {
179179
TEST_DIR="pypaimon/tests/py36"
180180
echo "Running tests for Python 3.6: $TEST_DIR"
181181
else
182-
TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 --ignore=pypaimon/tests/e2e"
182+
TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36 --ignore=pypaimon/tests/e2e --ignore=pypaimon/tests/torch_read_test.py"
183183
echo "Running tests for Python $PYTHON_VERSION (excluding py36): pypaimon/tests --ignore=pypaimon/tests/py36"
184184
fi
185185

@@ -197,7 +197,32 @@ function pytest_check() {
197197
print_function "STAGE" "pytest checks... [SUCCESS]"
198198
fi
199199
}
200+
function pytest_torch_check() {
201+
print_function "STAGE" "pytest torch checks"
202+
if [ ! -f "$PYTEST_PATH" ]; then
203+
echo "For some unknown reasons, the pytest package is not complete."
204+
fi
200205

206+
# Get Python version
207+
PYTHON_VERSION=$(python -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')")
208+
echo "Detected Python version: $PYTHON_VERSION"
209+
TEST_DIR="pypaimon/tests/torch_read_test.py"
210+
echo "Running tests for Python $PYTHON_VERSION: pypaimon/tests/torch_read_test.py"
211+
212+
# the return value of a pipeline is the status of the last command to exit
213+
# with a non-zero status or zero if no command exited with a non-zero status
214+
set -o pipefail
215+
($PYTEST_PATH $TEST_DIR) 2>&1 | tee -a $LOG_FILE
216+
217+
PYCODESTYLE_STATUS=$?
218+
if [ $PYCODESTYLE_STATUS -ne 0 ]; then
219+
print_function "STAGE" "pytest checks... [FAILED]"
220+
# Stop the running script.
221+
exit 1;
222+
else
223+
print_function "STAGE" "pytest checks... [SUCCESS]"
224+
fi
225+
}
201226
# Mixed tests check - runs Java-Python interoperability tests
202227
function mixed_check() {
203228
# Get Python version
@@ -279,7 +304,7 @@ usage: $0 [options]
279304
-l list all checks supported.
280305
Examples:
281306
./lint-python.sh => exec all checks.
282-
./lint-python.sh -e tox,flake8 => exclude checks tox,flake8.
307+
./lint-python.sh -e flake8 => exclude checks flake8.
283308
./lint-python.sh -i flake8 => include checks flake8.
284309
./lint-python.sh -i mixed => include checks mixed.
285310
./lint-python.sh -l => list all checks supported.

paimon-python/dev/requirements.txt

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,23 @@
1919
cachetools>=4.2,<6; python_version=="3.6"
2020
cachetools>=5,<6; python_version>"3.6"
2121
dataclasses>=0.8; python_version < "3.7"
22-
fastavro>=1.4,<2; python_version<"3.9"
23-
fastavro>=1.4,<2; python_version>="3.9"
22+
fastavro>=1.4,<2
2423
fsspec>=2021.10,<2026; python_version<"3.8"
2524
fsspec>=2023,<2026; python_version>="3.8"
2625
ossfs>=2021.8; python_version<"3.8"
2726
ossfs>=2023; python_version>="3.8"
28-
packaging>=21,<26; python_version<"3.8"
29-
packaging>=21,<26; python_version>="3.8"
27+
packaging>=21,<26
3028
pandas>=1.1,<2; python_version < "3.7"
3129
pandas>=1.3,<3; python_version >= "3.7" and python_version < "3.9"
3230
pandas>=1.5,<3; python_version >= "3.9"
3331
polars>=0.9,<1; python_version<"3.8"
34-
polars>=1,<2; python_version=="3.8"
35-
polars>=1,<2; python_version>"3.8"
32+
polars>=1,<2; python_version>="3.8"
3633
pyarrow>=6,<7; python_version < "3.8"
37-
pyarrow>=16,<20; python_version >= "3.8" and python_version < "3.13"
38-
pyarrow>=16,<20; python_version >= "3.13"
34+
pyarrow>=16,<20; python_version >= "3.8"
35+
pylance>=0.20,<1; python_version>="3.9"
36+
pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
3937
pyroaring
4038
ray>=2.10,<3
4139
readerwriterlock>=1,<2
42-
zstandard>=0.19,<1; python_version<"3.9"
43-
zstandard>=0.19,<1; python_version>="3.9"
44-
pylance>=0.20,<1; python_version>="3.9"
45-
pylance>=0.10,<1; python_version>="3.8" and python_version<"3.9"
40+
torch
41+
zstandard>=0.19,<1

0 commit comments

Comments
 (0)