summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/oracle/types.py
blob: 1c252ab897b85bc6de5b495535d7c9d5090c4712 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors

from ... import exc
from ...sql import sqltypes
from ...types import NVARCHAR
from ...types import VARCHAR


class RAW(sqltypes._Binary):
    __visit_name__ = "RAW"


OracleRaw = RAW


class NCLOB(sqltypes.Text):
    __visit_name__ = "NCLOB"


class VARCHAR2(VARCHAR):
    __visit_name__ = "VARCHAR2"


NVARCHAR2 = NVARCHAR


class NUMBER(sqltypes.Numeric, sqltypes.Integer):
    __visit_name__ = "NUMBER"

    def __init__(self, precision=None, scale=None, asdecimal=None):
        if asdecimal is None:
            asdecimal = bool(scale and scale > 0)

        super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)

    def adapt(self, impltype):
        ret = super().adapt(impltype)
        # leave a hint for the DBAPI handler
        ret._is_oracle_number = True
        return ret

    @property
    def _type_affinity(self):
        if bool(self.scale and self.scale > 0):
            return sqltypes.Numeric
        else:
            return sqltypes.Integer


class FLOAT(sqltypes.FLOAT):
    """Oracle FLOAT.

    This is the same as :class:`_sqltypes.FLOAT` except that
    an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
    parameter is accepted, and
    the :paramref:`_sqltypes.Float.precision` parameter is not accepted.

    Oracle FLOAT types indicate precision in terms of "binary precision", which
    defaults to 126. For a REAL type, the value is 63. This parameter does not
    cleanly map to a specific number of decimal places but is roughly
    equivalent to the desired number of decimal places divided by 0.3103.

    .. versionadded:: 2.0

    """

    __visit_name__ = "FLOAT"

    def __init__(
        self,
        binary_precision=None,
        asdecimal=False,
        decimal_return_scale=None,
    ):
        r"""
        Construct a FLOAT

        :param binary_precision: Oracle binary precision value to be rendered
         in DDL. This may be approximated to the number of decimal characters
         using the formula "decimal precision = 0.30103 * binary precision".
         The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.

        :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`

        :param decimal_return_scale: See
         :paramref:`_sqltypes.Float.decimal_return_scale`

        """
        super().__init__(
            asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
        )
        self.binary_precision = binary_precision


class BINARY_DOUBLE(sqltypes.Float):
    __visit_name__ = "BINARY_DOUBLE"


class BINARY_FLOAT(sqltypes.Float):
    __visit_name__ = "BINARY_FLOAT"


class BFILE(sqltypes.LargeBinary):
    __visit_name__ = "BFILE"


class LONG(sqltypes.Text):
    __visit_name__ = "LONG"


class _OracleDateLiteralRender:
    def _literal_processor_datetime(self, dialect):
        def process(value):
            if value is not None:
                if getattr(value, "microsecond", None):
                    value = (
                        f"""TO_TIMESTAMP"""
                        f"""('{value.isoformat().replace("T", " ")}', """
                        """'YYYY-MM-DD HH24:MI:SS.FF')"""
                    )
                else:
                    value = (
                        f"""TO_DATE"""
                        f"""('{value.isoformat().replace("T", " ")}', """
                        """'YYYY-MM-DD HH24:MI:SS')"""
                    )
            return value

        return process

    def _literal_processor_date(self, dialect):
        def process(value):
            if value is not None:
                if getattr(value, "microsecond", None):
                    value = (
                        f"""TO_TIMESTAMP"""
                        f"""('{value.isoformat().split("T")[0]}', """
                        """'YYYY-MM-DD')"""
                    )
                else:
                    value = (
                        f"""TO_DATE"""
                        f"""('{value.isoformat().split("T")[0]}', """
                        """'YYYY-MM-DD')"""
                    )
            return value

        return process


class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
    """Provide the oracle DATE type.

    This type has no special Python behavior, except that it subclasses
    :class:`_types.DateTime`; this is to suit the fact that the Oracle
    ``DATE`` type supports a time value.

    """

    __visit_name__ = "DATE"

    def literal_processor(self, dialect):
        return self._literal_processor_datetime(dialect)

    def _compare_type_affinity(self, other):
        return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)


class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
    def literal_processor(self, dialect):
        return self._literal_processor_date(dialect)


class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
    __visit_name__ = "INTERVAL"

    def __init__(self, day_precision=None, second_precision=None):
        """Construct an INTERVAL.

        Note that only DAY TO SECOND intervals are currently supported.
        This is due to a lack of support for YEAR TO MONTH intervals
        within available DBAPIs.

        :param day_precision: the day precision value.  this is the number of
          digits to store for the day field.  Defaults to "2"
        :param second_precision: the second precision value.  this is the
          number of digits to store for the fractional seconds field.
          Defaults to "6".

        """
        self.day_precision = day_precision
        self.second_precision = second_precision

    @classmethod
    def _adapt_from_generic_interval(cls, interval):
        return INTERVAL(
            day_precision=interval.day_precision,
            second_precision=interval.second_precision,
        )

    @property
    def _type_affinity(self):
        return sqltypes.Interval

    def as_generic(self, allow_nulltype=False):
        return sqltypes.Interval(
            native=True,
            second_precision=self.second_precision,
            day_precision=self.day_precision,
        )


class TIMESTAMP(sqltypes.TIMESTAMP):
    """Oracle implementation of ``TIMESTAMP``, which supports additional
    Oracle-specific modes

    .. versionadded:: 2.0

    """

    def __init__(self, timezone: bool = False, local_timezone: bool = False):
        """Construct a new :class:`_oracle.TIMESTAMP`.

        :param timezone: boolean.  Indicates that the TIMESTAMP type should
         use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype.

        :param local_timezone: boolean.  Indicates that the TIMESTAMP type
         should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype.


        """
        if timezone and local_timezone:
            raise exc.ArgumentError(
                "timezone and local_timezone are mutually exclusive"
            )
        super().__init__(timezone=timezone)
        self.local_timezone = local_timezone


class ROWID(sqltypes.TypeEngine):
    """Oracle ROWID type.

    When used in a cast() or similar, generates ROWID.

    """

    __visit_name__ = "ROWID"


class _OracleBoolean(sqltypes.Boolean):
    def get_dbapi_type(self, dbapi):
        return dbapi.NUMBER