|
1 | | -import { describe, it, expect, beforeEach, vi } from "vitest"; |
| 1 | +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; |
2 | 2 | import { AmazonBedrockEmbeddingFunction } from "./index"; |
3 | 3 |
|
4 | | -// Mock AWS Bedrock Runtime client |
5 | | -vi.mock("@aws-sdk/client-bedrock-runtime", () => { |
6 | | - const mockSend = vi.fn().mockResolvedValue({ |
7 | | - body: new TextEncoder().encode( |
8 | | - JSON.stringify({ |
9 | | - embedding: Array(1536) |
10 | | - .fill(0) |
11 | | - .map((_, i) => i / 1000), |
12 | | - }) |
13 | | - ), |
14 | | - }); |
15 | | - |
16 | | - return { |
17 | | - BedrockRuntimeClient: vi.fn().mockImplementation(() => ({ |
18 | | - send: mockSend, |
19 | | - })), |
20 | | - InvokeModelCommand: vi.fn().mockImplementation((params) => params), |
21 | | - }; |
22 | | -}); |
| 4 | +const mockFetch = vi.hoisted(() => vi.fn()); |
23 | 5 |
|
24 | 6 | describe("AmazonBedrockEmbeddingFunction", () => { |
25 | 7 | beforeEach(() => { |
| 8 | + vi.stubGlobal("fetch", mockFetch); |
| 9 | + let call = 0; |
| 10 | + mockFetch.mockImplementation(async () => ({ |
| 11 | + ok: true, |
| 12 | + status: 200, |
| 13 | + statusText: "OK", |
| 14 | + text: async () => "", |
| 15 | + json: async () => ({ |
| 16 | + embedding: Array(1024) |
| 17 | + .fill(0) |
| 18 | + .map((_, i) => (i + call++) / 1000), |
| 19 | + }), |
| 20 | + })); |
26 | 21 | vi.clearAllMocks(); |
| 22 | + process.env.AMAZON_BEDROCK_API_KEY = "test-api-key"; |
| 23 | + }); |
| 24 | + |
| 25 | + afterEach(() => { |
| 26 | + vi.unstubAllGlobals(); |
| 27 | + delete process.env.AMAZON_BEDROCK_API_KEY; |
27 | 28 | }); |
28 | 29 |
|
29 | 30 | it("should initialize with default parameters", () => { |
30 | | - const embedder = new AmazonBedrockEmbeddingFunction({}); |
31 | | - expect(embedder.name).toBe("bedrock"); |
| 31 | + const embedder = new AmazonBedrockEmbeddingFunction({ |
| 32 | + region: "us-east-1", |
| 33 | + }); |
| 34 | + expect(embedder.name).toBe("amazon_bedrock"); |
32 | 35 |
|
33 | 36 | const config = embedder.getConfig(); |
34 | | - expect(config.model_id).toBe("amazon.titan-embed-text-v1"); |
| 37 | + expect(config.model_name).toBe("amazon.titan-embed-text-v2"); |
| 38 | + expect(config.region).toBe("us-east-1"); |
| 39 | + expect(config.api_key_env).toBe("AMAZON_BEDROCK_API_KEY"); |
35 | 40 | }); |
36 | 41 |
|
37 | 42 | it("should initialize with custom parameters", () => { |
38 | 43 | const embedder = new AmazonBedrockEmbeddingFunction({ |
| 44 | + apiKey: "direct-key", |
39 | 45 | region: "us-west-2", |
40 | | - modelId: "custom-model-id", |
| 46 | + modelName: "amazon.titan-embed-text-v1", |
| 47 | + apiKeyEnv: "CUSTOM_BEDROCK_KEY", |
41 | 48 | }); |
42 | 49 |
|
43 | 50 | const config = embedder.getConfig(); |
44 | 51 | expect(config.region).toBe("us-west-2"); |
45 | | - expect(config.model_id).toBe("custom-model-id"); |
| 52 | + expect(config.model_name).toBe("amazon.titan-embed-text-v1"); |
| 53 | + expect(config.api_key_env).toBe("CUSTOM_BEDROCK_KEY"); |
| 54 | + }); |
| 55 | + |
| 56 | + it("should throw when API key is missing", () => { |
| 57 | + delete process.env.AMAZON_BEDROCK_API_KEY; |
| 58 | + expect(() => { |
| 59 | + new AmazonBedrockEmbeddingFunction({ region: "us-east-1" }); |
| 60 | + }).toThrow(/apiKey is required/); |
| 61 | + }); |
| 62 | + |
| 63 | + it("should throw when region is missing", () => { |
| 64 | + expect(() => { |
| 65 | + new AmazonBedrockEmbeddingFunction({ apiKey: "k" }); |
| 66 | + }).toThrow(/region is required/); |
46 | 67 | }); |
47 | 68 |
|
48 | 69 | it("should generate embeddings", async () => { |
49 | | - const embedder = new AmazonBedrockEmbeddingFunction({}); |
| 70 | + const embedder = new AmazonBedrockEmbeddingFunction({ |
| 71 | + region: "us-east-1", |
| 72 | + }); |
50 | 73 | const texts = ["Hello world", "Test text"]; |
51 | 74 | const embeddings = await embedder.generate(texts); |
52 | 75 |
|
53 | 76 | expect(embeddings.length).toBe(texts.length); |
54 | 77 | embeddings.forEach((embedding) => { |
55 | | - expect(embedding.length).toBe(1536); |
| 78 | + expect(embedding.length).toBe(1024); |
56 | 79 | }); |
| 80 | + expect(embeddings[0]).not.toEqual(embeddings[1]); |
| 81 | + }); |
| 82 | + |
| 83 | + it("should return empty array for empty input", async () => { |
| 84 | + const embedder = new AmazonBedrockEmbeddingFunction({ |
| 85 | + region: "us-east-1", |
| 86 | + }); |
| 87 | + expect(await embedder.generate([])).toEqual([]); |
| 88 | + expect(mockFetch).not.toHaveBeenCalled(); |
57 | 89 | }); |
58 | 90 |
|
59 | 91 | it("should build from config", () => { |
60 | 92 | const snakeCaseConfig = { |
61 | 93 | region: "eu-west-1", |
62 | | - model_id: "amazon.titan-embed-text-v2:0", |
| 94 | + model_name: "amazon.titan-embed-text-v2:0" as const, |
| 95 | + api_key_env: "AMAZON_BEDROCK_API_KEY", |
63 | 96 | }; |
64 | 97 |
|
65 | 98 | const embedder = |
66 | 99 | AmazonBedrockEmbeddingFunction.buildFromConfig(snakeCaseConfig); |
67 | 100 |
|
68 | 101 | expect(embedder).toBeInstanceOf(AmazonBedrockEmbeddingFunction); |
69 | | - expect(embedder.name).toBe("bedrock"); |
| 102 | + expect(embedder.name).toBe("amazon_bedrock"); |
70 | 103 |
|
71 | 104 | const config = embedder.getConfig(); |
72 | 105 | expect(config.region).toBe("eu-west-1"); |
73 | | - expect(config.model_id).toBe("amazon.titan-embed-text-v2:0"); |
| 106 | + expect(config.model_name).toBe("amazon.titan-embed-text-v2:0"); |
74 | 107 | }); |
75 | 108 |
|
76 | | - it("should handle empty config", () => { |
77 | | - const embedder = AmazonBedrockEmbeddingFunction.buildFromConfig({}); |
| 109 | + it("should build instance with defaults from partial config", () => { |
| 110 | + const embedder = AmazonBedrockEmbeddingFunction.buildFromConfig({ |
| 111 | + region: "us-east-1", |
| 112 | + }); |
78 | 113 |
|
79 | 114 | expect(embedder).toBeInstanceOf(AmazonBedrockEmbeddingFunction); |
80 | 115 | const config = embedder.getConfig(); |
81 | | - expect(config.model_id).toBe("amazon.titan-embed-text-v1"); |
| 116 | + expect(config.model_name).toBe("amazon.titan-embed-text-v2"); |
| 117 | + }); |
| 118 | + |
| 119 | + it("should call fetch with correct URL and headers", async () => { |
| 120 | + const embedder = new AmazonBedrockEmbeddingFunction({ |
| 121 | + region: "us-east-1", |
| 122 | + modelName: "amazon.titan-embed-text-v2", |
| 123 | + }); |
| 124 | + |
| 125 | + await embedder.generate(["hello"]); |
| 126 | + |
| 127 | + expect(mockFetch).toHaveBeenCalledWith( |
| 128 | + "https://bedrock-runtime.us-east-1.amazonaws.com/model/amazon.titan-embed-text-v2/invoke", |
| 129 | + expect.objectContaining({ |
| 130 | + method: "POST", |
| 131 | + headers: expect.objectContaining({ |
| 132 | + Authorization: "Bearer test-api-key", |
| 133 | + "Content-Type": "application/json", |
| 134 | + Accept: "application/json", |
| 135 | + }), |
| 136 | + body: JSON.stringify({ inputText: "hello" }), |
| 137 | + }) |
| 138 | + ); |
| 139 | + }); |
| 140 | + |
| 141 | + it("should expose dimension and getModelDimensions", () => { |
| 142 | + const embedder = new AmazonBedrockEmbeddingFunction({ |
| 143 | + region: "us-east-1", |
| 144 | + modelName: "amazon.titan-embed-text-v1", |
| 145 | + }); |
| 146 | + expect(embedder.dimension).toBe(1536); |
| 147 | + expect( |
| 148 | + AmazonBedrockEmbeddingFunction.getModelDimensions()[ |
| 149 | + "amazon.titan-embed-text-v2" |
| 150 | + ] |
| 151 | + ).toBe(1024); |
82 | 152 | }); |
83 | 153 | }); |
0 commit comments