2828
2929import pandas as pd
3030import pyarrow as pa
31+ from parameterized import parameterized
3132
3233from pypaimon import CatalogFactory , Schema
3334from pypaimon .manifest .manifest_file_manager import ManifestFileManager
@@ -675,7 +676,12 @@ def test_types(self):
675676 l2 .append (field .to_dict ())
676677 self .assertEqual (l1 , l2 )
677678
678- def test_write (self ):
679+ @parameterized .expand ([
680+ ('parquet' ,),
681+ ('orc' ,),
682+ ('avro' ,),
683+ ])
684+ def test_write (self , file_format ):
679685 pa_schema = pa .schema ([
680686 ('f0' , pa .int32 ()),
681687 ('f1' , pa .string ()),
@@ -684,9 +690,15 @@ def test_write(self):
684690 catalog = CatalogFactory .create ({
685691 "warehouse" : self .warehouse
686692 })
687- catalog .create_database ("test_write_db" , False )
688- catalog .create_table ("test_write_db.test_table" , Schema .from_pyarrow_schema (pa_schema ), False )
689- table = catalog .get_table ("test_write_db.test_table" )
693+ db_name = f"test_write_{ file_format } _db"
694+ table_name = f"test_{ file_format } _table"
695+ catalog .create_database (db_name , False )
696+ schema = Schema .from_pyarrow_schema (
697+ pa_schema ,
698+ options = {'file.format' : file_format }
699+ )
700+ catalog .create_table (f"{ db_name } .{ table_name } " , schema , False )
701+ table = catalog .get_table (f"{ db_name } .{ table_name } " )
690702
691703 data = {
692704 'f0' : [1 , 2 , 3 ],
@@ -704,17 +716,7 @@ def test_write(self):
704716 table_write .close ()
705717 table_commit .close ()
706718
707- self .assertTrue (os .path .exists (self .warehouse + "/test_write_db.db/test_table/snapshot/LATEST" ))
708- self .assertTrue (os .path .exists (self .warehouse + "/test_write_db.db/test_table/snapshot/snapshot-1" ))
709- self .assertTrue (os .path .exists (self .warehouse + "/test_write_db.db/test_table/manifest" ))
710- self .assertTrue (os .path .exists (self .warehouse + "/test_write_db.db/test_table/bucket-0" ))
711- self .assertEqual (len (glob .glob (self .warehouse + "/test_write_db.db/test_table/manifest/*" )), 3 )
712- self .assertEqual (len (glob .glob (self .warehouse + "/test_write_db.db/test_table/bucket-0/*.parquet" )), 1 )
713-
714- with open (self .warehouse + '/test_write_db.db/test_table/snapshot/snapshot-1' , 'r' , encoding = 'utf-8' ) as file :
715- content = '' .join (file .readlines ())
716- self .assertTrue (content .__contains__ ('\" totalRecordCount\" : 3' ))
717- self .assertTrue (content .__contains__ ('\" deltaRecordCount\" : 3' ))
719+ self ._verify_file_compression (file_format , db_name , table_name , expected_rows = 3 )
718720
719721 write_builder = table .new_batch_write_builder ()
720722 table_write = write_builder .new_write ()
@@ -725,11 +727,166 @@ def test_write(self):
725727 table_write .close ()
726728 table_commit .close ()
727729
728- with open (self .warehouse + '/test_write_db.db/test_table/snapshot/snapshot-2' , 'r' , encoding = 'utf-8' ) as file :
730+ snapshot_path = os .path .join (self .warehouse , f"{ db_name } .db" , table_name , "snapshot" , "snapshot-2" )
731+ with open (snapshot_path , 'r' , encoding = 'utf-8' ) as file :
729732 content = '' .join (file .readlines ())
730733 self .assertTrue (content .__contains__ ('\" totalRecordCount\" : 6' ))
731734 self .assertTrue (content .__contains__ ('\" deltaRecordCount\" : 3' ))
732735
736+ @parameterized .expand ([
737+ ('parquet' , 'zstd' ),
738+ ('parquet' , 'lz4' ),
739+ ('parquet' , 'snappy' ),
740+ ('orc' , 'zstd' ),
741+ ('orc' , 'lz4' ),
742+ ('orc' , 'snappy' ),
743+ ('avro' , 'zstd' ),
744+ ('avro' , 'snappy' ),
745+ ])
746+ def test_write_with_compression (self , file_format , compression ):
747+ pa_schema = pa .schema ([
748+ ('f0' , pa .int32 ()),
749+ ('f1' , pa .string ()),
750+ ('f2' , pa .string ())
751+ ])
752+ catalog = CatalogFactory .create ({
753+ "warehouse" : self .warehouse
754+ })
755+ db_name = f"test_write_{ file_format } _{ compression } _db"
756+ table_name = f"test_{ file_format } _{ compression } _table"
757+ catalog .create_database (db_name , False )
758+ schema = Schema .from_pyarrow_schema (
759+ pa_schema ,
760+ options = {
761+ 'file.format' : file_format ,
762+ 'file.compression' : compression
763+ }
764+ )
765+ catalog .create_table (f"{ db_name } .{ table_name } " , schema , False )
766+ table = catalog .get_table (f"{ db_name } .{ table_name } " )
767+
768+ data = {
769+ 'f0' : [1 , 2 , 3 ],
770+ 'f1' : ['a' , 'b' , 'c' ],
771+ 'f2' : ['X' , 'Y' , 'Z' ]
772+ }
773+ expect = pa .Table .from_pydict (data , schema = pa_schema )
774+
775+ write_builder = table .new_batch_write_builder ()
776+ table_write = write_builder .new_write ()
777+ table_commit = write_builder .new_commit ()
778+
779+ try :
780+ table_write .write_arrow (expect )
781+ commit_messages = table_write .prepare_commit ()
782+ table_commit .commit (commit_messages )
783+ table_write .close ()
784+ table_commit .close ()
785+
786+ self ._verify_file_compression_with_format (
787+ file_format , compression , db_name , table_name , expected_rows = 3
788+ )
789+ except (ValueError , RuntimeError ):
790+ raise
791+
792+ def _verify_file_compression_with_format (
793+ self , file_format : str , compression : str ,
794+ db_name : str , table_name : str , expected_rows : int = 3 , expected_zstd_level : int = 1 ):
795+ if file_format == 'parquet' :
796+ parquet_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.parquet" )
797+ self .assertEqual (len (parquet_files ), 1 )
798+ import pyarrow .parquet as pq
799+ parquet_file_path = parquet_files [0 ]
800+ parquet_metadata = pq .read_metadata (parquet_file_path )
801+ for i in range (parquet_metadata .num_columns ):
802+ column_metadata = parquet_metadata .row_group (0 ).column (i )
803+ actual_compression = column_metadata .compression
804+ compression_str = str (actual_compression ).upper ()
805+ expected_compression_upper = compression .upper ()
806+ self .assertIn (
807+ expected_compression_upper , compression_str ,
808+ f"Expected compression to be { compression } , but got { actual_compression } " )
809+ if compression .lower () == 'zstd' and hasattr (column_metadata , 'compression_level' ):
810+ actual_level = column_metadata .compression_level
811+ self .assertEqual (
812+ actual_level , expected_zstd_level ,
813+ f"Expected zstd compression level to be { expected_zstd_level } , but got { actual_level } " )
814+ elif file_format == 'orc' :
815+ orc_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.orc" )
816+ self .assertEqual (len (orc_files ), 1 )
817+ import pyarrow .orc as orc
818+ orc_file_path = orc_files [0 ]
819+ orc_file = orc .ORCFile (orc_file_path )
820+ try :
821+ table = orc_file .read ()
822+ self .assertEqual (table .num_rows , expected_rows , "ORC file should contain expected rows" )
823+ except Exception as e :
824+ self .fail (f"Failed to read ORC file (compression may be incorrect): { e } " )
825+ elif file_format == 'avro' :
826+ avro_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.avro" )
827+ self .assertEqual (len (avro_files ), 1 )
828+ import fastavro
829+ avro_file_path = avro_files [0 ]
830+ with open (avro_file_path , 'rb' ) as f :
831+ reader = fastavro .reader (f )
832+ codec = reader .codec
833+ expected_codec_map = {
834+ 'zstd' : 'zstandard' ,
835+ 'zstandard' : 'zstandard' ,
836+ 'snappy' : 'snappy' ,
837+ 'deflate' : 'deflate' ,
838+ }
839+ expected_codec = expected_codec_map .get (
840+ compression .lower (), compression .lower ())
841+ self .assertEqual (
842+ codec , expected_codec ,
843+ f"Expected compression codec to be '{ expected_codec } ', but got '{ codec } '" )
844+
845+ def _verify_file_compression (self , file_format : str , db_name : str , table_name : str ,
846+ expected_rows : int = 3 , expected_zstd_level : int = 1 ):
847+ if file_format == 'parquet' :
848+ parquet_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.parquet" )
849+ self .assertEqual (len (parquet_files ), 1 )
850+ import pyarrow .parquet as pq
851+ parquet_file_path = parquet_files [0 ]
852+ parquet_metadata = pq .read_metadata (parquet_file_path )
853+ for i in range (parquet_metadata .num_columns ):
854+ column_metadata = parquet_metadata .row_group (0 ).column (i )
855+ compression = column_metadata .compression
856+ compression_str = str (compression ).upper ()
857+ self .assertIn (
858+ 'ZSTD' , compression_str ,
859+ f"Expected compression to be ZSTD , "
860+ f"but got { compression } " )
861+ if hasattr (column_metadata , 'compression_level' ):
862+ actual_level = column_metadata .compression_level
863+ self .assertEqual (
864+ actual_level , expected_zstd_level ,
865+ f"Expected zstd compression level to be { expected_zstd_level } , but got { actual_level } " )
866+ elif file_format == 'orc' :
867+ orc_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.orc" )
868+ self .assertEqual (len (orc_files ), 1 )
869+ import pyarrow .orc as orc
870+ orc_file_path = orc_files [0 ]
871+ orc_file = orc .ORCFile (orc_file_path )
872+ try :
873+ table = orc_file .read ()
874+ self .assertEqual (table .num_rows , expected_rows , "ORC file should contain expected rows" )
875+ except Exception as e :
876+ self .fail (f"Failed to read ORC file (compression may be incorrect): { e } " )
877+ elif file_format == 'avro' :
878+ avro_files = glob .glob (self .warehouse + f"/{ db_name } .db/{ table_name } /bucket-0/*.avro" )
879+ self .assertEqual (len (avro_files ), 1 )
880+ import fastavro
881+ avro_file_path = avro_files [0 ]
882+ with open (avro_file_path , 'rb' ) as f :
883+ reader = fastavro .reader (f )
884+ codec = reader .codec
885+ self .assertEqual (
886+ codec , 'zstandard' ,
887+ f"Expected compression codec to be 'zstandard', "
888+ f"but got '{ codec } '" )
889+
733890 def _test_value_stats_cols_case (self , manifest_manager , table , value_stats_cols , expected_fields_count , test_name ):
734891 """Helper method to test a specific _VALUE_STATS_COLS case."""
735892
0 commit comments