|
19 | 19 |
|
20 | 20 | package io.milvus.bulkwriter.common.utils; |
21 | 21 |
|
22 | | -import io.milvus.param.collection.CollectionSchemaParam; |
23 | | -import io.milvus.param.collection.FieldType; |
24 | 22 | import org.apache.parquet.schema.LogicalTypeAnnotation; |
25 | 23 | import org.apache.parquet.schema.MessageType; |
26 | 24 | import org.apache.parquet.schema.PrimitiveType; |
27 | 25 | import org.apache.parquet.schema.Types; |
| 26 | +import io.milvus.v2.service.collection.request.CreateCollectionReq; |
28 | 27 |
|
29 | 28 | import java.util.List; |
30 | 29 |
|
31 | 30 | import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME; |
32 | 31 |
|
33 | 32 | public class ParquetUtils { |
34 | | - public static MessageType parseCollectionSchema(CollectionSchemaParam collectionSchema) { |
35 | | - List<FieldType> fieldTypes = collectionSchema.getFieldTypes(); |
| 33 | + private static void setMessageType(Types.MessageTypeBuilder builder, |
| 34 | + PrimitiveType.PrimitiveTypeName primitiveName, |
| 35 | + LogicalTypeAnnotation logicType, |
| 36 | + CreateCollectionReq.FieldSchema field, |
| 37 | + boolean isListType) { |
| 38 | + // Note: |
| 39 | + // Ideally, if the field is nullable, the builder should be builder.requiredList() or builder.required(). |
| 40 | + // But in milvus (versions <= v2.5.4), the milvus server logic cannot handle parquet files with |
| 41 | + // requiredList()/required(), the server will crash in the file /internal/util/importutilv2/parquet/field_reader.go, |
| 42 | + // in the parquet.FieldReader.Next() with a runtime error: "index out of range [0] with length 0". |
| 43 | + // This issue is tracked by https://github.com/milvus-io/milvus/issues/40291 |
| 44 | + // The python sdk BulkWriter uses Pandas to generate parquet files, the Pandas sets all schema to be "optional" |
| 45 | + // so that the crash is by-passed. |
| 46 | + // To avoid the crash, in Java SDK, we use optionalList()/optional() even if the field is nullable. |
| 47 | + if (isListType) { |
| 48 | + // FloatVector/BinaryVector/Float16Vector/BFloat16Vector/Array enter this section |
| 49 | + if (logicType == null) { |
| 50 | + builder.optionalList().optionalElement(primitiveName).named(field.getName()); |
| 51 | + } else { |
| 52 | + builder.optionalList().optionalElement(primitiveName).as(logicType).named(field.getName()); |
| 53 | + } |
| 54 | + } else { |
| 55 | + // SparseFloatVector/Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar/JSON enter this section |
| 56 | + if (logicType == null) { |
| 57 | + builder.optional(primitiveName).named(field.getName()); |
| 58 | + } else { |
| 59 | + builder.optional(primitiveName).as(logicType).named(field.getName()); |
| 60 | + } |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSchema collectionSchema) { |
| 65 | + List<CreateCollectionReq.FieldSchema> fields = collectionSchema.getFieldSchemaList(); |
| 66 | + List<String> outputFieldNames = V2AdapterUtils.getOutputFieldNames(collectionSchema); |
36 | 67 | Types.MessageTypeBuilder messageTypeBuilder = Types.buildMessage(); |
37 | | - for (FieldType fieldType : fieldTypes) { |
38 | | - if (fieldType.isAutoID()) { |
| 68 | + for (CreateCollectionReq.FieldSchema field : fields) { |
| 69 | + if (field.getIsPrimaryKey() && field.getAutoID()) { |
39 | 70 | continue; |
40 | 71 | } |
41 | | - switch (fieldType.getDataType()) { |
| 72 | + if (outputFieldNames.contains(field.getName())) { |
| 73 | + continue; |
| 74 | + } |
| 75 | + |
| 76 | + switch (field.getDataType()) { |
42 | 77 | case FloatVector: |
43 | | - messageTypeBuilder.requiredList() |
44 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.FLOAT) |
45 | | - .named(fieldType.getName()); |
| 78 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true); |
46 | 79 | break; |
47 | 80 | case BinaryVector: |
48 | 81 | case Float16Vector: |
49 | 82 | case BFloat16Vector: |
50 | | - messageTypeBuilder.requiredList() |
51 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false)) |
52 | | - .named(fieldType.getName()); |
| 83 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, |
| 84 | + LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false), field, true); |
53 | 85 | break; |
54 | 86 | case Array: |
55 | | - fillArrayType(messageTypeBuilder, fieldType); |
| 87 | + fillArrayType(messageTypeBuilder, field); |
56 | 88 | break; |
57 | 89 |
|
58 | 90 | case Int64: |
59 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT64) |
60 | | - .named(fieldType.getName()); |
| 91 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, false); |
61 | 92 | break; |
62 | 93 | case VarChar: |
63 | 94 | case JSON: |
64 | 95 | case SparseFloatVector: // sparse vector is parsed as JSON format string in the server side |
65 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()) |
66 | | - .named(fieldType.getName()); |
| 96 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY, |
| 97 | + LogicalTypeAnnotation.stringType(), field, false); |
67 | 98 | break; |
68 | 99 | case Int8: |
69 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true)) |
70 | | - .named(fieldType.getName()); |
| 100 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, |
| 101 | + LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, false); |
71 | 102 | break; |
72 | 103 | case Int16: |
73 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true)) |
74 | | - .named(fieldType.getName()); |
| 104 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, |
| 105 | + LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, false); |
75 | 106 | break; |
76 | 107 | case Int32: |
77 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32) |
78 | | - .named(fieldType.getName()); |
| 108 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, false); |
79 | 109 | break; |
80 | 110 | case Float: |
81 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.FLOAT) |
82 | | - .named(fieldType.getName()); |
| 111 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, false); |
83 | 112 | break; |
84 | 113 | case Double: |
85 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.DOUBLE) |
86 | | - .named(fieldType.getName()); |
| 114 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, false); |
87 | 115 | break; |
88 | 116 | case Bool: |
89 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BOOLEAN) |
90 | | - .named(fieldType.getName()); |
| 117 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, false); |
91 | 118 | break; |
92 | 119 |
|
93 | 120 | } |
94 | 121 | } |
95 | 122 |
|
96 | 123 | if (collectionSchema.isEnableDynamicField()) { |
97 | | - messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()) |
| 124 | + messageTypeBuilder.optional(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()) |
98 | 125 | .named(DYNAMIC_FIELD_NAME); |
99 | 126 | } |
100 | 127 | return messageTypeBuilder.named("schema"); |
101 | 128 | } |
102 | 129 |
|
103 | | - private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, FieldType fieldType) { |
104 | | - switch (fieldType.getElementType()) { |
| 130 | + private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, CreateCollectionReq.FieldSchema field) { |
| 131 | + switch (field.getElementType()) { |
105 | 132 | case Int64: |
106 | | - messageTypeBuilder.requiredList() |
107 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.INT64) |
108 | | - .named(fieldType.getName()); |
| 133 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, true); |
109 | 134 | break; |
110 | 135 | case VarChar: |
111 | | - messageTypeBuilder.requiredList() |
112 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType()) |
113 | | - .named(fieldType.getName()); |
| 136 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY, |
| 137 | + LogicalTypeAnnotation.stringType(), field, true); |
114 | 138 | break; |
115 | 139 | case Int8: |
116 | | - messageTypeBuilder.requiredList() |
117 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true)) |
118 | | - .named(fieldType.getName()); |
| 140 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, |
| 141 | + LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, true); |
119 | 142 | break; |
120 | 143 | case Int16: |
121 | | - messageTypeBuilder.requiredList() |
122 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true)) |
123 | | - .named(fieldType.getName()); |
| 144 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, |
| 145 | + LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, true); |
124 | 146 | break; |
125 | 147 | case Int32: |
126 | | - messageTypeBuilder.requiredList() |
127 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.INT32) |
128 | | - .named(fieldType.getName()); |
| 148 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, true); |
129 | 149 | break; |
130 | 150 | case Float: |
131 | | - messageTypeBuilder.requiredList() |
132 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.FLOAT) |
133 | | - .named(fieldType.getName()); |
| 151 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true); |
134 | 152 | break; |
135 | 153 | case Double: |
136 | | - messageTypeBuilder.requiredList() |
137 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.DOUBLE) |
138 | | - .named(fieldType.getName()); |
| 154 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, true); |
139 | 155 | break; |
140 | 156 | case Bool: |
141 | | - messageTypeBuilder.requiredList() |
142 | | - .requiredElement(PrimitiveType.PrimitiveTypeName.BOOLEAN) |
143 | | - .named(fieldType.getName()); |
| 157 | + setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, true); |
144 | 158 | break; |
145 | | - |
146 | 159 | } |
147 | 160 | } |
148 | 161 | } |
0 commit comments