@@ -25,6 +25,10 @@ public static function setUpBeforeClass(): void
2525 paths: [__DIR__ . '/models ' ],
2626 isDevMode: true
2727 );
28+ $ config ->addCustomNumericFunction ('l2_distance ' , 'Pgvector\Doctrine\L2Distance ' );
29+ $ config ->addCustomNumericFunction ('max_inner_product ' , 'Pgvector\Doctrine\MaxInnerProduct ' );
30+ $ config ->addCustomNumericFunction ('cosine_distance ' , 'Pgvector\Doctrine\CosineDistance ' );
31+ $ config ->addCustomNumericFunction ('l1_distance ' , 'Pgvector\Doctrine\L1Distance ' );
2832
2933 $ connection = DriverManager::getConnection ([
3034 'driver ' => 'pgsql ' ,
@@ -84,10 +88,9 @@ public function testTypes()
8488 public function testVectorL2Distance ()
8589 {
8690 $ this ->createItems ();
87- $ rsm = new ResultSetMappingBuilder (self ::$ em );
88- $ rsm ->addRootEntityFromClassMetadata ('DoctrineItem ' , 'i ' );
89- $ neighbors = self ::$ em ->createNativeQuery ('SELECT * FROM doctrine_items i ORDER BY embedding <-> ? LIMIT 5 ' , $ rsm )
91+ $ neighbors = self ::$ em ->createQuery ('SELECT i FROM DoctrineItem i ORDER BY l2_distance(i.embedding, ?1) ' )
9092 ->setParameter (1 , new Vector ([1 , 1 , 1 ]))
93+ ->setMaxResults (5 )
9194 ->getResult ();
9295 $ this ->assertEquals ([1 , 3 , 2 ], array_map (fn ($ v ) => $ v ->getId (), $ neighbors ));
9396 $ this ->assertEquals ([[1 , 1 , 1 ], [1 , 1 , 2 ], [2 , 2 , 2 ]], array_map (fn ($ v ) => $ v ->getEmbedding ()->toArray (), $ neighbors ));
@@ -96,32 +99,29 @@ public function testVectorL2Distance()
9699 public function testVectorMaxInnerProduct ()
97100 {
98101 $ this ->createItems ();
99- $ rsm = new ResultSetMappingBuilder (self ::$ em );
100- $ rsm ->addRootEntityFromClassMetadata ('DoctrineItem ' , 'i ' );
101- $ neighbors = self ::$ em ->createNativeQuery ('SELECT * FROM doctrine_items i ORDER BY embedding <#> ? LIMIT 5 ' , $ rsm )
102+ $ neighbors = self ::$ em ->createQuery ('SELECT i FROM DoctrineItem i ORDER BY max_inner_product(i.embedding, ?1) ' )
102103 ->setParameter (1 , new Vector ([1 , 1 , 1 ]))
104+ ->setMaxResults (5 )
103105 ->getResult ();
104106 $ this ->assertEquals ([2 , 3 , 1 ], array_map (fn ($ v ) => $ v ->getId (), $ neighbors ));
105107 }
106108
107109 public function testVectorCosineDistance ()
108110 {
109111 $ this ->createItems ();
110- $ rsm = new ResultSetMappingBuilder (self ::$ em );
111- $ rsm ->addRootEntityFromClassMetadata ('DoctrineItem ' , 'i ' );
112- $ neighbors = self ::$ em ->createNativeQuery ('SELECT * FROM doctrine_items i ORDER BY embedding <=> ? LIMIT 5 ' , $ rsm )
112+ $ neighbors = self ::$ em ->createQuery ('SELECT i FROM DoctrineItem i ORDER BY cosine_distance(i.embedding, ?1) ' )
113113 ->setParameter (1 , new Vector ([1 , 1 , 1 ]))
114+ ->setMaxResults (5 )
114115 ->getResult ();
115116 $ this ->assertEquals ([1 , 2 , 3 ], array_map (fn ($ v ) => $ v ->getId (), $ neighbors ));
116117 }
117118
118119 public function testVectorL1Distance ()
119120 {
120121 $ this ->createItems ();
121- $ rsm = new ResultSetMappingBuilder (self ::$ em );
122- $ rsm ->addRootEntityFromClassMetadata ('DoctrineItem ' , 'i ' );
123- $ neighbors = self ::$ em ->createNativeQuery ('SELECT * FROM doctrine_items i ORDER BY embedding <+> ? LIMIT 5 ' , $ rsm )
122+ $ neighbors = self ::$ em ->createQuery ('SELECT i FROM DoctrineItem i ORDER BY l1_distance(i.embedding, ?1) ' )
124123 ->setParameter (1 , new Vector ([1 , 1 , 1 ]))
124+ ->setMaxResults (5 )
125125 ->getResult ();
126126 $ this ->assertEquals ([1 , 3 , 2 ], array_map (fn ($ v ) => $ v ->getId (), $ neighbors ));
127127 }
0 commit comments