Skip to content

Commit 9c8ceb6

Browse files
committed
Merge pull request #1060 from touilleMan/GenericReferenceField-choices
Fix GenericReferenceField choices parameter
2 parents 9671ca5 + bebce2c commit 9c8ceb6

File tree

3 files changed

+100
-17
lines changed

3 files changed

+100
-17
lines changed

mongoengine/base/fields.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,26 +165,29 @@ def validate(self, value, clean=True):
165165
"""
166166
pass
167167

168-
def _validate(self, value, **kwargs):
168+
def _validate_choices(self, value):
169169
Document = _import_class('Document')
170170
EmbeddedDocument = _import_class('EmbeddedDocument')
171171

172-
# Check the Choices Constraint
173-
if self.choices:
172+
choice_list = self.choices
173+
if isinstance(choice_list[0], (list, tuple)):
174+
choice_list = [k for k, _ in choice_list]
175+
176+
# Choices which are other types of Documents
177+
if isinstance(value, (Document, EmbeddedDocument)):
178+
if not any(isinstance(value, c) for c in choice_list):
179+
self.error(
180+
'Value must be instance of %s' % unicode(choice_list)
181+
)
182+
# Choices which are types other than Documents
183+
elif value not in choice_list:
184+
self.error('Value must be one of %s' % unicode(choice_list))
174185

175-
choice_list = self.choices
176-
if isinstance(self.choices[0], (list, tuple)):
177-
choice_list = [k for k, v in self.choices]
178186

179-
# Choices which are other types of Documents
180-
if isinstance(value, (Document, EmbeddedDocument)):
181-
if not any(isinstance(value, c) for c in choice_list):
182-
self.error(
183-
'Value must be instance of %s' % unicode(choice_list)
184-
)
185-
# Choices which are types other than Documents
186-
elif value not in choice_list:
187-
self.error('Value must be one of %s' % unicode(choice_list))
187+
def _validate(self, value, **kwargs):
188+
# Check the Choices Constraint
189+
if self.choices:
190+
self._validate_choices(value)
188191

189192
# check validation argument
190193
if self.validation is not None:
@@ -308,7 +311,7 @@ def to_python(self, value):
308311
value_dict[k] = self.to_python(v)
309312

310313
if is_list: # Convert back to a list
311-
return [v for k, v in sorted(value_dict.items(),
314+
return [v for _, v in sorted(value_dict.items(),
312315
key=operator.itemgetter(0))]
313316
return value_dict
314317

@@ -375,7 +378,7 @@ def to_mongo(self, value):
375378
value_dict[k] = self.to_mongo(v)
376379

377380
if is_list: # Convert back to a list
378-
return [v for k, v in sorted(value_dict.items(),
381+
return [v for _, v in sorted(value_dict.items(),
379382
key=operator.itemgetter(0))]
380383
return value_dict
381384

mongoengine/fields.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,30 @@ class GenericReferenceField(BaseField):
11401140
.. versionadded:: 0.3
11411141
"""
11421142

1143+
def __init__(self, *args, **kwargs):
1144+
choices = kwargs.pop('choices', None)
1145+
super(GenericReferenceField, self).__init__(*args, **kwargs)
1146+
self.choices = []
1147+
# Keep the choices as a list of allowed Document class names
1148+
if choices:
1149+
for choice in choices:
1150+
if isinstance(choice, basestring):
1151+
self.choices.append(choice)
1152+
elif isinstance(choice, type) and issubclass(choice, Document):
1153+
self.choices.append(choice._class_name)
1154+
else:
1155+
self.error('Invalid choices provided: must be a list of'
1156+
'Document subclasses and/or basestrings')
1157+
1158+
def _validate_choices(self, value):
1159+
if isinstance(value, dict):
1160+
# If the field has not been dereferenced, it is still a dict
1161+
# of class and DBRef
1162+
value = value.get('_cls')
1163+
elif isinstance(value, Document):
1164+
value = value._class_name
1165+
super(GenericReferenceField, self)._validate_choices(value)
1166+
11431167
def __get__(self, instance, owner):
11441168
if instance is None:
11451169
return self

tests/fields/fields.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2396,6 +2396,62 @@ class Bookmark(Document):
23962396
bm = Bookmark.objects.first()
23972397
self.assertEqual(bm.bookmark_object, post_1)
23982398

2399+
def test_generic_reference_string_choices(self):
2400+
"""Ensure that a GenericReferenceField can handle choices as strings
2401+
"""
2402+
class Link(Document):
2403+
title = StringField()
2404+
2405+
class Post(Document):
2406+
title = StringField()
2407+
2408+
class Bookmark(Document):
2409+
bookmark_object = GenericReferenceField(choices=('Post', Link))
2410+
2411+
Link.drop_collection()
2412+
Post.drop_collection()
2413+
Bookmark.drop_collection()
2414+
2415+
link_1 = Link(title="Pitchfork")
2416+
link_1.save()
2417+
2418+
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
2419+
post_1.save()
2420+
2421+
bm = Bookmark(bookmark_object=link_1)
2422+
bm.save()
2423+
2424+
bm = Bookmark(bookmark_object=post_1)
2425+
bm.save()
2426+
2427+
bm = Bookmark(bookmark_object=bm)
2428+
self.assertRaises(ValidationError, bm.validate)
2429+
2430+
def test_generic_reference_choices_no_dereference(self):
2431+
"""Ensure that a GenericReferenceField can handle choices on
2432+
non-derefenreced (i.e. DBRef) elements
2433+
"""
2434+
class Post(Document):
2435+
title = StringField()
2436+
2437+
class Bookmark(Document):
2438+
bookmark_object = GenericReferenceField(choices=(Post, ))
2439+
other_field = StringField()
2440+
2441+
Post.drop_collection()
2442+
Bookmark.drop_collection()
2443+
2444+
post_1 = Post(title="Behind the Scenes of the Pavement Reunion")
2445+
post_1.save()
2446+
2447+
bm = Bookmark(bookmark_object=post_1)
2448+
bm.save()
2449+
2450+
bm = Bookmark.objects.get(id=bm.id)
2451+
# bookmark_object is now a DBRef
2452+
bm.other_field = 'dummy_change'
2453+
bm.save()
2454+
23992455
def test_generic_reference_list_choices(self):
24002456
"""Ensure that a ListField properly dereferences generic references and
24012457
respects choices.

0 commit comments

Comments
 (0)