New code for performing explicit joins with custom join conditions.
* added ExtendedManager.join_custom_field(), which uses the introspection magic from populate_relationships (now factored out) to infer the type of relationship between two models and construct the correct join.  join_custom_field() presents a much simpler, more Django-y interface for doing this sort of thing -- compare with add_join() above it.
* changed TKO custom fields code to use join_custom_field()
* added some cases to AFE rpc_interface_unittest to ensure populate_relationships() usage didn't break
* simplified _CustomQuery and got rid of _CustomSqlQ.  _CustomQuery can do the work itself and its cleaner this way.
* added add_where(), an alternative to extra(where=...) that fits more into Django's normal representation of WHERE clauses, and therefore supports & and | operators later

Signed-off-by: Steve Howard <[email protected]>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@4155 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/afe/model_logic.py b/frontend/afe/model_logic.py
index 3b5c70e..7fbdb76 100644
--- a/frontend/afe/model_logic.py
+++ b/frontend/afe/model_logic.py
@@ -6,6 +6,7 @@
 import django.core.exceptions
 from django.db import models as dbmodels, backend, connection
 from django.db.models.sql import query
+import django.db.models.sql.where
 from django.utils import datastructures
 from autotest_lib.frontend.afe import readonly_connection
 
@@ -94,77 +95,97 @@
     """
 
     class _CustomQuery(query.Query):
+        def __init__(self, *args, **kwargs):
+            super(ExtendedManager._CustomQuery, self).__init__(*args, **kwargs)
+            self._custom_joins = []
+
+
         def clone(self, klass=None, **kwargs):
-            obj = super(ExtendedManager._CustomQuery, self).clone(
-                klass, _customSqlQ=self._customSqlQ)
-
-            customQ = kwargs.get('_customSqlQ', None)
-            if customQ is not None:
-                obj._customSqlQ._joins.update(customQ._joins)
-                obj._customSqlQ._where.extend(customQ._where)
-                obj._customSqlQ._params.extend(customQ._params)
-
+            obj = super(ExtendedManager._CustomQuery, self).clone(klass)
+            obj._custom_joins = list(self._custom_joins)
             return obj
 
+
+        def combine(self, rhs, connector):
+            super(ExtendedManager._CustomQuery, self).combine(rhs, connector)
+            if hasattr(rhs, '_custom_joins'):
+                self._custom_joins.extend(rhs._custom_joins)
+
+
+        def add_custom_join(self, table, condition, join_type,
+                            condition_values=(), alias=None):
+            if alias is None:
+                alias = table
+            join_dict = dict(table=table,
+                             condition=condition,
+                             condition_values=condition_values,
+                             join_type=join_type,
+                             alias=alias)
+            self._custom_joins.append(join_dict)
+
+
         def get_from_clause(self):
-            from_, params = super(
-                ExtendedManager._CustomQuery, self).get_from_clause()
+            from_, params = (super(ExtendedManager._CustomQuery, self)
+                             .get_from_clause())
 
-            join_clause = ''
-            for join_alias, join in self._customSqlQ._joins.iteritems():
-                join_table, join_type, condition = join
-                join_clause += ' %s %s AS %s ON (%s)' % (
-                    join_type, _quote_name(join_table),
-                    _quote_name(join_alias), condition)
-
-            if join_clause:
-                from_.append(join_clause)
+            for join_dict in self._custom_joins:
+                from_.append('%s %s AS %s ON (%s)'
+                             % (join_dict['join_type'],
+                                _quote_name(join_dict['table']),
+                                _quote_name(join_dict['alias']),
+                                join_dict['condition']))
+                params.extend(join_dict['condition_values'])
 
             return from_, params
 
 
-    class _CustomSqlQ(dbmodels.Q):
-        def __init__(self):
-            self._joins = datastructures.SortedDict()
-            self._where, self._params = [], []
+        @classmethod
+        def convert_query(self, query_set):
+            """
+            Convert the query set's "query" attribute to a _CustomQuery.
+            """
+            # Make a copy of the query set
+            query_set = query_set.all()
+            query_set.query = query_set.query.clone(
+                    klass=ExtendedManager._CustomQuery,
+                    _custom_joins=[])
+            return query_set
 
 
-        def add_join(self, table, condition, join_type, alias=None):
-            if alias is None:
-                alias = table
-            self._joins[alias] = (table, join_type, condition)
+    class _WhereClause(object):
+        """Object allowing us to inject arbitrary SQL into Django queries.
 
