Skip to content

Commit e773008

Browse files
authored
BulkWriter supports nullable/default_value (#1313)
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent f71c685 commit e773008

15 files changed

Lines changed: 1013 additions & 440 deletions

File tree

examples/src/main/java/io/milvus/v2/BulkWriterExample.java

Lines changed: 283 additions & 145 deletions
Large diffs are not rendered by default.

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/BulkWriter.java

Lines changed: 136 additions & 46 deletions
Large diffs are not rendered by default.

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/LocalBulkWriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import com.google.gson.JsonObject;
2424
import io.milvus.bulkwriter.common.clientenum.BulkFileType;
2525
import io.milvus.bulkwriter.writer.FormatFileWriter;
26-
import io.milvus.param.collection.CollectionSchemaParam;
26+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
2727
import org.apache.commons.collections4.CollectionUtils;
2828
import org.slf4j.Logger;
2929
import org.slf4j.LoggerFactory;
@@ -38,7 +38,7 @@
3838
import java.util.concurrent.TimeUnit;
3939
import java.util.concurrent.locks.ReentrantLock;
4040

41-
public class LocalBulkWriter extends BulkWriter implements AutoCloseable {
41+
public class LocalBulkWriter extends BulkWriter {
4242
private static final Logger logger = LoggerFactory.getLogger(LocalBulkWriter.class);
4343

4444
private Map<String, Thread> workingThread;
@@ -52,7 +52,7 @@ public LocalBulkWriter(LocalBulkWriterParam bulkWriterParam) throws IOException
5252
this.localFiles = Lists.newArrayList();
5353
}
5454

55-
protected LocalBulkWriter(CollectionSchemaParam collectionSchema,
55+
protected LocalBulkWriter(CreateCollectionReq.CollectionSchema collectionSchema,
5656
long chunkSize,
5757
BulkFileType fileType,
5858
String localPath,

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/LocalBulkWriterParam.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@Getter
3939
@ToString
4040
public class LocalBulkWriterParam {
41-
private final CollectionSchemaParam collectionSchema;
41+
private final CreateCollectionReq.CollectionSchema collectionSchema;
4242
private final String localPath;
4343
private final long chunkSize;
4444
private final BulkFileType fileType;
@@ -60,7 +60,7 @@ public static Builder newBuilder() {
6060
* Builder for {@link LocalBulkWriterParam} class.
6161
*/
6262
public static final class Builder {
63-
private CollectionSchemaParam collectionSchema;
63+
private CreateCollectionReq.CollectionSchema collectionSchema;
6464
private String localPath;
6565
private long chunkSize = 128 * 1024 * 1024;
6666
private BulkFileType fileType = BulkFileType.PARQUET;
@@ -76,7 +76,7 @@ private Builder() {
7676
* @return <code>Builder</code>
7777
*/
7878
public Builder withCollectionSchema(@NonNull CollectionSchemaParam collectionSchema) {
79-
this.collectionSchema = collectionSchema;
79+
this.collectionSchema = V2AdapterUtils.convertV1Schema(collectionSchema);
8080
return this;
8181
}
8282

@@ -87,7 +87,7 @@ public Builder withCollectionSchema(@NonNull CollectionSchemaParam collectionSch
8787
* @return <code>Builder</code>
8888
*/
8989
public Builder withCollectionSchema(@NonNull CreateCollectionReq.CollectionSchema collectionSchema) {
90-
this.collectionSchema = V2AdapterUtils.convertV2Schema(collectionSchema);
90+
this.collectionSchema = collectionSchema;
9191
return this;
9292
}
9393

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriter.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,6 @@ protected void callBack(List<String> fileList) {
168168
}
169169

170170
for (String filePath : fileList) {
171-
String ext = getExtension(filePath);
172-
if (!Lists.newArrayList(".parquet").contains(ext)) {
173-
continue;
174-
}
175-
176171
String relativeFilePath = filePath.replace(super.getDataPath(), "");
177172
String minioFilePath = getMinioFilePath(remotePath, relativeFilePath);
178173

@@ -277,16 +272,4 @@ private static String getMinioFilePath(String remotePath, String relativeFilePat
277272
Path joinedPath = remote.resolve(relative);
278273
return joinedPath.toString();
279274
}
280-
281-
private static String getExtension(String filePath) {
282-
Path path = Paths.get(filePath);
283-
String fileName = path.getFileName().toString();
284-
int dotIndex = fileName.lastIndexOf('.');
285-
286-
if (dotIndex == -1 || dotIndex == fileName.length() - 1) {
287-
return "";
288-
} else {
289-
return fileName.substring(dotIndex);
290-
}
291-
}
292275
}

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/RemoteBulkWriterParam.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
@Getter
4141
@ToString
4242
public class RemoteBulkWriterParam {
43-
private final CollectionSchemaParam collectionSchema;
43+
private final CreateCollectionReq.CollectionSchema collectionSchema;
4444
private final StorageConnectParam connectParam;
4545
private final String remotePath;
4646
private final long chunkSize;
@@ -64,7 +64,7 @@ public static Builder newBuilder() {
6464
* Builder for {@link RemoteBulkWriterParam} class.
6565
*/
6666
public static final class Builder {
67-
private CollectionSchemaParam collectionSchema;
67+
private CreateCollectionReq.CollectionSchema collectionSchema;
6868
private StorageConnectParam connectParam;
6969
private String remotePath;
7070
private long chunkSize = 128 * 1024 * 1024;
@@ -81,7 +81,7 @@ private Builder() {
8181
* @return <code>Builder</code>
8282
*/
8383
public Builder withCollectionSchema(@NonNull CollectionSchemaParam collectionSchema) {
84-
this.collectionSchema = collectionSchema;
84+
this.collectionSchema = V2AdapterUtils.convertV1Schema(collectionSchema);
8585
return this;
8686
}
8787

@@ -92,7 +92,7 @@ public Builder withCollectionSchema(@NonNull CollectionSchemaParam collectionSch
9292
* @return <code>Builder</code>
9393
*/
9494
public Builder withCollectionSchema(@NonNull CreateCollectionReq.CollectionSchema collectionSchema) {
95-
this.collectionSchema = V2AdapterUtils.convertV2Schema(collectionSchema);
95+
this.collectionSchema = collectionSchema;
9696
return this;
9797
}
9898

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/clientenum/TypeSize.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package io.milvus.bulkwriter.common.clientenum;
2121

2222
import io.milvus.exception.ParamException;
23-
import io.milvus.grpc.DataType;
23+
import io.milvus.v2.common.DataType;
2424

2525
public enum TypeSize {
2626
BOOL(DataType.Bool, 1),

sdk-bulkwriter/src/main/java/io/milvus/bulkwriter/common/utils/ParquetUtils.java

Lines changed: 71 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,130 +19,143 @@
1919

2020
package io.milvus.bulkwriter.common.utils;
2121

22-
import io.milvus.param.collection.CollectionSchemaParam;
23-
import io.milvus.param.collection.FieldType;
2422
import org.apache.parquet.schema.LogicalTypeAnnotation;
2523
import org.apache.parquet.schema.MessageType;
2624
import org.apache.parquet.schema.PrimitiveType;
2725
import org.apache.parquet.schema.Types;
26+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
2827

2928
import java.util.List;
3029

3130
import static io.milvus.param.Constant.DYNAMIC_FIELD_NAME;
3231

3332
public class ParquetUtils {
34-
public static MessageType parseCollectionSchema(CollectionSchemaParam collectionSchema) {
35-
List<FieldType> fieldTypes = collectionSchema.getFieldTypes();
33+
private static void setMessageType(Types.MessageTypeBuilder builder,
34+
PrimitiveType.PrimitiveTypeName primitiveName,
35+
LogicalTypeAnnotation logicType,
36+
CreateCollectionReq.FieldSchema field,
37+
boolean isListType) {
38+
// Note:
39+
// Ideally, if the field is nullable, the builder should be builder.requiredList() or builder.required().
40+
// But in milvus (versions <= v2.5.4), the milvus server logic cannot handle parquet files with
41+
// requiredList()/required(), the server will crash in the file /internal/util/importutilv2/parquet/field_reader.go,
42+
// in the parquet.FieldReader.Next() with a runtime error: "index out of range [0] with length 0".
43+
// This issue is tracked by https://github.com/milvus-io/milvus/issues/40291
44+
// The python sdk BulkWriter uses Pandas to generate parquet files, the Pandas sets all schema to be "optional"
45+
// so that the crash is by-passed.
46+
// To avoid the crash, in Java SDK, we use optionalList()/optional() even if the field is nullable.
47+
if (isListType) {
48+
// FloatVector/BinaryVector/Float16Vector/BFloat16Vector/Array enter this section
49+
if (logicType == null) {
50+
builder.optionalList().optionalElement(primitiveName).named(field.getName());
51+
} else {
52+
builder.optionalList().optionalElement(primitiveName).as(logicType).named(field.getName());
53+
}
54+
} else {
55+
// SparseFloatVector/Bool/Int8/Int16/Int32/Int64/Float/Double/Varchar/JSON enter this section
56+
if (logicType == null) {
57+
builder.optional(primitiveName).named(field.getName());
58+
} else {
59+
builder.optional(primitiveName).as(logicType).named(field.getName());
60+
}
61+
}
62+
}
63+
64+
public static MessageType parseCollectionSchema(CreateCollectionReq.CollectionSchema collectionSchema) {
65+
List<CreateCollectionReq.FieldSchema> fields = collectionSchema.getFieldSchemaList();
66+
List<String> outputFieldNames = V2AdapterUtils.getOutputFieldNames(collectionSchema);
3667
Types.MessageTypeBuilder messageTypeBuilder = Types.buildMessage();
37-
for (FieldType fieldType : fieldTypes) {
38-
if (fieldType.isAutoID()) {
68+
for (CreateCollectionReq.FieldSchema field : fields) {
69+
if (field.getIsPrimaryKey() && field.getAutoID()) {
3970
continue;
4071
}
41-
switch (fieldType.getDataType()) {
72+
if (outputFieldNames.contains(field.getName())) {
73+
continue;
74+
}
75+
76+
switch (field.getDataType()) {
4277
case FloatVector:
43-
messageTypeBuilder.requiredList()
44-
.requiredElement(PrimitiveType.PrimitiveTypeName.FLOAT)
45-
.named(fieldType.getName());
78+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true);
4679
break;
4780
case BinaryVector:
4881
case Float16Vector:
4982
case BFloat16Vector:
50-
messageTypeBuilder.requiredList()
51-
.requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false))
52-
.named(fieldType.getName());
83+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
84+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, false), field, true);
5385
break;
5486
case Array:
55-
fillArrayType(messageTypeBuilder, fieldType);
87+
fillArrayType(messageTypeBuilder, field);
5688
break;
5789

5890
case Int64:
59-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT64)
60-
.named(fieldType.getName());
91+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, false);
6192
break;
6293
case VarChar:
6394
case JSON:
6495
case SparseFloatVector: // sparse vector is parsed as JSON format string in the server side
65-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
66-
.named(fieldType.getName());
96+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
97+
LogicalTypeAnnotation.stringType(), field, false);
6798
break;
6899
case Int8:
69-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true))
70-
.named(fieldType.getName());
100+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
101+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, false);
71102
break;
72103
case Int16:
73-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true))
74-
.named(fieldType.getName());
104+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
105+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, false);
75106
break;
76107
case Int32:
77-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.INT32)
78-
.named(fieldType.getName());
108+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, false);
79109
break;
80110
case Float:
81-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.FLOAT)
82-
.named(fieldType.getName());
111+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, false);
83112
break;
84113
case Double:
85-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.DOUBLE)
86-
.named(fieldType.getName());
114+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, false);
87115
break;
88116
case Bool:
89-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BOOLEAN)
90-
.named(fieldType.getName());
117+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, false);
91118
break;
92119

93120
}
94121
}
95122

