Skip to content

Commit deb5677

Browse files
committed
Allow shard key to be in an embedded document (#551)
1 parent 5c464c3 commit deb5677

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

mongoengine/document.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,12 @@ def save(self, force_insert=False, validate=True, clean=True,
341341
select_dict['_id'] = object_id
342342
shard_key = self.__class__._meta.get('shard_key', tuple())
343343
for k in shard_key:
344-
actual_key = self._db_field_map.get(k, k)
345-
select_dict[actual_key] = doc[actual_key]
344+
path = self._lookup_field(k.split('.'))
345+
actual_key = [p.db_field for p in path]
346+
val = doc
347+
for ak in actual_key:
348+
val = val[ak]
349+
select_dict['.'.join(actual_key)] = val
346350

347351
def is_new_object(last_error):
348352
if last_error is not None:
@@ -444,7 +448,12 @@ def _object_key(self):
444448
select_dict = {'pk': self.pk}
445449
shard_key = self.__class__._meta.get('shard_key', tuple())
446450
for k in shard_key:
447-
select_dict[k] = getattr(self, k)
451+
path = self._lookup_field(k.split('.'))
452+
actual_key = [p.db_field for p in path]
453+
val = self
454+
for ak in actual_key:
455+
val = getattr(val, ak)
456+
select_dict['__'.join(actual_key)] = val
448457
return select_dict
449458

450459
def update(self, **kwargs):

tests/document/instance.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,20 @@ class Animal(Document):
484484
doc.reload()
485485
Animal.drop_collection()
486486

487+
def test_reload_sharded_nested(self):
488+
class SuperPhylum(EmbeddedDocument):
489+
name = StringField()
490+
491+
class Animal(Document):
492+
superphylum = EmbeddedDocumentField(SuperPhylum)
493+
meta = {'shard_key': ('superphylum.name',)}
494+
495+
Animal.drop_collection()
496+
doc = Animal(superphylum=SuperPhylum(name='Deuterostomia'))
497+
doc.save()
498+
doc.reload()
499+
Animal.drop_collection()
500+
487501
def test_reload_referencing(self):
488502
"""Ensures reloading updates weakrefs correctly
489503
"""
@@ -2715,6 +2729,32 @@ def change_shard_key():
27152729

27162730
self.assertRaises(OperationError, change_shard_key)
27172731

2732+
def test_shard_key_in_embedded_document(self):
2733+
class Foo(EmbeddedDocument):
2734+
foo = StringField()
2735+
2736+
class Bar(Document):
2737+
meta = {
2738+
'shard_key': ('foo.foo',)
2739+
}
2740+
foo = EmbeddedDocumentField(Foo)
2741+
bar = StringField()
2742+
2743+
foo_doc = Foo(foo='hello')
2744+
bar_doc = Bar(foo=foo_doc, bar='world')
2745+
bar_doc.save()
2746+
2747+
self.assertTrue(bar_doc.id is not None)
2748+
2749+
bar_doc.bar = 'baz'
2750+
bar_doc.save()
2751+
2752+
def change_shard_key():
2753+
bar_doc.foo.foo = 'something'
2754+
bar_doc.save()
2755+
2756+
self.assertRaises(OperationError, change_shard_key)
2757+
27182758
def test_shard_key_primary(self):
27192759
class LogEntry(Document):
27202760
machine = StringField(primary_key=True)

0 commit comments

Comments
 (0)