New code for performing explicit joins with custom join conditions.
* added ExtendedManager.join_custom_field(), which uses the introspection magic from populate_relationships (now factored out) to infer the type of relationship between two models and construct the correct join. join_custom_field() presents a much simpler, more Django-y interface for doing this sort of thing -- compare with add_join() above it.
* changed TKO custom fields code to use join_custom_field()
* added some cases to AFE rpc_interface_unittest to ensure populate_relationships() usage didn't break
* simplified _CustomQuery and got rid of _CustomSqlQ. _CustomQuery can do the work itself and its cleaner this way.
* added add_where(), an alternative to extra(where=...) that fits more into Django's normal representation of WHERE clauses, and therefore supports & and | operators later
Signed-off-by: Steve Howard <[email protected]>
git-svn-id: http://test.kernel.org/svn/autotest/trunk@4155 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/afe/model_logic.py b/frontend/afe/model_logic.py
index 3b5c70e..7fbdb76 100644
--- a/frontend/afe/model_logic.py
+++ b/frontend/afe/model_logic.py
@@ -6,6 +6,7 @@
import django.core.exceptions
from django.db import models as dbmodels, backend, connection
from django.db.models.sql import query
+import django.db.models.sql.where
from django.utils import datastructures
from autotest_lib.frontend.afe import readonly_connection
@@ -94,77 +95,97 @@
"""
class _CustomQuery(query.Query):
+ def __init__(self, *args, **kwargs):
+ super(ExtendedManager._CustomQuery, self).__init__(*args, **kwargs)
+ self._custom_joins = []
+
+
def clone(self, klass=None, **kwargs):
- obj = super(ExtendedManager._CustomQuery, self).clone(
- klass, _customSqlQ=self._customSqlQ)
-
- customQ = kwargs.get('_customSqlQ', None)
- if customQ is not None:
- obj._customSqlQ._joins.update(customQ._joins)
- obj._customSqlQ._where.extend(customQ._where)
- obj._customSqlQ._params.extend(customQ._params)
-
+ obj = super(ExtendedManager._CustomQuery, self).clone(klass)
+ obj._custom_joins = list(self._custom_joins)
return obj
+
+ def combine(self, rhs, connector):
+ super(ExtendedManager._CustomQuery, self).combine(rhs, connector)
+ if hasattr(rhs, '_custom_joins'):
+ self._custom_joins.extend(rhs._custom_joins)
+
+
+ def add_custom_join(self, table, condition, join_type,
+ condition_values=(), alias=None):
+ if alias is None:
+ alias = table
+ join_dict = dict(table=table,
+ condition=condition,
+ condition_values=condition_values,
+ join_type=join_type,
+ alias=alias)
+ self._custom_joins.append(join_dict)
+
+
def get_from_clause(self):
- from_, params = super(
- ExtendedManager._CustomQuery, self).get_from_clause()
+ from_, params = (super(ExtendedManager._CustomQuery, self)
+ .get_from_clause())
- join_clause = ''
- for join_alias, join in self._customSqlQ._joins.iteritems():
- join_table, join_type, condition = join
- join_clause += ' %s %s AS %s ON (%s)' % (
- join_type, _quote_name(join_table),
- _quote_name(join_alias), condition)
-
- if join_clause:
- from_.append(join_clause)
+ for join_dict in self._custom_joins:
+ from_.append('%s %s AS %s ON (%s)'
+ % (join_dict['join_type'],
+ _quote_name(join_dict['table']),
+ _quote_name(join_dict['alias']),
+ join_dict['condition']))
+ params.extend(join_dict['condition_values'])
return from_, params
- class _CustomSqlQ(dbmodels.Q):
- def __init__(self):
- self._joins = datastructures.SortedDict()
- self._where, self._params = [], []
+ @classmethod
+ def convert_query(self, query_set):
+ """
+ Convert the query set's "query" attribute to a _CustomQuery.
+ """
+ # Make a copy of the query set
+ query_set = query_set.all()
+ query_set.query = query_set.query.clone(
+ klass=ExtendedManager._CustomQuery,
+ _custom_joins=[])
+ return query_set
- def add_join(self, table, condition, join_type, alias=None):
- if alias is None:
- alias = table
- self._joins[alias] = (table, join_type, condition)
+ class _WhereClause(object):
+ """Object allowing us to inject arbitrary SQL into Django queries.
-
- def add_where(self, where, params=[]):
- self._where.append(where)
- self._params.extend(params)
-
-
- def add_to_query(self, query, aliases):
- if self._where:
- where = ' AND '.join(self._where)
- query.add_extra(None, None, (where,), self._params, None, None)
-
-
- def _add_customSqlQ(self, query_set, filter_object):
- """\
- Add a _CustomSqlQ to the query set.
+ By using this instead of extra(where=...), we can still freely combine
+ queries with & and |.
"""
- # Make a copy of the query set
- query_set = query_set.all()
+ def __init__(self, clause, values=()):
+ self._clause = clause
+ self._values = values
- query_set.query = query_set.query.clone(
- ExtendedManager._CustomQuery, _customSqlQ=filter_object)
- return query_set.filter(filter_object)
+
+ def as_sql(self, qn=None):
+ return self._clause, self._values
+
+
+ def relabel_aliases(self, change_map):
+ return
def add_join(self, query_set, join_table, join_key, join_condition='',
- alias=None, suffix='', exclude=False, force_left_join=False):
- """
- Add a join to query_set.
+ join_condition_values=(), join_from_key=None, alias=None,
+ suffix='', exclude=False, force_left_join=False):
+ """Add a join to query_set.
+
+ Join looks like this:
+ (INNER|LEFT) JOIN <join_table> AS <alias>
+ ON (<this table>.<join_from_key> = <join_table>.<join_key>
+ and <join_condition>)
+
@param join_table table to join to
@param join_key field referencing back to this model to use for the join
@param join_condition extra condition for the ON clause of the join
+ @param join_condition_values values to substitute into join_condition
+ @param join_from_key column on this model to join from.
@param alias alias to use for for join
@param suffix suffix to add to join_table for the join alias, if no
alias is provided
@@ -173,15 +194,15 @@
@param force_left_join - if true, a LEFT OUTER JOIN will be used
instead of an INNER JOIN regardless of other options
"""
- join_from_table = _quote_name(self.model._meta.db_table)
- join_from_key = _quote_name(self.model._meta.pk.name)
- if alias:
- join_alias = alias
- else:
- join_alias = join_table + suffix
- full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key)
- full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
- join_from_key)
+ join_from_table = query_set.model._meta.db_table
+ if join_from_key is None:
+ join_from_key = self.model._meta.pk.name
+ if alias is None:
+ alias = join_table + suffix
+ full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
+ full_join_condition = '%s = %s.%s' % (full_join_key,
+ _quote_name(join_from_table),
+ _quote_name(join_from_key))
if join_condition:
full_join_condition += ' AND (' + join_condition + ')'
if exclude or force_left_join:
@@ -189,15 +210,128 @@
else:
join_type = query_set.query.INNER
- filter_object = self._CustomSqlQ()
- filter_object.add_join(join_table,
- full_join_condition,
- join_type,
- alias=join_alias)
- if exclude:
- filter_object.add_where(full_join_key + ' IS NULL')
+ query_set = self._CustomQuery.convert_query(query_set)
+ query_set.query.add_custom_join(join_table,
+ full_join_condition,
+ join_type,
+ condition_values=join_condition_values,
+ alias=alias)
- query_set = self._add_customSqlQ(query_set, filter_object)
+ if exclude:
+ query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
+
+ return query_set
+
+
+ def _info_for_many_to_one_join(self, field, join_to_query, alias):
+ """
+ @param field: the ForeignKey field on the related model
+ @param join_to_query: the query over the related model that we're
+ joining to
+ @param alias: alias of joined table
+ """
+ info = {}
+ rhs_table = join_to_query.model._meta.db_table
+ info['rhs_table'] = rhs_table
+ info['rhs_column'] = field.column
+ info['lhs_column'] = field.rel.get_related_field().column
+ rhs_where = join_to_query.query.where
+ rhs_where.relabel_aliases({rhs_table: alias})
+ initial_clause, values = rhs_where.as_sql()
+ all_clauses = (initial_clause,) + join_to_query.query.extra_where
+ info['where_clause'] = ' AND '.join('(%s)' % clause
+ for clause in all_clauses)
+ values += join_to_query.query.extra_params
+ info['values'] = values
+ return info
+
+
+ def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
+ m2m_is_on_this_model):
+ """
+ @param m2m_field: a Django field representing the M2M relationship.
+ It uses a pivot table with the following structure:
+ this model table <---> M2M pivot table <---> joined model table
+ @param join_to_query: the query over the related model that we're
+ joining to.
+ @param alias: alias of joined table
+ """
+ if m2m_is_on_this_model:
+ # referenced field on this model
+ lhs_id_field = self.model._meta.pk
+ # foreign key on the pivot table referencing lhs_id_field
+ m2m_lhs_column = m2m_field.m2m_column_name()
+ # foreign key on the pivot table referencing rhd_id_field
+ m2m_rhs_column = m2m_field.m2m_reverse_name()
+ # referenced field on related model
+ rhs_id_field = m2m_field.rel.get_related_field()
+ else:
+ lhs_id_field = m2m_field.rel.get_related_field()
+ m2m_lhs_column = m2m_field.m2m_reverse_name()
+ m2m_rhs_column = m2m_field.m2m_column_name()
+ rhs_id_field = join_to_query.model._meta.pk
+
+ info = {}
+ info['rhs_table'] = m2m_field.m2m_db_table()
+ info['rhs_column'] = m2m_lhs_column
+ info['lhs_column'] = lhs_id_field.column
+
+ # select the ID of related models relevant to this join. we can only do
+ # a single join, so we need to gather this information up front and
+ # include it in the join condition.
+ rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
+ assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
+ 'match a single related object.')
+ rhs_id = rhs_ids[0]
+
+ info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
+ _quote_name(m2m_rhs_column),
+ rhs_id)
+ info['values'] = ()
+ return info
+
+
+ def join_custom_field(self, query_set, join_to_query, alias,
+ left_join=True):
+ """Join to a related model to create a custom field in the given query.
+
+ This method is used to construct a custom field on the given query based
+ on a many-valued relationsip. join_to_query should be a simple query
+ (no joins) on the related model which returns at most one related row
+ per instance of this model.
+
+ For many-to-one relationships, the joined table contains the matching
+ row from the related model it one is related, NULL otherwise.
+
+ For many-to-many relationships, the joined table contains the matching
+ row if it's related, NULL otherwise.
+ """
+ relationship_type, field = self.determine_relationship(
+ join_to_query.model)
+
+ if relationship_type == self.MANY_TO_ONE:
+ info = self._info_for_many_to_one_join(field, join_to_query, alias)
+ elif relationship_type == self.M2M_ON_RELATED_MODEL:
+ info = self._info_for_many_to_many_join(
+ m2m_field=field, join_to_query=join_to_query, alias=alias,
+ m2m_is_on_this_model=False)
+ elif relationship_type ==self.M2M_ON_THIS_MODEL:
+ info = self._info_for_many_to_many_join(
+ m2m_field=field, join_to_query=join_to_query, alias=alias,
+ m2m_is_on_this_model=True)
+
+ return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
+ join_from_key=info['lhs_column'],
+ join_condition=info['where_clause'],
+ join_condition_values=info['values'],
+ alias=alias,
+ force_left_join=left_join)
+
+
+ def add_where(self, query_set, where, values=()):
+ query_set = query_set.all()
+ query_set.query.where.add(self._WhereClause(where, values),
+ django.db.models.sql.where.AND)
return query_set
@@ -235,6 +369,39 @@
return field.rel and field.rel.to is model_class
+ MANY_TO_ONE = object()
+ M2M_ON_RELATED_MODEL = object()
+ M2M_ON_THIS_MODEL = object()
+
+ def determine_relationship(self, related_model):
+ """
+ Determine the relationship between this model and related_model.
+
+ related_model must have some sort of many-valued relationship to this
+ manager's model.
+ @returns (relationship_type, field), where relationship_type is one of
+ MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
+ is the Django field object for the relationship.
+ """
+ # look for a foreign key field on related_model relating to this model
+ for field in related_model._meta.fields:
+ if self._is_relation_to(field, self.model):
+ return self.MANY_TO_ONE, field
+
+ # look for an M2M field on related_model relating to this model
+ for field in related_model._meta.many_to_many:
+ if self._is_relation_to(field, self.model):
+ return self.M2M_ON_RELATED_MODEL, field
+
+ # maybe this model has the many-to-many field
+ for field in self.model._meta.many_to_many:
+ if self._is_relation_to(field, related_model):
+ return self.M2M_ON_THIS_MODEL, field
+
+ raise ValueError('%s has no relation to %s' %
+ (related_model, self.model))
+
+
def _get_pivot_iterator(self, base_objects_by_id, related_model):
"""
Determine the relationship between this model and related_model, and
@@ -244,33 +411,22 @@
@returns a pivot iterator, which yields a tuple (base_object,
related_object) for each relationship between a base object and a
related object. all base_object instances come from base_objects_by_id.
- Note -- this depends on Django model internals and will likely need to
- be updated when we move to Django 1.x.
+ Note -- this depends on Django model internals.
"""
- # look for a field on related_model relating to this model
- for field in related_model._meta.fields:
- if self._is_relation_to(field, self.model):
- # many-to-one
- return self._many_to_one_pivot(base_objects_by_id,
- related_model, field)
-
- for field in related_model._meta.many_to_many:
- if self._is_relation_to(field, self.model):
- # many-to-many
- return self._many_to_many_pivot(
+ relationship_type, field = self.determine_relationship(related_model)
+ if relationship_type == self.MANY_TO_ONE:
+ return self._many_to_one_pivot(base_objects_by_id,
+ related_model, field)
+ elif relationship_type == self.M2M_ON_RELATED_MODEL:
+ return self._many_to_many_pivot(
base_objects_by_id, related_model, field.m2m_db_table(),
field.m2m_reverse_name(), field.m2m_column_name())
-
- # maybe this model has the many-to-many field
- for field in self.model._meta.many_to_many:
- if self._is_relation_to(field, related_model):
- return self._many_to_many_pivot(
+ else:
+ assert relationship_type == self.M2M_ON_THIS_MODEL
+ return self._many_to_many_pivot(
base_objects_by_id, related_model, field.m2m_db_table(),
field.m2m_column_name(), field.m2m_reverse_name())
- raise ValueError('%s has no relation to %s' %
- (related_model, self.model))
-
def _many_to_one_pivot(self, base_objects_by_id, related_model,
foreign_key_field):
diff --git a/frontend/afe/rpc_interface_unittest.py b/frontend/afe/rpc_interface_unittest.py
index c84dacf..7a8b7e2 100755
--- a/frontend/afe/rpc_interface_unittest.py
+++ b/frontend/afe/rpc_interface_unittest.py
@@ -60,6 +60,12 @@
hosts = rpc_interface.get_hosts(hostname='host1')
self._check_hostnames(hosts, ['host1'])
+ host = hosts[0]
+ self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
+ self.assertEquals(host['platform'], 'myplatform')
+ self.assertEquals(host['atomic_group'], None)
+ self.assertEquals(host['acls'], ['my_acl'])
+ self.assertEquals(host['attributes'], {})
def test_get_hosts_multiple_labels(self):
diff --git a/frontend/tko/models.py b/frontend/tko/models.py
index 429473d..7348b07 100644
--- a/frontend/tko/models.py
+++ b/frontend/tko/models.py
@@ -327,12 +327,11 @@
second_join_condition = ('%s.id = %s.testlabel_id' %
(second_join_alias,
'tko_test_labels_tests' + suffix))
- filter_object = self._CustomSqlQ()
- filter_object.add_join('tko_test_labels',
- second_join_condition,
- query_set.query.LOUTER,
- alias=second_join_alias)
- return self._add_customSqlQ(query_set, filter_object)
+ query_set.query.add_custom_join('tko_test_labels',
+ second_join_condition,
+ query_set.query.LOUTER,
+ alias=second_join_alias)
+ return query_set
def _get_label_ids_from_names(self, label_names):
@@ -373,12 +372,10 @@
def _join_label_column(self, query_set, label_name, label_id):
- table_name = TestLabel.tests.field.m2m_db_table()
alias = 'label_' + label_name
- condition = "%s.testlabel_id = %s" % (_quote_name(alias), label_id)
- query_set = self.add_join(query_set, table_name,
- join_key='test_id', join_condition=condition,
- alias=alias, force_left_join=True)
+ label_query = TestLabel.objects.filter(name=label_name)
+ query_set = Test.objects.join_custom_field(query_set, label_query,
+ alias)
query_set = self._add_select_ifnull(query_set, alias, label_name)
return query_set
@@ -392,23 +389,21 @@
return query_set
- def _join_attribute(self, test_view_query_set, attribute,
- alias=None, extra_join_condition=None):
+ def _join_attribute(self, query_set, attribute, alias=None,
+ extra_join_condition=None):
"""
Join the given TestView QuerySet to TestAttribute. The resulting query
has an additional column for the given attribute named
"attribute_<attribute name>".
"""
- table_name = TestAttribute._meta.db_table
if not alias:
alias = 'attribute_' + attribute
- condition = "%s.attribute = '%s'" % (_quote_name(alias),
- self.escape_user_sql(attribute))
+ attribute_query = TestAttribute.objects.filter(attribute=attribute)
if extra_join_condition:
- condition += ' AND (%s)' % extra_join_condition
- query_set = self.add_join(test_view_query_set, table_name,
- join_key='test_idx', join_condition=condition,
- alias=alias, force_left_join=True)
+ attribute_query = attribute_query.extra(
+ where=[extra_join_condition])
+ query_set = Test.objects.join_custom_field(query_set, attribute_query,
+ alias)
query_set = self._add_select_value(query_set, alias)
return query_set
@@ -427,23 +422,18 @@
def _join_one_iteration_key(self, query_set, result_key, first_alias=None):
- table_name = IterationResult._meta.db_table
alias = 'iteration_' + result_key
- condition_parts = ["%s.attribute = '%s'" %
- (_quote_name(alias),
- self.escape_user_sql(result_key))]
+ iteration_query = IterationResult.objects.filter(attribute=result_key)
if first_alias:
# after the first join, we need to match up iteration indices,
# otherwise each join will expand the query by the number of
# iterations and we'll have extraneous rows
- condition_parts.append('%s.iteration = %s.iteration' %
- (_quote_name(alias),
- _quote_name(first_alias)))
+ iteration_query = iteration_query.extra(
+ where=['%s.iteration = %s.iteration'
+ % (_quote_name(alias), _quote_name(first_alias))])
- condition = ' and '.join(condition_parts)
- # add a join to IterationResult
- query_set = self.add_join(query_set, table_name, join_key='test_idx',
- join_condition=condition, alias=alias)
+ query_set = Test.objects.join_custom_field(query_set, iteration_query,
+ alias, left_join=False)
# select the iteration value and index for this join
query_set = self._add_select_value(query_set, alias)
if not first_alias:
diff --git a/frontend/tko/rpc_interface_unittest.py b/frontend/tko/rpc_interface_unittest.py
index bbae95a..248ea83 100644
--- a/frontend/tko/rpc_interface_unittest.py
+++ b/frontend/tko/rpc_interface_unittest.py
@@ -465,7 +465,7 @@
self.assertEquals(len(tests), 3)
self.assertEquals(tests[0]['label_testlabel1'], 'testlabel1')
- self.assert_(tests[0]['label_testlabel2'], 'testlabel2')
+ self.assertEquals(tests[0]['label_testlabel2'], 'testlabel2')
for index in (1, 2):
self.assertEquals(tests[index]['label_testlabel1'], None)