summaryrefslogtreecommitdiff
path: root/test/ext/mypy/test_overloads.py
blob: 4a258a00bf527a4c6ec9053e112f1eabf710b0c0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from sqlalchemy import testing
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import CoreExecuteOptionsParameter
from sqlalchemy.ext.asyncio.engine import AsyncConnection
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
from sqlalchemy.orm.query import Query
from sqlalchemy.sql.base import Executable
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import eq_

engine_execution_options = {
    "compiled_cache": "Optional[CompiledCacheType]",
    "logging_token": "str",
    "isolation_level": "IsolationLevel",
    "insertmanyvalues_page_size": "int",
    "schema_translate_map": "Optional[SchemaTranslateMapType]",
    "opt": "Any",
}
core_execution_options = {
    **engine_execution_options,
    "no_parameters": "bool",
    "stream_results": "bool",
    "max_row_buffer": "int",
    "yield_per": "int",
}

orm_dql_execution_options = {
    **core_execution_options,
    "populate_existing": "bool",
    "autoflush": "bool",
}

orm_dml_execution_options = {
    "synchronize_session": "SynchronizeSessionArgument",
    "dml_strategy": "DMLStrategyArgument",
    "is_delete_using": "bool",
    "is_update_from": "bool",
}

orm_execution_options = {
    **orm_dql_execution_options,
    **orm_dml_execution_options,
}


class OverloadTest(fixtures.TestBase):
    # NOTE: get_overloads is python 3.11. typing_extensions implements it
    # but for it to work the typing_extensions overload needs to be use and
    # it can only be imported directly from typing_extensions in all modules
    # that use it otherwise flake8 (pyflakes actually) will flag it with F811
    __requires__ = ("python311",)

    @testing.combinations(
        (Engine, engine_execution_options),
        (Connection, core_execution_options),
        (AsyncEngine, engine_execution_options),
        (AsyncConnection, core_execution_options),
        (Query, orm_dql_execution_options),
        (Executable, orm_execution_options),
    )
    def test_methods(self, class_, expected):
        from typing import get_overloads

        overloads = get_overloads(getattr(class_, "execution_options"))
        eq_(len(overloads), 2)
        annotations = overloads[0].__annotations__.copy()
        annotations.pop("self", None)
        annotations.pop("return", None)
        eq_(annotations, expected)
        annotations = overloads[1].__annotations__.copy()
        annotations.pop("self", None)
        annotations.pop("return", None)
        eq_(annotations, {"opt": "Any"})

    @testing.combinations(
        (CoreExecuteOptionsParameter, core_execution_options),
        (OrmExecuteOptionsParameter, orm_execution_options),
    )
    def test_typed_dicts(self, typ, expected):
        # we currently expect these to be union types with first entry
        # is the typed dict

        typed_dict = typ.__args__[0]

        expected = dict(expected)
        expected.pop("opt")

        assert_annotations = {
            key: fwd_ref.__forward_arg__
            for key, fwd_ref in typed_dict.__annotations__.items()
        }
        eq_(assert_annotations, expected)