-
-        def add_where(self, where, params=[]):
-            self._where.append(where)
-            self._params.extend(params)
-
-
-        def add_to_query(self, query, aliases):
-            if self._where:
-                where = ' AND '.join(self._where)
-                query.add_extra(None, None, (where,), self._params, None, None)
-
-
-    def _add_customSqlQ(self, query_set, filter_object):
-        """\
-        Add a _CustomSqlQ to the query set.
+        By using this instead of extra(where=...), we can still freely combine
+        queries with & and |.
         """
-        # Make a copy of the query set
-        query_set = query_set.all()
+        def __init__(self, clause, values=()):
+            self._clause = clause
+            self._values = values
 
-        query_set.query = query_set.query.clone(
-            ExtendedManager._CustomQuery, _customSqlQ=filter_object)
-        return query_set.filter(filter_object)
+
+        def as_sql(self, qn=None):
+            return self._clause, self._values
+
+
+        def relabel_aliases(self, change_map):
+            return
 
 
     def add_join(self, query_set, join_table, join_key, join_condition='',
-                 alias=None, suffix='', exclude=False, force_left_join=False):
-        """
-        Add a join to query_set.
+                 join_condition_values=(), join_from_key=None, alias=None,
+                 suffix='', exclude=False, force_left_join=False):
+        """Add a join to query_set.
+
+        Join looks like this:
+                (INNER|LEFT) JOIN <join_table> AS <alias>
+                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
+                        and <join_condition>)
+
         @param join_table table to join to
         @param join_key field referencing back to this model to use for the join
         @param join_condition extra condition for the ON clause of the join
+        @param join_condition_values values to substitute into join_condition
+        @param join_from_key column on this model to join from.
         @param alias alias to use for for join
         @param suffix suffix to add to join_table for the join alias, if no
                 alias is provided
