Skip to content

Commit 6868f66

Browse files
author
Omer Katz
committed
Merge pull request #1155 from AWhetter/fix837
ReferenceFields can now reference abstract Document types
2 parents 3c0b00e + 04497ae commit 6868f66

File tree

4 files changed

+109
-7
lines changed

4 files changed

+109
-7
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,4 @@ that much better:
230230
* Amit Lichtenberg (https://github.com/amitlicht)
231231
* Lars Butler (https://github.com/larsbutler)
232232
* George Macon (https://github.com/gmacon)
233+
* Ashley Whetter (https://github.com/AWhetter)

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Changes in 0.10.2
66
=================
77
- Allow shard key to point to a field in an embedded document. #551
88
- Allow arbirary metadata in fields. #1129
9+
- ReferenceFields now support abstract document types. #837
910

1011
Changes in 0.10.1
1112
=======================

mongoengine/fields.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,10 @@ def __init__(self, document_type, dbref=False,
895895
or as the :class:`~pymongo.objectid.ObjectId`.id .
896896
:param reverse_delete_rule: Determines what to do when the referring
897897
object is deleted
898+
899+
.. note ::
900+
A reference to an abstract document type is always stored as a
901+
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
898902
"""
899903
if not isinstance(document_type, basestring):
900904
if not issubclass(document_type, (Document, basestring)):
@@ -927,9 +931,14 @@ def __get__(self, instance, owner):
927931
self._auto_dereference = instance._fields[self.name]._auto_dereference
928932
# Dereference DBRefs
929933
if self._auto_dereference and isinstance(value, DBRef):
930-
value = self.document_type._get_db().dereference(value)
934+
if hasattr(value, 'cls'):
935+
# Dereference using the class type specified in the reference
936+
cls = get_document(value.cls)
937+
else:
938+
cls = self.document_type
939+
value = cls._get_db().dereference(value)
931940
if value is not None:
932-
instance._data[self.name] = self.document_type._from_son(value)
941+
instance._data[self.name] = cls._from_son(value)
933942

934943
return super(ReferenceField, self).__get__(instance, owner)
935944

@@ -939,21 +948,29 @@ def to_mongo(self, document):
939948
return document.id
940949
return document
941950

942-
id_field_name = self.document_type._meta['id_field']
943-
id_field = self.document_type._fields[id_field_name]
944-
945951
if isinstance(document, Document):
946952
# We need the id from the saved object to create the DBRef
947953
id_ = document.pk
948954
if id_ is None:
949955
self.error('You can only reference documents once they have'
950956
' been saved to the database')
957+
958+
# Use the attributes from the document instance, so that they
959+
# override the attributes of this field's document type
960+
cls = document
951961
else:
952962
id_ = document
963+
cls = self.document_type
964+
965+
id_field_name = cls._meta['id_field']
966+
id_field = cls._fields[id_field_name]
953967

954968
id_ = id_field.to_mongo(id_)
955-
if self.dbref:
956-
collection = self.document_type._get_collection_name()
969+
if self.document_type._meta.get('abstract'):
970+
collection = cls._get_collection_name()
971+
return DBRef(collection, id_, cls=cls._class_name)
972+
elif self.dbref:
973+
collection = cls._get_collection_name()
957974
return DBRef(collection, id_)
958975

959976
return id_
@@ -982,6 +999,14 @@ def validate(self, value):
982999
self.error('You can only reference documents once they have been '
9831000
'saved to the database')
9841001

1002+
if self.document_type._meta.get('abstract') and \
1003+
not isinstance(value, self.document_type):
1004+
self.error('%s is not an instance of abstract reference'
1005+
' type %s' % (value._class_name,
1006+
self.document_type._class_name)
1007+
)
1008+
1009+
9851010
def lookup_member(self, member_name):
9861011
return self.document_type._fields.get(member_name)
9871012

tests/fields/fields.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,81 @@ class BlogPost(Document):
22812281
Member.drop_collection()
22822282
BlogPost.drop_collection()
22832283

2284+
def test_reference_class_with_abstract_parent(self):
2285+
"""Ensure that a class with an abstract parent can be referenced.
2286+
"""
2287+
class Sibling(Document):
2288+
name = StringField()
2289+
meta = {"abstract": True}
2290+
2291+
class Sister(Sibling):
2292+
pass
2293+
2294+
class Brother(Sibling):
2295+
sibling = ReferenceField(Sibling)
2296+
2297+
Sister.drop_collection()
2298+
Brother.drop_collection()
2299+
2300+
sister = Sister(name="Alice")
2301+
sister.save()
2302+
brother = Brother(name="Bob", sibling=sister)
2303+
brother.save()
2304+
2305+
self.assertEquals(Brother.objects[0].sibling.name, sister.name)
2306+
2307+
Sister.drop_collection()
2308+
Brother.drop_collection()
2309+
2310+
def test_reference_abstract_class(self):
2311+
"""Ensure that an abstract class instance cannot be used in the
2312+
reference of that abstract class.
2313+
"""
2314+
class Sibling(Document):
2315+
name = StringField()
2316+
meta = {"abstract": True}
2317+
2318+
class Sister(Sibling):
2319+
pass
2320+
2321+
class Brother(Sibling):
2322+
sibling = ReferenceField(Sibling)
2323+
2324+
Sister.drop_collection()
2325+
Brother.drop_collection()
2326+
2327+
sister = Sibling(name="Alice")
2328+
brother = Brother(name="Bob", sibling=sister)
2329+
self.assertRaises(ValidationError, brother.save)
2330+
2331+
Sister.drop_collection()
2332+
Brother.drop_collection()
2333+
2334+
def test_abstract_reference_base_type(self):
2335+
"""Ensure that an an abstract reference fails validation when given a
2336+
Document that does not inherit from the abstract type.
2337+
"""
2338+
class Sibling(Document):
2339+
name = StringField()
2340+
meta = {"abstract": True}
2341+
2342+
class Brother(Sibling):
2343+
sibling = ReferenceField(Sibling)
2344+
2345+
class Mother(Document):
2346+
name = StringField()
2347+
2348+
Brother.drop_collection()
2349+
Mother.drop_collection()
2350+
2351+
mother = Mother(name="Carol")
2352+
mother.save()
2353+
brother = Brother(name="Bob", sibling=mother)
2354+
self.assertRaises(ValidationError, brother.save)
2355+
2356+
Brother.drop_collection()
2357+
Mother.drop_collection()
2358+
22842359
def test_generic_reference(self):
22852360
"""Ensure that a GenericReferenceField properly dereferences items.
22862361
"""

0 commit comments

Comments
 (0)