summaryrefslogtreecommitdiff
path: root/astroid/_ast.py
blob: c134ae70e714536a36fda59ea0ac5faeb2826557 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt

from __future__ import annotations

import ast
from typing import NamedTuple

from astroid.const import Context


class FunctionType(NamedTuple):
    argtypes: list[ast.expr]
    returns: ast.expr


class ParserModule(NamedTuple):
    unary_op_classes: dict[type[ast.unaryop], str]
    cmp_op_classes: dict[type[ast.cmpop], str]
    bool_op_classes: dict[type[ast.boolop], str]
    bin_op_classes: dict[type[ast.operator], str]
    context_classes: dict[type[ast.expr_context], Context]

    def parse(self, string: str, type_comments: bool = True) -> ast.Module:
        return ast.parse(string, type_comments=type_comments)


def parse_function_type_comment(type_comment: str) -> FunctionType | None:
    """Given a correct type comment, obtain a FunctionType object."""
    func_type = ast.parse(type_comment, "<type_comment>", "func_type")  # type: ignore[attr-defined]
    return FunctionType(argtypes=func_type.argtypes, returns=func_type.returns)


def get_parser_module(type_comments: bool = True) -> ParserModule:
    unary_op_classes = _unary_operators_from_module()
    cmp_op_classes = _compare_operators_from_module()
    bool_op_classes = _bool_operators_from_module()
    bin_op_classes = _binary_operators_from_module()
    context_classes = _contexts_from_module()

    return ParserModule(
        unary_op_classes,
        cmp_op_classes,
        bool_op_classes,
        bin_op_classes,
        context_classes,
    )


def _unary_operators_from_module() -> dict[type[ast.unaryop], str]:
    return {ast.UAdd: "+", ast.USub: "-", ast.Not: "not", ast.Invert: "~"}


def _binary_operators_from_module() -> dict[type[ast.operator], str]:
    return {
        ast.Add: "+",
        ast.BitAnd: "&",
        ast.BitOr: "|",
        ast.BitXor: "^",
        ast.Div: "/",
        ast.FloorDiv: "//",
        ast.MatMult: "@",
        ast.Mod: "%",
        ast.Mult: "*",
        ast.Pow: "**",
        ast.Sub: "-",
        ast.LShift: "<<",
        ast.RShift: ">>",
    }


def _bool_operators_from_module() -> dict[type[ast.boolop], str]:
    return {ast.And: "and", ast.Or: "or"}


def _compare_operators_from_module() -> dict[type[ast.cmpop], str]:
    return {
        ast.Eq: "==",
        ast.Gt: ">",
        ast.GtE: ">=",
        ast.In: "in",
        ast.Is: "is",
        ast.IsNot: "is not",
        ast.Lt: "<",
        ast.LtE: "<=",
        ast.NotEq: "!=",
        ast.NotIn: "not in",
    }


def _contexts_from_module() -> dict[type[ast.expr_context], Context]:
    return {
        ast.Load: Context.Load,
        ast.Store: Context.Store,
        ast.Del: Context.Del,
        ast.Param: Context.Store,
    }