Skip to content

Commit f3f2562

Browse files
committed
update embedding tests
1 parent a8f5a21 commit f3f2562

17 files changed

Lines changed: 580 additions & 73 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
node_modules/
1+
**/node_modules/
22
dist/
33
*.log
44
.DS_Store
@@ -12,5 +12,6 @@ coverage/
1212
.pnpm-store
1313
.cursor
1414
*.db
15+
examples/seekdb-prisma/generated/**
1516
spec/
1617
.vscode

examples/seekdb-prisma/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"scripts": {
66
"start": "tsx index.ts",
77
"start:embedded": "tsx index-embedded.ts",
8-
"db:generate": "prisma generate",
8+
"postinstall": "prisma generate",
99
"db:push": "prisma db push"
1010
},
1111
"dependencies": {

examples/seekdb-prisma/schema.prisma

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Prisma schema: relational tables only. Vector tables are managed by seekdb Collection.
22
generator client {
33
provider = "prisma-client-js"
4+
output = "./generated/client"
45
}
56

67
datasource db {
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"compilerOptions": {
3+
"target": "ES2022",
4+
"module": "NodeNext",
5+
"moduleResolution": "NodeNext",
6+
"strict": true,
7+
"noEmit": true,
8+
"skipLibCheck": true
9+
},
10+
"include": ["*.ts", "generated/**/*.ts"]
11+
}

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"build": "pnpm --filter seekdb run build && pnpm --filter '@seekdb/*' run build",
77
"build:seekdb": "pnpm --filter seekdb run build",
88
"build:embeddings": "pnpm --filter '@seekdb/*' run build",
9-
"test": "pnpm --filter seekdb run test -- run && pnpm --filter @seekdb/prisma-adapter run test",
9+
"test": "pnpm --filter seekdb run test && pnpm --filter @seekdb/prisma-adapter run test",
1010
"lint": "pnpm -r run lint",
1111
"type-check": "pnpm -r run type-check",
1212
"prettier": "prettier --write .",
Lines changed: 103 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,153 @@
1-
import { describe, it, expect, beforeEach, vi } from "vitest";
1+
import { describe, it, expect, beforeEach, afterEach, vi } from "vitest";
22
import { AmazonBedrockEmbeddingFunction } from "./index";
33

4-
// Mock AWS Bedrock Runtime client
5-
vi.mock("@aws-sdk/client-bedrock-runtime", () => {
6-
const mockSend = vi.fn().mockResolvedValue({
7-
body: new TextEncoder().encode(
8-
JSON.stringify({
9-
embedding: Array(1536)
10-
.fill(0)
11-
.map((_, i) => i / 1000),
12-
})
13-
),
14-
});
15-
16-
return {
17-
BedrockRuntimeClient: vi.fn().mockImplementation(() => ({
18-
send: mockSend,
19-
})),
20-
InvokeModelCommand: vi.fn().mockImplementation((params) => params),
21-
};
22-
});
4+
const mockFetch = vi.hoisted(() => vi.fn());
235

