Skip to content

Commit 8691706

Browse files
committed
Improved subquery support
1 parent 257cfeb commit 8691706

File tree

4 files changed

+65
-36
lines changed

4 files changed

+65
-36
lines changed

ftrack_query/abstract.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88

99
class AbstractQuery(object):
10-
"""Class to use for inheritance checks."""
10+
"""Base class to use mainly for inheritance checks."""
11+
12+
def __init__(self):
13+
self._where = []
1114

1215

1316
class AbstractComparison(object):
@@ -94,10 +97,7 @@ def parser(cls, *args, **kwargs):
9497
9598
args:
9699
Query: An unexecuted query object.
97-
This is not recommended, but an attempt will be made
98-
to execute it for a single result.
99-
It will raise an exception if multiple or none are
100-
found.
100+
This will be added as a subquery if supported.
101101
102102
dict: Like kargs, but with relationships allowed.
103103
A relationship like "parent.name" is not compatible
@@ -110,24 +110,23 @@ def parser(cls, *args, **kwargs):
110110
111111
Anything else passed in will get converted to strings.
112112
The comparison class has been designed to evaluate when
113-
__str__ is called, but any custom class could be used.
113+
to_str() is called, but any custom class could be used.
114114
115115
kwargs:
116116
Search for attributes of an entity.
117-
This is the recommended way to query if possible.
117+
`(x=y)` is the equivelant of `(entity.x == y)`.
118118
"""
119119
for arg in args:
120-
# The query has not been performed, attempt to execute
121-
# This shouldn't really be used, so don't catch any errors
122120
if isinstance(arg, AbstractQuery):
123-
arg = arg.one()
121+
for item in arg._where:
122+
yield item
124123

125-
if isinstance(arg, dict):
124+
elif isinstance(arg, dict):
126125
for key, value in arg.items():
127-
yield cls(key)==value
126+
yield cls(key) == value
128127

129128
elif isinstance(arg, ftrack_api.entity.base.Entity):
130-
raise TypeError("keyword required for {}".format(arg))
129+
raise TypeError('keyword required for {}'.format(arg))
131130

132131
# The object is likely a comparison object, so convert to str
133132
# If an actual string is input, then assume it's valid syntax
@@ -136,8 +135,11 @@ def parser(cls, *args, **kwargs):
136135

137136
for key, value in kwargs.items():
138137
if isinstance(value, AbstractQuery):
139-
value = value.one()
140-
yield cls(key)==value
138+
for item in value._where:
139+
yield item
140+
141+
else:
142+
yield cls(key) == value
141143

142144

143145
class AbstractStatement(object):

ftrack_query/query.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,9 @@ class Query(AbstractQuery):
221221
)
222222

223223
def __init__(self, session, entity):
224+
super(Query, self).__init__()
224225
self._session = session
225226
self._entity = entity
226-
self._where = []
227227
self._populate = []
228228
self._sort = []
229229
self._offset = 0
@@ -294,7 +294,6 @@ def __call__(self, *args, **kwargs):
294294
return result
295295
raise
296296

297-
298297
raise TypeError("'Query' object is not callable, "
299298
"perhaps you meant to use 'Query.where()'?")
300299

ftrack_query/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ def parse_operators(func):
4141
"""Parse the value when an operator is used."""
4242
@wraps(func)
4343
def wrapper(self, value):
44-
# If the item is constructed query, assume it's a single object
4544
if isinstance(value, AbstractQuery):
46-
value = value.one()
45+
raise NotImplementedError('query comparisons are not supported')
4746

4847
# If the item is an FTrack entity, use the ID
4948
if isinstance(value, ftrack_api.entity.base.Entity):

tests/test_query.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_any(self):
9494

9595
def test_sort(self):
9696
self.assertEqual(str(entity.a.desc()), 'a descending')
97-
self.assertNotEqual(str(entity.a.desc), 'a descending')
97+
self.assertEqual(str(entity.a.desc), 'a.desc')
9898
self.assertEqual(str(entity.a.b.asc()), 'a.b ascending')
9999

100100
def test_call(self):
@@ -138,30 +138,59 @@ def test_id_remap_in(self):
138138
'Project where project_schema.id in ("{}")'.format(schema['id'])
139139
)
140140

141-
def test_query_remap(self):
141+
def test_id_remap_in_multiple(self):
142142
schema = self.session.ProjectSchema.first()
143-
query = self.session.ProjectSchema.where(id=schema['id'])
144143
self.assertEqual(
145-
str(self.session.Project.where(entity.project_schema == query)),
146-
'Project where project_schema.id is "{}"'.format(schema['id']),
147-
)
148-
self.assertEqual(
149-
str(self.session.Project.where(project_schema=query)),
150-
'Project where project_schema.id is "{}"'.format(schema['id']),
144+
str(self.session.Project.where(entity.project_schema.in_(schema, schema))),
145+
'Project where project_schema.id in ("{s}", "{s}")'.format(s=schema['id'])
151146
)
152147

153-
def test_subquery_in(self):
154-
schema = self.session.ProjectSchema.first()
155-
query = self.session.ProjectSchema.where(id=schema['id'])
148+
149+
class TestQueryComparison(unittest.TestCase):
150+
def setUp(self):
151+
self.session = FTrackQuery(debug=True)
152+
153+
def test_in(self):
154+
query = self.session.ProjectSchema.where(name='My Schema')
156155
self.assertEqual(
157156
str(self.session.Project.where(entity.project_schema.in_(query))),
158-
'Project where project_schema.id in (select id from ProjectSchema where id is "{}")'.format(schema['id']),
157+
'Project where project_schema.id in (select id from ProjectSchema where name is "My Schema")',
159158
)
160159
with self.assertRaises(ValueError):
161-
self.assertEqual(
162-
str(self.session.Project.where(entity.project_schema.in_(query, query))),
163-
'Project where project_schema.id in ("{id}", "{id}")'.format(id=schema['id']),
164-
)
160+
str(self.session.Project.where(entity.project_schema.in_(query, query)))
161+
162+
def test_has_simple(self):
163+
query = self.session.ProjectSchema.where(name='My Schema')
164+
self.assertEqual(
165+
str(self.session.Project.where(entity.project_schema.has(query))),
166+
'Project where project_schema has (name is "My Schema")',
167+
)
168+
169+
def test_has_complex(self):
170+
query = self.session.ProjectSchema.where(~entity.project.has(name='Invalid Project'), name='My Schema')
171+
self.assertEqual(
172+
str(self.session.Project.where(entity.project_schema.has(query))),
173+
'Project where project_schema has (not project has (name is "Invalid Project") and name is "My Schema")',
174+
)
175+
176+
def test_has_multiple(self):
177+
query1 = self.session.ProjectSchema.where(~entity.project.has(name='Invalid Project'))
178+
query2 = self.session.ProjectSchema.where(name='My Schema')
179+
self.assertEqual(
180+
str(self.session.Project.where(entity.project_schema.has(query1, query2))),
181+
'Project where project_schema has (not project has (name is "Invalid Project") and name is "My Schema")',
182+
)
183+
self.assertEqual(
184+
str(self.session.Project.where(entity.project_schema.any(query1, query2))),
185+
'Project where project_schema any (not project has (name is "Invalid Project") and name is "My Schema")',
186+
)
187+
188+
def test_equals(self):
189+
with self.assertRaises(NotImplementedError):
190+
entity.value == self.session.ProjectSchema.where(name='My Schema')
191+
with self.assertRaises(NotImplementedError):
192+
entity.value != self.session.ProjectSchema.where(name='My Schema')
193+
165194

166195

167196
if __name__ == '__main__':

0 commit comments

Comments
 (0)