96123
if (collectionSchema.isEnableDynamicField()) {
97-
messageTypeBuilder.required(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
124+
messageTypeBuilder.optional(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
98125
.named(DYNAMIC_FIELD_NAME);
99126
}
100127
return messageTypeBuilder.named("schema");
101128
}
102129

103-
private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, FieldType fieldType) {
104-
switch (fieldType.getElementType()) {
130+
private static void fillArrayType(Types.MessageTypeBuilder messageTypeBuilder, CreateCollectionReq.FieldSchema field) {
131+
switch (field.getElementType()) {
105132
case Int64:
106-
messageTypeBuilder.requiredList()
107-
.requiredElement(PrimitiveType.PrimitiveTypeName.INT64)
108-
.named(fieldType.getName());
133+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT64, null, field, true);
109134
break;
110135
case VarChar:
111-
messageTypeBuilder.requiredList()
112-
.requiredElement(PrimitiveType.PrimitiveTypeName.BINARY).as(LogicalTypeAnnotation.stringType())
113-
.named(fieldType.getName());
136+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BINARY,
137+
LogicalTypeAnnotation.stringType(), field, true);
114138
break;
115139
case Int8:
116-
messageTypeBuilder.requiredList()
117-
.requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true))
118-
.named(fieldType.getName());
140+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
141+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(8, true), field, true);
119142
break;
120143
case Int16:
121-
messageTypeBuilder.requiredList()
122-
.requiredElement(PrimitiveType.PrimitiveTypeName.INT32).as(LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true))
123-
.named(fieldType.getName());
144+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32,
145+
LogicalTypeAnnotation.IntLogicalTypeAnnotation.intType(16, true), field, true);
124146
break;
125147
case Int32:
126-
messageTypeBuilder.requiredList()
127-
.requiredElement(PrimitiveType.PrimitiveTypeName.INT32)
128-
.named(fieldType.getName());
148+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.INT32, null, field, true);
129149
break;
130150
case Float:
131-
messageTypeBuilder.requiredList()
132-
.requiredElement(PrimitiveType.PrimitiveTypeName.FLOAT)
133-
.named(fieldType.getName());
151+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.FLOAT, null, field, true);
134152
break;
135153
case Double:
136-
messageTypeBuilder.requiredList()
137-
.requiredElement(PrimitiveType.PrimitiveTypeName.DOUBLE)
138-
.named(fieldType.getName());
154+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.DOUBLE, null, field, true);
139155
break;
140156
case Bool:
141-
messageTypeBuilder.requiredList()
142-
.requiredElement(PrimitiveType.PrimitiveTypeName.BOOLEAN)
143-
.named(fieldType.getName());
157+
setMessageType(messageTypeBuilder, PrimitiveType.PrimitiveTypeName.BOOLEAN, null, field, true);
144158
break;
145-
146159
}
147160
}
148161
}

0 commit comments

Comments
 (0)