@@ -173,15 +194,15 @@
         @param force_left_join - if true, a LEFT OUTER JOIN will be used
         instead of an INNER JOIN regardless of other options
         """
-        join_from_table = _quote_name(self.model._meta.db_table)
-        join_from_key = _quote_name(self.model._meta.pk.name)
-        if alias:
-            join_alias = alias
-        else:
-            join_alias = join_table + suffix
-        full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key)
-        full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
-                                              join_from_key)
+        join_from_table = query_set.model._meta.db_table
+        if join_from_key is None:
+            join_from_key = self.model._meta.pk.name
+        if alias is None:
+            alias = join_table + suffix
+        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
+        full_join_condition = '%s = %s.%s' % (full_join_key,
+                                              _quote_name(join_from_table),
+                                              _quote_name(join_from_key))
         if join_condition:
             full_join_condition += ' AND (' + join_condition + ')'
         if exclude or force_left_join:
@@ -189,15 +210,128 @@
         else:
             join_type = query_set.query.INNER
 
-        filter_object = self._CustomSqlQ()
-        filter_object.add_join(join_table,
-                               full_join_condition,
-                               join_type,
-                               alias=join_alias)
-        if exclude:
-            filter_object.add_where(full_join_key + ' IS NULL')
+        query_set = self._CustomQuery.convert_query(query_set)
+        query_set.query.add_custom_join(join_table,
+                                        full_join_condition,
+                                        join_type,
+                                        condition_values=join_condition_values,
+                                        alias=alias)
 
-        query_set = self._add_customSqlQ(query_set, filter_object)
+        if exclude:
+            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
+
+        return query_set
+
+
+    def _info_for_many_to_one_join(self, field, join_to_query, alias):
+        """
+        @param field: the ForeignKey field on the related model
+        @param join_to_query: the query over the related model that we're
+                joining to
+        @param alias: alias of joined table
+        """
+        info = {}
+        rhs_table = join_to_query.model._meta.db_table
+        info['rhs_table'] = rhs_table
+        info['rhs_column'] = field.column
+        info['lhs_column'] = field.rel.get_related_field().column
+        rhs_where = join_to_query.query.where
+        rhs_where.relabel_aliases({rhs_table: alias})
+        initial_clause, values = rhs_where.as_sql()
+        all_clauses = (initial_clause,) + join_to_query.query.extra_where
+        info['where_clause'] = ' AND '.join('(%s)' % clause
+                                            for clause in all_clauses)
+        values += join_to_query.query.extra_params
+        info['values'] = values
+        return info
+
+
+    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
+                                    m2m_is_on_this_model):
+        """
+        @param m2m_field: a Django field representing the M2M relationship.
+                It uses a pivot table with the following structure:
+                this model table <---> M2M pivot table <---> joined model table
+        @param join_to_query: the query over the related model that we're
+                joining to.
+        @param alias: alias of joined table
+        """
+        if m2m_is_on_this_model:
+            # referenced field on this model
+            lhs_id_field = self.model._meta.pk
+            # foreign key on the pivot table referencing lhs_id_field
+            m2m_lhs_column = m2m_field.m2m_column_name()
+            # foreign key on the pivot table referencing rhd_id_field
+            m2m_rhs_column = m2m_field.m2m_reverse_name()
+            # referenced field on related model
+            rhs_id_field = m2m_field.rel.get_related_field()
+        else:
+            lhs_id_field = m2m_field.rel.get_related_field()
+            m2m_lhs_column = m2m_field.m2m_reverse_name()
+            m2m_rhs_column = m2m_field.m2m_column_name()
+            rhs_id_field = join_to_query.model._meta.pk
+
+        info = {}
+        info['rhs_table'] = m2m_field.m2m_db_table()
+        info['rhs_column'] = m2m_lhs_column
+        info['lhs_column'] = lhs_id_field.column
+
+        # select the ID of related models relevant to this join.  we can only do
+        # a single join, so we need to gather this information up front and
+        # include it in the join condition.
+        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
+        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
+                                   'match a single related object.')
+        rhs_id = rhs_ids[0]
+
+        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
+                                               _quote_name(m2m_rhs_column),
+                                               rhs_id)
+        info['values'] = ()
+        return info
+
+
+    def join_custom_field(self, query_set, join_to_query, alias,
+                          left_join=True):
+        """Join to a related model to create a custom field in the given query.
+
+        This method is used to construct a custom field on the given query based
+        on a many-valued relationsip.  join_to_query should be a simple query
+        (no joins) on the related model which returns at most one related row
+        per instance of this model.
+
+        For many-to-one relationships, the joined table contains the matching
+        row from the related model it one is related, NULL otherwise.
+
+        For many-to-many relationships, the joined table contains the matching
+        row if it's related, NULL otherwise.
+        """
+        relationship_type, field = self.determine_relationship(
+                join_to_query.model)
+
+        if relationship_type == self.MANY_TO_ONE:
+            info = self._info_for_many_to_one_join(field, join_to_query, alias)
+        elif relationship_type == self.M2M_ON_RELATED_MODEL:
+            info = self._info_for_many_to_many_join(
+                    m2m_field=field, join_to_query=join_to_query, alias=alias,
+                    m2m_is_on_this_model=False)
+        elif relationship_type ==self.M2M_ON_THIS_MODEL:
+            info = self._info_for_many_to_many_join(
+                    m2m_field=field, join_to_query=join_to_query, alias=alias,
+                    m2m_is_on_this_model=True)
+
+        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
+                             join_from_key=info['lhs_column'],
+                             join_condition=info['where_clause'],
+                             join_condition_values=info['values'],
+                             alias=alias,
+                             force_left_join=left_join)
+
+
+    def add_where(self, query_set, where, values=()):
+        query_set = query_set.all()
+        query_set.query.where.add(self._WhereClause(where, values),
+                                  django.db.models.sql.where.AND)
         return query_set
 
 
@@ -235,6 +369,39 @@
         return field.rel and field.rel.to is model_class
 
 
+    MANY_TO_ONE = object()
+    M2M_ON_RELATED_MODEL = object()
+    M2M_ON_THIS_MODEL = object()
+
+    def determine_relationship(self, related_model):
+        """
+        Determine the relationship between this model and related_model.
+
+        related_model must have some sort of many-valued relationship to this
+        manager's model.
+        @returns (relationship_type, field), where relationship_type is one of
+                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
+                is the Django field object for the relationship.
+        """
+        # look for a foreign key field on related_model relating to this model
+        for field in related_model._meta.fields:
+            if self._is_relation_to(field, self.model):
+                return self.MANY_TO_ONE, field
+
+        # look for an M2M field on related_model relating to this model
+        for field in related_model._meta.many_to_many:
+            if self._is_relation_to(field, self.model):
+                return self.M2M_ON_RELATED_MODEL, field
+
+        # maybe this model has the many-to-many field
+        for field in self.model._meta.many_to_many:
+            if self._is_relation_to(field, related_model):
+                return self.M2M_ON_THIS_MODEL, field
+
+        raise ValueError('%s has no relation to %s' %
+                         (related_model, self.model))
+
+
     def _get_pivot_iterator(self, base_objects_by_id, related_model):
         """
         Determine the relationship between this model and related_model, and
