Skip to content

Commit 0132273

Browse files
committed
ReferenceFields can now reference abstract Document types
A class that inherits from an abstract Document type is stored in the database as a reference with a 'cls' field that is the class name of the document being stored. Fixes #837
1 parent 19cbb44 commit 0132273

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

mongoengine/fields.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -928,9 +928,14 @@ def __get__(self, instance, owner):
928928
self._auto_dereference = instance._fields[self.name]._auto_dereference
929929
# Dereference DBRefs
930930
if self._auto_dereference and isinstance(value, DBRef):
931-
value = self.document_type._get_db().dereference(value)
931+
if hasattr(value, 'cls'):
932+
# Dereference using the class type specified in the reference
933+
cls = get_document(value.cls)
934+
else:
935+
cls = self.document_type
936+
value = cls._get_db().dereference(value)
932937
if value is not None:
933-
instance._data[self.name] = self.document_type._from_son(value)
938+
instance._data[self.name] = cls._from_son(value)
934939

935940
return super(ReferenceField, self).__get__(instance, owner)
936941

@@ -940,22 +945,30 @@ def to_mongo(self, document):
940945
return document.id
941946
return document
942947

943-
id_field_name = self.document_type._meta['id_field']
944-
id_field = self.document_type._fields[id_field_name]
945-
946948
if isinstance(document, Document):
947949
# We need the id from the saved object to create the DBRef
948950
id_ = document.pk
949951
if id_ is None:
950952
self.error('You can only reference documents once they have'
951953
' been saved to the database')
954+
955+
# Use the attributes from the document instance, so that they
956+
# override the attributes of this field's document type
957+
cls = document
952958
else:
953959
id_ = document
960+
cls = self.document_type
961+
962+
id_field_name = cls._meta['id_field']
963+
id_field = cls._fields[id_field_name]
954964

955965
id_ = id_field.to_mongo(id_)
956966
if self.dbref:
957-
collection = self.document_type._get_collection_name()
967+
collection = cls._get_collection_name()
958968
return DBRef(collection, id_)
969+
elif self.document_type._meta.get('abstract'):
970+
collection = cls._get_collection_name()
971+
return DBRef(collection, id_, cls=cls._class_name)
959972

960973
return id_
961974

tests/fields/fields.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,56 @@ class BlogPost(Document):
22812281
Member.drop_collection()
22822282
BlogPost.drop_collection()
22832283

2284+
def test_reference_class_with_abstract_parent(self):
2285+
"""Ensure that a class with an abstract parent can be referenced.
2286+
"""
2287+
class Sibling(Document):
2288+
name = StringField()
2289+
meta = {"abstract": True}
2290+
2291+
class Sister(Sibling):
2292+
pass
2293+
2294+
class Brother(Sibling):
2295+
sibling = ReferenceField(Sibling)
2296+
2297+
Sister.drop_collection()
2298+
Brother.drop_collection()
2299+
2300+
sister = Sister(name="Alice")
2301+
sister.save()
2302+
brother = Brother(name="Bob", sibling=sister)
2303+
brother.save()
2304+
2305+
self.assertEquals(Brother.objects[0].sibling.name, sister.name)
2306+
2307+
Sister.drop_collection()
2308+
Brother.drop_collection()
2309+
2310+
def test_reference_abstract_class(self):
2311+
"""Ensure that an abstract class instance cannot be used in the
2312+
reference of that abstract class.
2313+
"""
2314+
class Sibling(Document):
2315+
name = StringField()
2316+
meta = {"abstract": True}
2317+
2318+
class Sister(Sibling):
2319+
pass
2320+
2321+
class Brother(Sibling):
2322+
sibling = ReferenceField(Sibling)
2323+
2324+
Sister.drop_collection()
2325+
Brother.drop_collection()
2326+
2327+
sister = Sibling(name="Alice")
2328+
brother = Brother(name="Bob", sibling=sister)
2329+
self.assertRaises(ValidationError, brother.save)
2330+
2331+
Sister.drop_collection()
2332+
Brother.drop_collection()
2333+
22842334
def test_generic_reference(self):
22852335
"""Ensure that a GenericReferenceField properly dereferences items.
22862336
"""

0 commit comments

Comments
 (0)