summaryrefslogtreecommitdiff
path: root/django/db/models/sql/datastructures.py
blob: dadd7c063d706d2fd9a234a066558b143290c77c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""
Useful auxiliary data structures for query construction. Not useful outside
the SQL domain.
"""
import warnings

from django.core.exceptions import FullResultSet
from django.db.models.sql.constants import INNER, LOUTER
from django.utils.deprecation import RemovedInDjango60Warning


class MultiJoin(Exception):
    """
    Used by join construction code to indicate the point at which a
    multi-valued join was attempted (if the caller wants to treat that
    exceptionally).
    """

    def __init__(self, names_pos, path_with_names):
        self.level = names_pos
        # The path travelled, this includes the path to the multijoin.
        self.names_with_path = path_with_names


class Empty:
    pass


class Join:
    """
    Used by sql.Query and sql.SQLCompiler to generate JOIN clauses into the
    FROM entry. For example, the SQL generated could be
        LEFT OUTER JOIN "sometable" T1
        ON ("othertable"."sometable_id" = "sometable"."id")

    This class is primarily used in Query.alias_map. All entries in alias_map
    must be Join compatible by providing the following attributes and methods:
        - table_name (string)
        - table_alias (possible alias for the table, can be None)
        - join_type (can be None for those entries that aren't joined from
          anything)
        - parent_alias (which table is this join's parent, can be None similarly
          to join_type)
        - as_sql()
        - relabeled_clone()
    """

    def __init__(
        self,
        table_name,
        parent_alias,
        table_alias,
        join_type,
        join_field,
        nullable,
        filtered_relation=None,
    ):
        # Join table
        self.table_name = table_name
        self.parent_alias = parent_alias
        # Note: table_alias is not necessarily known at instantiation time.
        self.table_alias = table_alias
        # LOUTER or INNER
        self.join_type = join_type
        # A list of 2-tuples to use in the ON clause of the JOIN.
        # Each 2-tuple will create one join condition in the ON clause.
        if hasattr(join_field, "get_joining_fields"):
            self.join_fields = join_field.get_joining_fields()
            self.join_cols = tuple(
                (lhs_field.column, rhs_field.column)
                for lhs_field, rhs_field in self.join_fields
            )
        else:
            warnings.warn(
                "The usage of get_joining_columns() in Join is deprecated. Implement "
                "get_joining_fields() instead.",
                RemovedInDjango60Warning,
            )
            self.join_fields = None
            self.join_cols = join_field.get_joining_columns()
        # Along which field (or ForeignObjectRel in the reverse join case)
        self.join_field = join_field
        # Is this join nullabled?
        self.nullable = nullable
        self.filtered_relation = filtered_relation

    def as_sql(self, compiler, connection):
        """
        Generate the full
           LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
        clause for this join.
        """
        join_conditions = []
        params = []
        qn = compiler.quote_name_unless_alias
        qn2 = connection.ops.quote_name
        # Add a join condition for each pair of joining columns.
        # RemovedInDjango60Warning: when the depraction ends, replace with:
        # for lhs, rhs in self.join_field:
        join_fields = self.join_fields or self.join_cols
        for lhs, rhs in join_fields:
            if isinstance(lhs, str):
                # RemovedInDjango60Warning: when the depraction ends, remove
                # the branch for strings.
                lhs_full_name = "%s.%s" % (qn(self.parent_alias), qn2(lhs))
                rhs_full_name = "%s.%s" % (qn(self.table_alias), qn2(rhs))
            else:
                lhs, rhs = connection.ops.prepare_join_on_clause(
                    self.parent_alias, lhs, self.table_alias, rhs
                )
                lhs_sql, lhs_params = compiler.compile(lhs)
                lhs_full_name = lhs_sql % lhs_params
                rhs_sql, rhs_params = compiler.compile(rhs)
                rhs_full_name = rhs_sql % rhs_params
            join_conditions.append(f"{lhs_full_name} = {rhs_full_name}")

        # Add a single condition inside parentheses for whatever
        # get_extra_restriction() returns.
        extra_cond = self.join_field.get_extra_restriction(
            self.table_alias, self.parent_alias
        )
        if extra_cond:
            extra_sql, extra_params = compiler.compile(extra_cond)
            join_conditions.append("(%s)" % extra_sql)
            params.extend(extra_params)
        if self.filtered_relation:
            try:
                extra_sql, extra_params = compiler.compile(self.filtered_relation)
            except FullResultSet:
                pass
            else:
                join_conditions.append("(%s)" % extra_sql)
                params.extend(extra_params)
        if not join_conditions:
            # This might be a rel on the other end of an actual declared field.
            declared_field = getattr(self.join_field, "field", self.join_field)
            raise ValueError(
                "Join generated an empty ON clause. %s did not yield either "
                "joining columns or extra restrictions." % declared_field.__class__
            )
        on_clause_sql = " AND ".join(join_conditions)
        alias_str = (
            "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
        )
        sql = "%s %s%s ON (%s)" % (
            self.join_type,
            qn(self.table_name),
            alias_str,
            on_clause_sql,
        )
        return sql, params

    def relabeled_clone(self, change_map):
        new_parent_alias = change_map.get(self.parent_alias, self.parent_alias)
        new_table_alias = change_map.get(self.table_alias, self.table_alias)
        if self.filtered_relation is not None:
            filtered_relation = self.filtered_relation.clone()
            filtered_relation.path = [
                change_map.get(p, p) for p in self.filtered_relation.path
            ]
        else:
            filtered_relation = None
        return self.__class__(
            self.table_name,
            new_parent_alias,
            new_table_alias,
            self.join_type,
            self.join_field,
            self.nullable,
            filtered_relation=filtered_relation,
        )

    @property
    def identity(self):
        return (
            self.__class__,
            self.table_name,
            self.parent_alias,
            self.join_field,
            self.filtered_relation,
        )

    def __eq__(self, other):
        if not isinstance(other, Join):
            return NotImplemented
        return self.identity == other.identity

    def __hash__(self):
        return hash(self.identity)

    def equals(self, other):
        # Ignore filtered_relation in equality check.
        return self.identity[:-1] == other.identity[:-1]

    def demote(self):
        new = self.relabeled_clone({})
        new.join_type = INNER
        return new

    def promote(self):
        new = self.relabeled_clone({})
        new.join_type = LOUTER
        return new


class BaseTable:
    """
    The BaseTable class is used for base table references in FROM clause. For
    example, the SQL "foo" in
        SELECT * FROM "foo" WHERE somecond
    could be generated by this class.
    """

    join_type = None
    parent_alias = None
    filtered_relation = None

    def __init__(self, table_name, alias):
        self.table_name = table_name
        self.table_alias = alias

    def as_sql(self, compiler, connection):
        alias_str = (
            "" if self.table_alias == self.table_name else (" %s" % self.table_alias)
        )
        base_sql = compiler.quote_name_unless_alias(self.table_name)
        return base_sql + alias_str, []

    def relabeled_clone(self, change_map):
        return self.__class__(
            self.table_name, change_map.get(self.table_alias, self.table_alias)
        )

    @property
    def identity(self):
        return self.__class__, self.table_name, self.table_alias

    def __eq__(self, other):
        if not isinstance(other, BaseTable):
            return NotImplemented
        return self.identity == other.identity

    def __hash__(self):
        return hash(self.identity)

    def equals(self, other):
        return self.identity == other.identity