@@ -244,33 +411,22 @@
         @returns a pivot iterator, which yields a tuple (base_object,
         related_object) for each relationship between a base object and a
         related object.  all base_object instances come from base_objects_by_id.
-        Note -- this depends on Django model internals and will likely need to
-        be updated when we move to Django 1.x.
+        Note -- this depends on Django model internals.
         """
-        # look for a field on related_model relating to this model
-        for field in related_model._meta.fields:
-            if self._is_relation_to(field, self.model):
-                # many-to-one
-                return self._many_to_one_pivot(base_objects_by_id,
-                                               related_model, field)
-
-        for field in related_model._meta.many_to_many:
-            if self._is_relation_to(field, self.model):
-                # many-to-many
-                return self._many_to_many_pivot(
+        relationship_type, field = self.determine_relationship(related_model)
+        if relationship_type == self.MANY_TO_ONE:
+            return self._many_to_one_pivot(base_objects_by_id,
+                                           related_model, field)
+        elif relationship_type == self.M2M_ON_RELATED_MODEL:
+            return self._many_to_many_pivot(
                     base_objects_by_id, related_model, field.m2m_db_table(),
                     field.m2m_reverse_name(), field.m2m_column_name())
-
-        # maybe this model has the many-to-many field
-        for field in self.model._meta.many_to_many:
-            if self._is_relation_to(field, related_model):
-                return self._many_to_many_pivot(
+        else:
+            assert relationship_type == self.M2M_ON_THIS_MODEL
+            return self._many_to_many_pivot(
                     base_objects_by_id, related_model, field.m2m_db_table(),
                     field.m2m_column_name(), field.m2m_reverse_name())
 
-        raise ValueError('%s has no relation to %s' %
-                         (related_model, self.model))
-
 
     def _many_to_one_pivot(self, base_objects_by_id, related_model,
                            foreign_key_field):
diff --git a/frontend/afe/rpc_interface_unittest.py b/frontend/afe/rpc_interface_unittest.py
index c84dacf..7a8b7e2 100755
--- a/frontend/afe/rpc_interface_unittest.py
+++ b/frontend/afe/rpc_interface_unittest.py
@@ -60,6 +60,12 @@
 
         hosts = rpc_interface.get_hosts(hostname='host1')
         self._check_hostnames(hosts, ['host1'])
+        host = hosts[0]
+        self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
+        self.assertEquals(host['platform'], 'myplatform')
+        self.assertEquals(host['atomic_group'], None)
+        self.assertEquals(host['acls'], ['my_acl'])
+        self.assertEquals(host['attributes'], {})
 
 
     def test_get_hosts_multiple_labels(self):
diff --git a/frontend/tko/models.py b/frontend/tko/models.py
index 429473d..7348b07 100644
--- a/frontend/tko/models.py
+++ b/frontend/tko/models.py
@@ -327,12 +327,11 @@
         second_join_condition = ('%s.id = %s.testlabel_id' %
                                  (second_join_alias,
                                   'tko_test_labels_tests' + suffix))
-        filter_object = self._CustomSqlQ()
-        filter_object.add_join('tko_test_labels',
-                               second_join_condition,
-                               query_set.query.LOUTER,
-                               alias=second_join_alias)
-        return self._add_customSqlQ(query_set, filter_object)
+        query_set.query.add_custom_join('tko_test_labels',
+                                        second_join_condition,
+                                        query_set.query.LOUTER,
+                                        alias=second_join_alias)
+        return query_set
 
 
     def _get_label_ids_from_names(self, label_names):
@@ -373,12 +372,10 @@
 
 
     def _join_label_column(self, query_set, label_name, label_id):
-        table_name = TestLabel.tests.field.m2m_db_table()
         alias = 'label_' + label_name
-        condition = "%s.testlabel_id = %s" % (_quote_name(alias), label_id)
-        query_set = self.add_join(query_set, table_name,
-                                  join_key='test_id', join_condition=condition,
-                                  alias=alias, force_left_join=True)
+        label_query = TestLabel.objects.filter(name=label_name)
+        query_set = Test.objects.join_custom_field(query_set, label_query,
+                                                   alias)
 
         query_set = self._add_select_ifnull(query_set, alias, label_name)
         return query_set
@@ -392,23 +389,21 @@
         return query_set
 
 
-    def _join_attribute(self, test_view_query_set, attribute,
-                        alias=None, extra_join_condition=None):
+    def _join_attribute(self, query_set, attribute, alias=None,
+                        extra_join_condition=None):
         """
         Join the given TestView QuerySet to TestAttribute.  The resulting query
         has an additional column for the given attribute named
         "attribute_<attribute name>".
         """
-        table_name = TestAttribute._meta.db_table
         if not alias:
             alias = 'attribute_' + attribute
-        condition = "%s.attribute = '%s'" % (_quote_name(alias),
-                                             self.escape_user_sql(attribute))
+        attribute_query = TestAttribute.objects.filter(attribute=attribute)
         if extra_join_condition:
-            condition += ' AND (%s)' % extra_join_condition
-        query_set = self.add_join(test_view_query_set, table_name,
-                                  join_key='test_idx', join_condition=condition,
-                                  alias=alias, force_left_join=True)
+            attribute_query = attribute_query.extra(
+                    where=[extra_join_condition])
+        query_set = Test.objects.join_custom_field(query_set, attribute_query,
+                                                   alias)
 
         query_set = self._add_select_value(query_set, alias)
         return query_set
@@ -427,23 +422,18 @@
 
 
     def _join_one_iteration_key(self, query_set, result_key, first_alias=None):
-        table_name = IterationResult._meta.db_table
         alias = 'iteration_' + result_key
-        condition_parts = ["%s.attribute = '%s'" %
-                           (_quote_name(alias),
-                            self.escape_user_sql(result_key))]
+        iteration_query = IterationResult.objects.filter(attribute=result_key)
         if first_alias:
             # after the first join, we need to match up iteration indices,
             # otherwise each join will expand the query by the number of
             # iterations and we'll have extraneous rows
-            condition_parts.append('%s.iteration = %s.iteration' %
-                                   (_quote_name(alias),
-                                    _quote_name(first_alias)))
+            iteration_query = iteration_query.extra(
+                    where=['%s.iteration = %s.iteration'
+                           % (_quote_name(alias), _quote_name(first_alias))])
 
-        condition = ' and '.join(condition_parts)
-        # add a join to IterationResult
-        query_set = self.add_join(query_set, table_name, join_key='test_idx',
-                                  join_condition=condition, alias=alias)
+        query_set = Test.objects.join_custom_field(query_set, iteration_query,
+                                                   alias, left_join=False)
         # select the iteration value and index for this join
         query_set = self._add_select_value(query_set, alias)
         if not first_alias:
diff --git a/frontend/tko/rpc_interface_unittest.py b/frontend/tko/rpc_interface_unittest.py
index bbae95a..248ea83 100644
--- a/frontend/tko/rpc_interface_unittest.py
+++ b/frontend/tko/rpc_interface_unittest.py
@@ -465,7 +465,7 @@
         self.assertEquals(len(tests), 3)
 
         self.assertEquals(tests[0]['label_testlabel1'], 'testlabel1')
-        self.assert_(tests[0]['label_testlabel2'], 'testlabel2')
+        self.assertEquals(tests[0]['label_testlabel2'], 'testlabel2')
 
         for index in (1, 2):
             self.assertEquals(tests[index]['label_testlabel1'], None)