summaryrefslogtreecommitdiff
path: root/test/ext/mypy/plugin_files/dataclasses_workaround.py
blob: 8ad69dbd0f462717759cbb223e43a29425045558 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# PYTHON_VERSION>=3.7

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING

from sqlalchemy import Column
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship

mapper_registry: registry = registry()


@mapper_registry.mapped
@dataclass
class User:
    __table__ = Table(
        "user",
        mapper_registry.metadata,
        Column("id", Integer, primary_key=True),
        Column("name", String(50)),
        Column("fullname", String(50)),
        Column("nickname", String(12)),
    )
    id: int = field(init=False)
    name: Optional[str] = None
    fullname: Optional[str] = None
    nickname: Optional[str] = None
    addresses: List[Address] = field(default_factory=list)

    if TYPE_CHECKING:
        _mypy_mapped_attrs = [id, name, fullname, nickname, addresses]

    __mapper_args__: Dict[str, Any] = {
        "properties": {"addresses": relationship("Address")}
    }


@mapper_registry.mapped
@dataclass
class Address:
    __table__ = Table(
        "address",
        mapper_registry.metadata,
        Column("id", Integer, primary_key=True),
        Column("user_id", Integer, ForeignKey("user.id")),
        Column("email_address", String(50)),
    )

    id: int = field(init=False)
    user_id: int = field(init=False)
    email_address: Optional[str] = None

    if TYPE_CHECKING:
        _mypy_mapped_attrs = [id, user_id, email_address]


stmt1 = select(User.name).where(User.id.in_([1, 2, 3]))
stmt2 = select(Address).where(Address.email_address.contains(["foo"]))