Skip to content

Commit cc5f0b6

Browse files
committed
Added distance functions for Doctrine
1 parent 368f5e1 commit cc5f0b6

File tree

6 files changed

+92
-12
lines changed

6 files changed

+92
-12
lines changed

src/doctrine/CosineDistance.php

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
class CosineDistance extends DistanceNode
6+
{
7+
protected function getOp(): string
8+
{
9+
return '<=>';
10+
}
11+
}

src/doctrine/DistanceNode.php

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\ORM\Query\AST\Functions\FunctionNode;
6+
use Doctrine\ORM\Query\Parser;
7+
use Doctrine\ORM\Query\SqlWalker;
8+
use Doctrine\ORM\Query\TokenType;
9+
10+
abstract class DistanceNode extends FunctionNode
11+
{
12+
public $left;
13+
public $right;
14+
15+
abstract protected function getOp(): string;
16+
17+
public function parse(Parser $parser): void
18+
{
19+
$parser->match(TokenType::T_IDENTIFIER);
20+
$parser->match(TokenType::T_OPEN_PARENTHESIS);
21+
$this->left = $parser->ArithmeticPrimary();
22+
$parser->match(TokenType::T_COMMA);
23+
$this->right = $parser->ArithmeticPrimary();
24+
$parser->match(TokenType::T_CLOSE_PARENTHESIS);
25+
}
26+
27+
public function getSql(SqlWalker $sqlWalker): string
28+
{
29+
return sprintf(
30+
'(%s %s %s)',
31+
$this->left->dispatch($sqlWalker),
32+
$this->getOp(),
33+
$this->right->dispatch($sqlWalker),
34+
);
35+
}
36+
}

src/doctrine/L1Distance.php

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
class L1Distance extends DistanceNode
6+
{
7+
protected function getOp(): string
8+
{
9+
return '<+>';
10+
}
11+
}

src/doctrine/L2Distance.php

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
class L2Distance extends DistanceNode
6+
{
7+
protected function getOp(): string
8+
{
9+
return '<->';
10+
}
11+
}

src/doctrine/MaxInnerProduct.php

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
class MaxInnerProduct extends DistanceNode
6+
{
7+
protected function getOp(): string
8+
{
9+
return '<#>';
10+
}
11+
}

tests/DoctrineTest.php

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ public static function setUpBeforeClass(): void
2525
paths: [__DIR__ . '/models'],
2626
isDevMode: true
2727
);
28+
$config->addCustomNumericFunction('l2_distance', 'Pgvector\Doctrine\L2Distance');
29+
$config->addCustomNumericFunction('max_inner_product', 'Pgvector\Doctrine\MaxInnerProduct');
30+
$config->addCustomNumericFunction('cosine_distance', 'Pgvector\Doctrine\CosineDistance');
31+
$config->addCustomNumericFunction('l1_distance', 'Pgvector\Doctrine\L1Distance');
2832

2933
$connection = DriverManager::getConnection([
3034
'driver' => 'pgsql',
@@ -84,10 +88,9 @@ public function testTypes()
8488
public function testVectorL2Distance()
8589
{
8690
$this->createItems();
87-
$rsm = new ResultSetMappingBuilder(self::$em);
88-
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
89-
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <-> ? LIMIT 5', $rsm)
91+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l2_distance(i.embedding, ?1)')
9092
->setParameter(1, new Vector([1, 1, 1]))
93+
->setMaxResults(5)
9194
->getResult();
9295
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
9396
$this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->getEmbedding()->toArray(), $neighbors));
@@ -96,32 +99,29 @@ public function testVectorL2Distance()
9699
public function testVectorMaxInnerProduct()
97100
{
98101
$this->createItems();
99-
$rsm = new ResultSetMappingBuilder(self::$em);
100-
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
101-
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <#> ? LIMIT 5', $rsm)
102+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY max_inner_product(i.embedding, ?1)')
102103
->setParameter(1, new Vector([1, 1, 1]))
104+
->setMaxResults(5)
103105
->getResult();
104106
$this->assertEquals([2, 3, 1], array_map(fn ($v) => $v->getId(), $neighbors));
105107
}
106108

107109
public function testVectorCosineDistance()
108110
{
109111
$this->createItems();
110-
$rsm = new ResultSetMappingBuilder(self::$em);
111-
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
112-
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <=> ? LIMIT 5', $rsm)
112+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY cosine_distance(i.embedding, ?1)')
113113
->setParameter(1, new Vector([1, 1, 1]))
114+
->setMaxResults(5)
114115
->getResult();
115116
$this->assertEquals([1, 2, 3], array_map(fn ($v) => $v->getId(), $neighbors));
116117
}
117118

118119
public function testVectorL1Distance()
119120
{
120121
$this->createItems();
121-
$rsm = new ResultSetMappingBuilder(self::$em);
122-
$rsm->addRootEntityFromClassMetadata('DoctrineItem', 'i');
123-
$neighbors = self::$em->createNativeQuery('SELECT * FROM doctrine_items i ORDER BY embedding <+> ? LIMIT 5', $rsm)
122+
$neighbors = self::$em->createQuery('SELECT i FROM DoctrineItem i ORDER BY l1_distance(i.embedding, ?1)')
124123
->setParameter(1, new Vector([1, 1, 1]))
124+
->setMaxResults(5)
125125
->getResult();
126126
$this->assertEquals([1, 3, 2], array_map(fn ($v) => $v->getId(), $neighbors));
127127
}

0 commit comments

Comments
 (0)