Skip to content

Commit 7856573

Browse files
[python] Fix default file compression and support file compression when writing avro file (#6996)
1 parent da8e246 commit 7856573

File tree

7 files changed

+237
-31
lines changed

7 files changed

+237
-31
lines changed

.github/workflows/paimon-python-checks.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
else
9898
python -m pip install --upgrade pip
9999
pip install torch --index-url https://download.pytorch.org/whl/cpu
100-
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0
100+
python -m pip install pyroaring readerwriterlock==1.0.9 fsspec==2024.3.1 cachetools==5.3.3 ossfs==2023.12.0 ray==2.48.0 fastavro==1.11.1 pyarrow==16.0.0 zstandard==0.24.0 polars==1.32.0 duckdb==1.3.2 numpy==1.24.3 pandas==2.0.3 pylance==0.39.0 cramjam flake8==4.0.1 pytest~=7.0 py4j==0.10.9.9 requests parameterized==0.9.0
101101
fi
102102
df -h
103103
- name: Run lint-python.sh
@@ -177,6 +177,7 @@ jobs:
177177
duckdb==1.3.2 \
178178
numpy==1.24.3 \
179179
pandas==2.0.3 \
180+
cramjam \
180181
pytest~=7.0 \
181182
py4j==0.10.9.9 \
182183
requests \

paimon-python/dev/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,5 @@ pyroaring
3838
ray>=2.10,<3
3939
readerwriterlock>=1,<2
4040
torch
41-
zstandard>=0.19,<1
41+
zstandard>=0.19,<1
42+
cramjam>=0.6,<1; python_version>="3.7"

paimon-python/pypaimon/common/file_io.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,20 +369,30 @@ def read_overwritten_file_utf8(self, path: str) -> Optional[str]:
369369

370370
return None
371371

372-
def write_parquet(self, path: str, data: pyarrow.Table, compression: str = 'snappy', **kwargs):
372+
def write_parquet(self, path: str, data: pyarrow.Table, compression: str = 'zstd',
373+
zstd_level: int = 1, **kwargs):
373374
try:
374375
import pyarrow.parquet as pq
375376

376377
with self.new_output_stream(path) as output_stream:
378+
if compression.lower() == 'zstd':
379+
kwargs['compression_level'] = zstd_level
377380
pq.write_table(data, output_stream, compression=compression, **kwargs)
378381

379382
except Exception as e:
380383
self.delete_quietly(path)
381384
raise RuntimeError(f"Failed to write Parquet file {path}: {e}") from e
382385

383-
def write_orc(self, path: str, data: pyarrow.Table, compression: str = 'zstd', **kwargs):
386+
def write_orc(self, path: str, data: pyarrow.Table, compression: str = 'zstd',
387+
zstd_level: int = 1, **kwargs):
384388
try:
385-
"""Write ORC file using PyArrow ORC writer."""
389+
"""Write ORC file using PyArrow ORC writer.
390+
391+
Note: PyArrow's ORC writer doesn't support compression_level parameter.
392+
ORC files will use zstd compression with default level
393+
(which is 3, see https://github.com/facebook/zstd/blob/dev/programs/zstdcli.c)
394+
instead of the specified level.
395+
"""
386396
import sys
387397
import pyarrow.orc as orc
388398

@@ -402,7 +412,10 @@ def write_orc(self, path: str, data: pyarrow.Table, compression: str = 'zstd', *
402412
self.delete_quietly(path)
403413
raise RuntimeError(f"Failed to write ORC file {path}: {e}") from e
404414

405-
def write_avro(self, path: str, data: pyarrow.Table, avro_schema: Optional[Dict[str, Any]] = None, **kwargs):
415+
def write_avro(
416+
self, path: str, data: pyarrow.Table,
417+
avro_schema: Optional[Dict[str, Any]] = None,
418+
compression: str = 'zstd', zstd_level: int = 1, **kwargs):
406419
import fastavro
407420
if avro_schema is None:
408421
from pypaimon.schema.data_types import PyarrowFieldParser
@@ -417,8 +430,28 @@ def record_generator():
417430

418431
records = record_generator()
419432

433+
codec_map = {
434+
'null': 'null',
435+
'deflate': 'deflate',
436+
'snappy': 'snappy',
437+
'bzip2': 'bzip2',
438+
'xz': 'xz',
439+
'zstandard': 'zstandard',
440+
'zstd': 'zstandard', # zstd is commonly used in Paimon
441+
}
442+
compression_lower = compression.lower()
443+
444+
codec = codec_map.get(compression_lower)
445+
if codec is None:
446+
raise ValueError(
447+
f"Unsupported compression '{compression}' for Avro format. "
448+
f"Supported compressions: {', '.join(sorted(codec_map.keys()))}."
449+
)
450+
420451
with self.new_output_stream(path) as output_stream:
421-
fastavro.writer(output_stream, avro_schema, records, **kwargs)
452+
if codec == 'zstandard':
453+
kwargs['codec_compression_level'] = zstd_level
454+
fastavro.writer(output_stream, avro_schema, records, codec=codec, **kwargs)
422455

423456
def write_lance(self, path: str, data: pyarrow.Table, **kwargs):
424457
try:

paimon-python/pypaimon/common/options/core_options.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,18 @@ class CoreOptions:
125125
FILE_COMPRESSION: ConfigOption[str] = (
126126
ConfigOptions.key("file.compression")
127127
.string_type()
128-
.default_value("lz4")
129-
.with_description("Default file compression format.")
128+
.default_value("zstd")
129+
.with_description("Default file compression format. For faster read and write, it is recommended to use zstd.")
130+
)
131+
132+
FILE_COMPRESSION_ZSTD_LEVEL: ConfigOption[int] = (
133+
ConfigOptions.key("file.compression.zstd-level")
134+
.int_type()
135+
.default_value(1)
136+
.with_description(
137+
"Default file compression zstd level. For higher compression rates, it can be configured to 9, "
138+
"but the read and write speed will significantly decrease."
139+
)
130140
)
131141

132142
FILE_COMPRESSION_PER_LEVEL: ConfigOption[Dict[str, str]] = (
@@ -346,6 +356,9 @@ def file_format(self, default=None):
346356
def file_compression(self, default=None):
347357
return self.options.get(CoreOptions.FILE_COMPRESSION, default)
348358

359+
def file_compression_zstd_level(self, default=None):
360+
return self.options.get(CoreOptions.FILE_COMPRESSION_ZSTD_LEVEL, default)
361+
349362
def file_compression_per_level(self, default=None):
350363
return self.options.get(CoreOptions.FILE_COMPRESSION_PER_LEVEL, default)
351364

paimon-python/pypaimon/tests/reader_base_test.py

Lines changed: 173 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import pandas as pd
3030
import pyarrow as pa
31+
from parameterized import parameterized
3132

3233
from pypaimon import CatalogFactory, Schema
3334
from pypaimon.manifest.manifest_file_manager import ManifestFileManager
@@ -675,7 +676,12 @@ def test_types(self):
675676
l2.append(field.to_dict())
676677
self.assertEqual(l1, l2)
677678

678-
def test_write(self):
679+
@parameterized.expand([
680+
('parquet',),
681+
('orc',),
682+
('avro',),
683+
])
684+
def test_write(self, file_format):
679685
pa_schema = pa.schema([
680686
('f0', pa.int32()),
681687
('f1', pa.string()),
@@ -684,9 +690,15 @@ def test_write(self):
684690
catalog = CatalogFactory.create({
685691
"warehouse": self.warehouse
686692
})
687-
catalog.create_database("test_write_db", False)
688-
catalog.create_table("test_write_db.test_table", Schema.from_pyarrow_schema(pa_schema), False)
689-
table = catalog.get_table("test_write_db.test_table")
693+
db_name = f"test_write_{file_format}_db"
694+
table_name = f"test_{file_format}_table"
695+
catalog.create_database(db_name, False)
696+
schema = Schema.from_pyarrow_schema(
697+
pa_schema,
698+
options={'file.format': file_format}
699+
)
700+
catalog.create_table(f"{db_name}.{table_name}", schema, False)
701+
table = catalog.get_table(f"{db_name}.{table_name}")
690702

691703
data = {
692704
'f0': [1, 2, 3],
@@ -704,17 +716,7 @@ def test_write(self):
704716
table_write.close()
705717
table_commit.close()
706718

707-
self.assertTrue(os.path.exists(self.warehouse + "/test_write_db.db/test_table/snapshot/LATEST"))
708-
self.assertTrue(os.path.exists(self.warehouse + "/test_write_db.db/test_table/snapshot/snapshot-1"))
709-
self.assertTrue(os.path.exists(self.warehouse + "/test_write_db.db/test_table/manifest"))
710-
self.assertTrue(os.path.exists(self.warehouse + "/test_write_db.db/test_table/bucket-0"))
711-
self.assertEqual(len(glob.glob(self.warehouse + "/test_write_db.db/test_table/manifest/*")), 3)
712-
self.assertEqual(len(glob.glob(self.warehouse + "/test_write_db.db/test_table/bucket-0/*.parquet")), 1)
713-
714-
with open(self.warehouse + '/test_write_db.db/test_table/snapshot/snapshot-1', 'r', encoding='utf-8') as file:
715-
content = ''.join(file.readlines())
716-
self.assertTrue(content.__contains__('\"totalRecordCount\": 3'))
717-
self.assertTrue(content.__contains__('\"deltaRecordCount\": 3'))
719+
self._verify_file_compression(file_format, db_name, table_name, expected_rows=3)
718720

719721
write_builder = table.new_batch_write_builder()
720722
table_write = write_builder.new_write()
@@ -725,11 +727,166 @@ def test_write(self):
725727
table_write.close()
726728
table_commit.close()
727729

728-
with open(self.warehouse + '/test_write_db.db/test_table/snapshot/snapshot-2', 'r', encoding='utf-8') as file:
730+
snapshot_path = os.path.join(self.warehouse, f"{db_name}.db", table_name, "snapshot", "snapshot-2")
731+
with open(snapshot_path, 'r', encoding='utf-8') as file:
729732
content = ''.join(file.readlines())
730733
self.assertTrue(content.__contains__('\"totalRecordCount\": 6'))
731734
self.assertTrue(content.__contains__('\"deltaRecordCount\": 3'))
732735

736+
@parameterized.expand([
737+
('parquet', 'zstd'),
738+
('parquet', 'lz4'),
739+
('parquet', 'snappy'),
740+
('orc', 'zstd'),
741+
('orc', 'lz4'),
742+
('orc', 'snappy'),
743+
('avro', 'zstd'),
744+
('avro', 'snappy'),
745+
])
746+
def test_write_with_compression(self, file_format, compression):
747+
pa_schema = pa.schema([
748+
('f0', pa.int32()),
749+
('f1', pa.string()),
750+
('f2', pa.string())
751+
])
752+
catalog = CatalogFactory.create({
753+
"warehouse": self.warehouse
754+
})
755+
db_name = f"test_write_{file_format}_{compression}_db"
756+
table_name = f"test_{file_format}_{compression}_table"
757+
catalog.create_database(db_name, False)
758+
schema = Schema.from_pyarrow_schema(
759+
pa_schema,
760+
options={
761+
'file.format': file_format,
762+
'file.compression': compression
763+
}
764+
)
765+
catalog.create_table(f"{db_name}.{table_name}", schema, False)
766+
table = catalog.get_table(f"{db_name}.{table_name}")
767+
768+
data = {
769+
'f0': [1, 2, 3],
770+
'f1': ['a', 'b', 'c'],
771+
'f2': ['X', 'Y', 'Z']
772+
}
773+
expect = pa.Table.from_pydict(data, schema=pa_schema)
774+
775+
write_builder = table.new_batch_write_builder()
776+
table_write = write_builder.new_write()
777+
table_commit = write_builder.new_commit()
778+
779+
try:
780+
table_write.write_arrow(expect)
781+
commit_messages = table_write.prepare_commit()
782+
table_commit.commit(commit_messages)
783+
table_write.close()
784+
table_commit.close()
785+
786+
self._verify_file_compression_with_format(
787+
file_format, compression, db_name, table_name, expected_rows=3
788+
)
789+
except (ValueError, RuntimeError):
790+
raise
791+
792+
def _verify_file_compression_with_format(
793+
self, file_format: str, compression: str,
794+
db_name: str, table_name: str, expected_rows: int = 3, expected_zstd_level: int = 1):
795+
if file_format == 'parquet':
796+
parquet_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.parquet")
797+
self.assertEqual(len(parquet_files), 1)
798+
import pyarrow.parquet as pq
799+
parquet_file_path = parquet_files[0]
800+
parquet_metadata = pq.read_metadata(parquet_file_path)
801+
for i in range(parquet_metadata.num_columns):
802+
column_metadata = parquet_metadata.row_group(0).column(i)
803+
actual_compression = column_metadata.compression
804+
compression_str = str(actual_compression).upper()
805+
expected_compression_upper = compression.upper()
806+
self.assertIn(
807+
expected_compression_upper, compression_str,
808+
f"Expected compression to be {compression}, but got {actual_compression}")
809+
if compression.lower() == 'zstd' and hasattr(column_metadata, 'compression_level'):
810+
actual_level = column_metadata.compression_level
811+
self.assertEqual(
812+
actual_level, expected_zstd_level,
813+
f"Expected zstd compression level to be {expected_zstd_level}, but got {actual_level}")
814+
elif file_format == 'orc':
815+
orc_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.orc")
816+
self.assertEqual(len(orc_files), 1)
817+
import pyarrow.orc as orc
818+
orc_file_path = orc_files[0]
819+
orc_file = orc.ORCFile(orc_file_path)
820+
try:
821+
table = orc_file.read()
822+
self.assertEqual(table.num_rows, expected_rows, "ORC file should contain expected rows")
823+
except Exception as e:
824+
self.fail(f"Failed to read ORC file (compression may be incorrect): {e}")
825+
elif file_format == 'avro':
826+
avro_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.avro")
827+
self.assertEqual(len(avro_files), 1)
828+
import fastavro
829+
avro_file_path = avro_files[0]
830+
with open(avro_file_path, 'rb') as f:
831+
reader = fastavro.reader(f)
832+
codec = reader.codec
833+
expected_codec_map = {
834+
'zstd': 'zstandard',
835+
'zstandard': 'zstandard',
836+
'snappy': 'snappy',
837+
'deflate': 'deflate',
838+
}
839+
expected_codec = expected_codec_map.get(
840+
compression.lower(), compression.lower())
841+
self.assertEqual(
842+
codec, expected_codec,
843+
f"Expected compression codec to be '{expected_codec}', but got '{codec}'")
844+
845+
def _verify_file_compression(self, file_format: str, db_name: str, table_name: str,
846+
expected_rows: int = 3, expected_zstd_level: int = 1):
847+
if file_format == 'parquet':
848+
parquet_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.parquet")
849+
self.assertEqual(len(parquet_files), 1)
850+
import pyarrow.parquet as pq
851+
parquet_file_path = parquet_files[0]
852+
parquet_metadata = pq.read_metadata(parquet_file_path)
853+
for i in range(parquet_metadata.num_columns):
854+
column_metadata = parquet_metadata.row_group(0).column(i)
855+
compression = column_metadata.compression
856+
compression_str = str(compression).upper()
857+
self.assertIn(
858+
'ZSTD', compression_str,
859+
f"Expected compression to be ZSTD , "
860+
f"but got {compression}")
861+
if hasattr(column_metadata, 'compression_level'):
862+
actual_level = column_metadata.compression_level
863+
self.assertEqual(
864+
actual_level, expected_zstd_level,
865+
f"Expected zstd compression level to be {expected_zstd_level}, but got {actual_level}")
866+
elif file_format == 'orc':
867+
orc_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.orc")
868+
self.assertEqual(len(orc_files), 1)
869+
import pyarrow.orc as orc
870+
orc_file_path = orc_files[0]
871+
orc_file = orc.ORCFile(orc_file_path)
872+
try:
873+
table = orc_file.read()
874+
self.assertEqual(table.num_rows, expected_rows, "ORC file should contain expected rows")
875+
except Exception as e:
876+
self.fail(f"Failed to read ORC file (compression may be incorrect): {e}")
877+
elif file_format == 'avro':
878+
avro_files = glob.glob(self.warehouse + f"/{db_name}.db/{table_name}/bucket-0/*.avro")
879+
self.assertEqual(len(avro_files), 1)
880+
import fastavro
881+
avro_file_path = avro_files[0]
882+
with open(avro_file_path, 'rb') as f:
883+
reader = fastavro.reader(f)
884+
codec = reader.codec
885+
self.assertEqual(
886+
codec, 'zstandard',
887+
f"Expected compression codec to be 'zstandard', "
888+
f"but got '{codec}'")
889+
733890
def _test_value_stats_cols_case(self, manifest_manager, table, value_stats_cols, expected_fields_count, test_name):
734891
"""Helper method to test a specific _VALUE_STATS_COLS case."""
735892

paimon-python/pypaimon/write/writer/data_blob_writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,11 @@ def _write_normal_data_to_file(self, data: pa.Table) -> Optional[DataFileMeta]:
255255

256256
# Write file based on format
257257
if self.file_format == CoreOptions.FILE_FORMAT_PARQUET:
258-
self.file_io.write_parquet(file_path, data, compression=self.compression)
258+
self.file_io.write_parquet(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
259259
elif self.file_format == CoreOptions.FILE_FORMAT_ORC:
260-
self.file_io.write_orc(file_path, data, compression=self.compression)
260+
self.file_io.write_orc(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
261261
elif self.file_format == CoreOptions.FILE_FORMAT_AVRO:
262-
self.file_io.write_avro(file_path, data)
262+
self.file_io.write_avro(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
263263
elif self.file_format == CoreOptions.FILE_FORMAT_LANCE:
264264
self.file_io.write_lance(file_path, data)
265265
else:

paimon-python/pypaimon/write/writer/data_writer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, table, partition: Tuple, bucket: int, max_seq_number: int, op
5656
)
5757
self.file_format = self.options.file_format(default_format)
5858
self.compression = self.options.file_compression()
59+
self.zstd_level = self.options.file_compression_zstd_level()
5960
self.sequence_generator = SequenceGenerator(max_seq_number)
6061

6162
self.pending_data: Optional[pa.Table] = None
@@ -169,11 +170,11 @@ def _write_data_to_file(self, data: pa.Table):
169170
external_path_str = None
170171

171172
if self.file_format == CoreOptions.FILE_FORMAT_PARQUET:
172-
self.file_io.write_parquet(file_path, data, compression=self.compression)
173+
self.file_io.write_parquet(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
173174
elif self.file_format == CoreOptions.FILE_FORMAT_ORC:
174-
self.file_io.write_orc(file_path, data, compression=self.compression)
175+
self.file_io.write_orc(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
175176
elif self.file_format == CoreOptions.FILE_FORMAT_AVRO:
176-
self.file_io.write_avro(file_path, data)
177+
self.file_io.write_avro(file_path, data, compression=self.compression, zstd_level=self.zstd_level)
177178
elif self.file_format == CoreOptions.FILE_FORMAT_BLOB:
178179
self.file_io.write_blob(file_path, data, self.blob_as_descriptor)
179180
elif self.file_format == CoreOptions.FILE_FORMAT_LANCE:

0 commit comments

Comments
 (0)