246
describe("AmazonBedrockEmbeddingFunction", () => {
257
beforeEach(() => {
8+
vi.stubGlobal("fetch", mockFetch);
9+
let call = 0;
10+
mockFetch.mockImplementation(async () => ({
11+
ok: true,
12+
status: 200,
13+
statusText: "OK",
14+
text: async () => "",
15+
json: async () => ({
16+
embedding: Array(1024)
17+
.fill(0)
18+
.map((_, i) => (i + call++) / 1000),
19+
}),
20+
}));
2621
vi.clearAllMocks();
22+
process.env.AMAZON_BEDROCK_API_KEY = "test-api-key";
23+
});
24+
25+
afterEach(() => {
26+
vi.unstubAllGlobals();
27+
delete process.env.AMAZON_BEDROCK_API_KEY;
2728
});
2829

2930
it("should initialize with default parameters", () => {
30-
const embedder = new AmazonBedrockEmbeddingFunction({});
31-
expect(embedder.name).toBe("bedrock");
31+
const embedder = new AmazonBedrockEmbeddingFunction({
32+
region: "us-east-1",
33+
});
34+
expect(embedder.name).toBe("amazon_bedrock");
3235

3336
const config = embedder.getConfig();
34-
expect(config.model_id).toBe("amazon.titan-embed-text-v1");
37+
expect(config.model_name).toBe("amazon.titan-embed-text-v2");
38+
expect(config.region).toBe("us-east-1");
39+
expect(config.api_key_env).toBe("AMAZON_BEDROCK_API_KEY");
3540
});
3641

3742
it("should initialize with custom parameters", () => {
3843
const embedder = new AmazonBedrockEmbeddingFunction({
44+
apiKey: "direct-key",
3945
region: "us-west-2",
40-
modelId: "custom-model-id",
46+
modelName: "amazon.titan-embed-text-v1",
47+
apiKeyEnv: "CUSTOM_BEDROCK_KEY",
4148
});
4249

4350
const config = embedder.getConfig();
4451
expect(config.region).toBe("us-west-2");
45-
expect(config.model_id).toBe("custom-model-id");
52+
expect(config.model_name).toBe("amazon.titan-embed-text-v1");
53+
expect(config.api_key_env).toBe("CUSTOM_BEDROCK_KEY");
54+
});
55+
56+
it("should throw when API key is missing", () => {
57+
delete process.env.AMAZON_BEDROCK_API_KEY;
58+
expect(() => {
59+
new AmazonBedrockEmbeddingFunction({ region: "us-east-1" });
60+
}).toThrow(/apiKey is required/);
61+
});
62+
63+
it("should throw when region is missing", () => {
64+
expect(() => {
65+
new AmazonBedrockEmbeddingFunction({ apiKey: "k" });
66+
}).toThrow(/region is required/);
4667
});
4768

4869
it("should generate embeddings", async () => {
49-
const embedder = new AmazonBedrockEmbeddingFunction({});
70+
const embedder = new AmazonBedrockEmbeddingFunction({
71+
region: "us-east-1",
72+
});
5073
const texts = ["Hello world", "Test text"];
5174
const embeddings = await embedder.generate(texts);
5275

5376
expect(embeddings.length).toBe(texts.length);
5477
embeddings.forEach((embedding) => {
55-
expect(embedding.length).toBe(1536);
78+
expect(embedding.length).toBe(1024);
5679
});
80+
expect(embeddings[0]).not.toEqual(embeddings[1]);
81+
});
82+
83+
it("should return empty array for empty input", async () => {
84+
const embedder = new AmazonBedrockEmbeddingFunction({
85+
region: "us-east-1",
86+
});
87+
expect(await embedder.generate([])).toEqual([]);
88+
expect(mockFetch).not.toHaveBeenCalled();
5789
});
5890

5991
it("should build from config", () => {
6092
const snakeCaseConfig = {
6193
region: "eu-west-1",
62-
model_id: "amazon.titan-embed-text-v2:0",
94+
model_name: "amazon.titan-embed-text-v2:0" as const,
95+
api_key_env: "AMAZON_BEDROCK_API_KEY",
6396
};
6497

6598
const embedder =
6699
AmazonBedrockEmbeddingFunction.buildFromConfig(snakeCaseConfig);
67100

68101
expect(embedder).toBeInstanceOf(AmazonBedrockEmbeddingFunction);
69-
expect(embedder.name).toBe("bedrock");
102+
expect(embedder.name).toBe("amazon_bedrock");
70103

71104
const config = embedder.getConfig();
72105
expect(config.region).toBe("eu-west-1");
73-
expect(config.model_id).toBe("amazon.titan-embed-text-v2:0");
106+
expect(config.model_name).toBe("amazon.titan-embed-text-v2:0");
74107
});
75108

76-
it("should handle empty config", () => {
77-
const embedder = AmazonBedrockEmbeddingFunction.buildFromConfig({});
109+
it("should build instance with defaults from partial config", () => {
110+
const embedder = AmazonBedrockEmbeddingFunction.buildFromConfig({
111+
region: "us-east-1",
112+
});
78113

79114
expect(embedder).toBeInstanceOf(AmazonBedrockEmbeddingFunction);
80115
const config = embedder.getConfig();
81-
expect(config.model_id).toBe("amazon.titan-embed-text-v1");
116+
expect(config.model_name).toBe("amazon.titan-embed-text-v2");
117+
});
118+
119+
it("should call fetch with correct URL and headers", async () => {
120+
const embedder = new AmazonBedrockEmbeddingFunction({
121+
region: "us-east-1",
122+
modelName: "amazon.titan-embed-text-v2",
123+
});
124+
125+
await embedder.generate(["hello"]);
126+
127+
expect(mockFetch).toHaveBeenCalledWith(
128+
"https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v2/invoke",
129+
expect.objectContaining({
130+
method: "POST",
131+
headers: expect.objectContaining({
132+
Authorization: "Bearer test-api-key",
133+
"Content-Type": "application/json",
134+
Accept: "application/json",
135+
}),
136+
body: JSON.stringify({ inputText: "hello" }),
137+
})
138+
);
139+
});
140+
141+
it("should expose dimension and getModelDimensions", () => {
142+
const embedder = new AmazonBedrockEmbeddingFunction({
143+
region: "us-east-1",
144+
modelName: "amazon.titan-embed-text-v1",
145+
});
146+
expect(embedder.dimension).toBe(1536);
147+
expect(
148+
AmazonBedrockEmbeddingFunction.getModelDimensions()[
149+
"amazon.titan-embed-text-v2"
150+
]
151+
).toBe(1024);
82152
});
83153
});

packages/embeddings/amazon-bedrock/index.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,6 @@ export class AmazonBedrockEmbeddingFunction implements EmbeddingFunction {
139139
static buildFromConfig(
140140
config: EmbeddingConfig
141141
): AmazonBedrockEmbeddingFunction {
142-
if (!config.api_key_env) {
143-
throw new Error(
144-
"Building Amazon bedrock embedding function from config: api_key_env is required in config."
145-
);
146-
}
147-
148142
return new AmazonBedrockEmbeddingFunction({
149143
apiKeyEnv: config.api_key_env,
150144
region: config.region,

0 commit comments

Comments
 (0)