Skip to content

Commit 1b94328

Browse files
authored
Add ability to filter the generic reference field by ObjectId and DBRef (#1425)
1 parent 25e0f12 commit 1b94328

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

mongoengine/fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ def to_mongo(self, document):
12491249
if document is None:
12501250
return None
12511251

1252-
if isinstance(document, (dict, SON)):
1252+
if isinstance(document, (dict, SON, ObjectId, DBRef)):
12531253
return document
12541254

12551255
id_field_name = document.__class__._meta['id_field']

mongoengine/queryset/transform.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import defaultdict
22

3-
from bson import SON
3+
from bson import ObjectId, SON
4+
from bson.dbref import DBRef
45
import pymongo
56

67
from mongoengine.base.fields import UPDATE_OPERATORS
@@ -26,6 +27,7 @@
2627
STRING_OPERATORS + CUSTOM_OPERATORS)
2728

2829

30+
# TODO make this less complex
2931
def query(_doc_cls=None, **kwargs):
3032
"""Transform a query from Django-style format to Mongo format.
3133
"""
@@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs):
6264
parts = []
6365

6466
CachedReferenceField = _import_class('CachedReferenceField')
67+
GenericReferenceField = _import_class('GenericReferenceField')
6568

6669
cleaned_fields = []
6770
for field in fields:
@@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs):
101104
# 'in', 'nin' and 'all' require a list of values
102105
value = [field.prepare_query_value(op, v) for v in value]
103106

107+
# If we're querying a GenericReferenceField, we need to alter the
108+
# key depending on the value:
109+
# * If the value is a DBRef, the key should be "field_name._ref".
110+
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
111+
if isinstance(field, GenericReferenceField):
112+
if isinstance(value, DBRef):
113+
parts[-1] += '._ref'
114+
elif isinstance(value, ObjectId):
115+
parts[-1] += '._ref.$id'
116+
104117
# if op and op not in COMPARISON_OPERATORS:
105118
if op:
106119
if op in GEO_OPERATORS:
@@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs):
128141

129142
for i, part in indices:
130143
parts.insert(i, part)
144+
131145
key = '.'.join(parts)
146+
132147
if op is None or key not in mongo_query:
133148
mongo_query[key] = value
134149
elif key in mongo_query:
135-
if key in mongo_query and isinstance(mongo_query[key], dict):
150+
if isinstance(mongo_query[key], dict):
136151
mongo_query[key].update(value)
137152
# $max/minDistance needs to come last - convert to SON
138153
value_dict = mongo_query[key]

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ tests = tests
99
[flake8]
1010
ignore=E501,F401,F403,F405,I201
1111
exclude=build,dist,docs,venv,.tox,.eggs,tests
12-
max-complexity=42
12+
max-complexity=45
1313
application-import-names=mongoengine,tests

tests/fields/fields.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,6 +2810,38 @@ class User(Document):
28102810
Post.drop_collection()
28112811
User.drop_collection()
28122812

2813+
def test_generic_reference_filter_by_dbref(self):
2814+
"""Ensure we can search for a specific generic reference by
2815+
providing its ObjectId.
2816+
"""
2817+
class Doc(Document):
2818+
ref = GenericReferenceField()
2819+
2820+
Doc.drop_collection()
2821+
2822+
doc1 = Doc.objects.create()
2823+
doc2 = Doc.objects.create(ref=doc1)
2824+
2825+
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
2826+
self.assertEqual(doc, doc2)
2827+
2828+
def test_generic_reference_filter_by_objectid(self):
2829+
"""Ensure we can search for a specific generic reference by
2830+
providing its DBRef.
2831+
"""
2832+
class Doc(Document):
2833+
ref = GenericReferenceField()
2834+
2835+
Doc.drop_collection()
2836+
2837+
doc1 = Doc.objects.create()
2838+
doc2 = Doc.objects.create(ref=doc1)
2839+
2840+
self.assertTrue(isinstance(doc1.pk, ObjectId))
2841+
2842+
doc = Doc.objects.get(ref=doc1.pk)
2843+
self.assertEqual(doc, doc2)
2844+
28132845
def test_binary_fields(self):
28142846
"""Ensure that binary fields can be stored and retrieved.
28152847
"""

0 commit comments

Comments
 (0)