diff options
-rw-r--r-- | Cython/Compiler/ExprNodes.py | 6 | ||||
-rw-r--r-- | Cython/Compiler/FlowControl.py | 77 | ||||
-rw-r--r-- | Cython/Compiler/MatchCaseNodes.py | 2245 | ||||
-rw-r--r-- | Cython/Compiler/Nodes.py | 11 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.pxd | 1 | ||||
-rw-r--r-- | Cython/Compiler/ParseTreeTransforms.py | 44 | ||||
-rw-r--r-- | Cython/Compiler/Parsing.pxd | 2 | ||||
-rw-r--r-- | Cython/Compiler/Parsing.py | 520 | ||||
-rw-r--r-- | Cython/Compiler/Visitor.py | 3 | ||||
-rw-r--r-- | Cython/TestUtils.py | 24 | ||||
-rw-r--r-- | Cython/Utility/MatchCase.c | 907 | ||||
-rw-r--r-- | Cython/Utility/MatchCase_Cy.pyx | 12 | ||||
-rw-r--r-- | Tools/ci-run.sh | 2 | ||||
-rw-r--r-- | test-requirements-pypy27.txt | 1 | ||||
-rw-r--r-- | tests/run/extra_patma.pyx | 173 | ||||
-rw-r--r-- | tests/run/extra_patma_py.py | 126 | ||||
-rw-r--r-- | tests/run/test_patma.py | 3198 |
17 files changed, 7324 insertions, 28 deletions
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 57adc791e..c4e7679cd 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -7225,6 +7225,7 @@ class AttributeNode(ExprNode): is_memslice_transpose = False is_special_lookup = False is_py_attr = 0 + dont_mangle_private_names = False # skip mangling class.__attr names def as_cython_attribute(self): if (isinstance(self.obj, NameNode) and @@ -7570,8 +7571,9 @@ class AttributeNode(ExprNode): def analyse_as_python_attribute(self, env, obj_type=None, immutable_obj=False): if obj_type is None: obj_type = self.obj.type - # mangle private '__*' Python attributes used inside of a class - self.attribute = env.mangle_class_private_name(self.attribute) + if not self.dont_mangle_private_names: + # mangle private '__*' Python attributes used inside of a class + self.attribute = env.mangle_class_private_name(self.attribute) self.member = self.attribute self.type = py_object_type self.is_py_attr = 1 diff --git a/Cython/Compiler/FlowControl.py b/Cython/Compiler/FlowControl.py index 294bce9ee..d37744609 100644 --- a/Cython/Compiler/FlowControl.py +++ b/Cython/Compiler/FlowControl.py @@ -931,6 +931,83 @@ class ControlFlowAnalysis(CythonTransform): self.flow.block = None return node + def visit_MatchNode(self, node): + # https://peps.python.org/pep-0634/#side-effects-and-undefined-behavior + # The specification of structural pattern matching gives + # a reasonable amount of freedom about when a binding takes + # effect (especially for matches that fail after a few completed + # steps). + # To make it easier to reason about which variables around bound, + # Cython delays the binding until the entire statement has been + # evaluated, but before the guard is evaluated. + # Therefore an unguarded match statement is equivalent to + # + # if case1: + # bound_variable = something + # block + # elif case2... + # + # while a guarded match statement is equivalent to + # + # if case1: + # bound_variable = something + # if guard: + # block + # goto end + # if case2: ... + # + # Note also that some of the cases may have just been transformed into + # if blocks + from .MatchCaseNodes import MatchCaseNode, SubstitutedMatchCaseNode + + self._visit(node.subject) + + next_block = self.flow.newblock() + orig_parent = self.flow.block + guard_block = None + for case in node.cases: + if isinstance(case, MatchCaseNode): + case_block = self.flow.nextblock(orig_parent) + if guard_block: + guard_block.add_child(case_block) + self._visit(case.pattern) + self.flow.nextblock() + if case.target_assignments: + self._visit(case.target_assignments) + if case.guard: + guard_block = self.flow.nextblock() + self._visit(case.guard) + else: + guard_block = None + self.flow.nextblock() + self._visit(case.body) + if self.flow.block: + self.flow.block.add_child(next_block) + elif isinstance(case, SubstitutedMatchCaseNode): + self.flow.nextblock() + if guard_block: + guard_block.add_child(self.flow.block) + guard_block = None + self._visit(case.body) + orig_parent = self.flow.block + else: + assert False, case + + if orig_parent is not None: + orig_parent.add_child(next_block) + if next_block.parents: + self.flow.block = next_block + else: + self.flow.block = None + return node + + def visit_PatternNode(self, node): + # avoid visiting anything that might be a target (since they're + # handled elsewhere) + self.visitchildren(node, attrs=None, + exclude=["as_targets", "target", "double_star_capture_target"]) + return node + def visit_AssertStatNode(self, node): """Essentially an if-condition that wraps a RaiseStatNode. """ diff --git a/Cython/Compiler/MatchCaseNodes.py b/Cython/Compiler/MatchCaseNodes.py new file mode 100644 index 000000000..a7e246bcb --- /dev/null +++ b/Cython/Compiler/MatchCaseNodes.py @@ -0,0 +1,2245 @@ +# Nodes for structural pattern matching. +# +# In a separate file because they're unlikely to be useful for much else. + +from .Nodes import Node, StatNode, ErrorNode +from .Errors import error, local_errors, report_error +from . import Nodes, ExprNodes, PyrexTypes, Builtin +from .Code import UtilityCode, TempitaUtilityCode +from .Options import copy_inherited_directives +from contextlib import contextmanager + + +class MatchNode(StatNode): + """ + subject ExprNode The expression to be matched + cases [MatchCaseBaseNode] list of cases + + sequence_mapping_temp None or AssignableTempNode an int temp to store result of sequence/mapping tests + sequence_mapping_temp is an optimization because determining whether something is a sequence or mapping + is slow on Python <3.10. It should be deleted once that's the lowest version supported + """ + + child_attrs = ["subject", "cases"] + + subject_clonenode = None # set to a value if we require a temp + sequence_mapping_temp = None + + def validate_irrefutable(self): + found_irrefutable_case = None + for case in self.cases: + if isinstance(case, ErrorNode): + # This validation happens before error nodes have been + # transformed into actual errors, so we need to ignore them + continue + if found_irrefutable_case: + error( + found_irrefutable_case.pos, + ( + "%s makes remaining patterns unreachable" + % found_irrefutable_case.pattern.irrefutable_message() + ), + ) + break + if case.is_irrefutable(): + found_irrefutable_case = case + case.validate_irrefutable() + + def refactor_cases(self): + # An early transform - changes cases that can be represented as + # a simple if/else statement into them (giving them maximum chance + # to be optimized by the existing mechanisms). Leaves other cases + # unchanged + from .ExprNodes import CloneNode, ProxyNode, NameNode + + self.subject = ProxyNode(self.subject) + subject = self.subject_clonenode = CloneNode(self.subject) + current_if_statement = None + for n, c in enumerate(self.cases + [None]): # The None is dummy at the end + if c is not None and c.is_simple_value_comparison(): + body = SubstitutedIfStatListNode( + c.body.pos, stats=c.body.stats, match_node=self + ) + if_clause = Nodes.IfClauseNode( + c.pos, + condition=c.pattern.get_simple_comparison_node(subject), + body=body, + ) + assignments = c.pattern.generate_target_assignments(subject, None) + if assignments: + if_clause.body.stats.insert(0, assignments) + if not current_if_statement: + current_if_statement = Nodes.IfStatNode( + c.pos, if_clauses=[], else_clause=None + ) + current_if_statement.if_clauses.append(if_clause) + self.cases[n] = None # remove case + elif current_if_statement: + # this cannot be simplified, but previous case(s) were + self.cases[n - 1] = SubstitutedMatchCaseNode( + current_if_statement.pos, body=current_if_statement + ) + current_if_statement = None + # eliminate optimized cases + self.cases = [c for c in self.cases if c is not None] + + def analyse_declarations(self, env): + self.subject.analyse_declarations(env) + for c in self.cases: + c.analyse_case_declarations(self.subject_clonenode, env) + + def analyse_expressions(self, env): + sequence_mapping_count = 0 + for c in self.cases: + if c.is_sequence_or_mapping(): + sequence_mapping_count += 1 + if sequence_mapping_count >= 2: + self.sequence_mapping_temp = AssignableTempNode( + self.pos, PyrexTypes.c_uint_type + ) + self.sequence_mapping_temp.is_addressable = lambda: True + + self.subject = self.subject.analyse_expressions(env) + assert isinstance(self.subject, ExprNodes.ProxyNode) + if not self.subject.arg.is_literal: + self.subject.arg = self.subject.arg.coerce_to_temp(env) + subject = self.subject_clonenode.analyse_expressions(env) + self.cases = [ + c.analyse_case_expressions(subject, env, self.sequence_mapping_temp) + for c in self.cases + ] + self.cases = [c for c in self.cases if c is not None] + return self + + def generate_execution_code(self, code): + if self.sequence_mapping_temp: + self.sequence_mapping_temp.allocate(code) + code.putln( + "%s = 0; /* sequence/mapping test temp */" + % self.sequence_mapping_temp.result() + ) + # For things that are a sequence at compile-time it's difficult + # to avoid generating the sequence mapping temp. Therefore, silence + # an "unused error" + code.putln("(void)%s;" % self.sequence_mapping_temp.result()) + end_label = self.end_label = code.new_label() + if self.subject_clonenode: + self.subject.generate_evaluation_code(code) + for c in self.cases: + c.generate_execution_code(code, end_label) + if self.sequence_mapping_temp: + self.sequence_mapping_temp.release(code) + if code.label_used(end_label): + code.put_label(end_label) + if self.subject_clonenode: + self.subject.generate_disposal_code(code) + self.subject.free_temps(code) + + +class MatchCaseBaseNode(Node): + """ + Common base for a MatchCaseNode and a + substituted node + """ + + pass + + +class MatchCaseNode(Node): + """ + pattern PatternNode + body StatListNode + guard ExprNode or None + + generated: + target_assignments [ SingleAssignmentNodes ] + comp_node ExprNode that evaluates to bool + """ + + target_assignments = None + comp_node = None + child_attrs = ["pattern", "target_assignments", "comp_node", "guard", "body"] + + def is_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return True # value doesn't really matter + return self.pattern.is_irrefutable() and not self.guard + + def is_simple_value_comparison(self): + if self.guard: + return False + return self.pattern.is_simple_value_comparison() + + def validate_targets(self): + if isinstance(self.pattern, ErrorNode): + return + self.pattern.get_targets() + + def validate_irrefutable(self): + if isinstance(self.pattern, ErrorNode): + return + self.pattern.validate_irrefutable() + + def is_sequence_or_mapping(self): + return self.pattern.is_sequence_or_mapping() + + def analyse_case_declarations(self, subject_node, env): + self.pattern.analyse_declarations(env) + self.target_assignments = self.pattern.generate_target_assignments( + subject_node, env + ) + if self.target_assignments: + self.target_assignments.analyse_declarations(env) + if self.guard: + self.guard.analyse_declarations(env) + self.body.analyse_declarations(env) + + def analyse_case_expressions(self, subject_node, env, sequence_mapping_temp): + with local_errors(True) as errors: + self.pattern = self.pattern.analyse_pattern_expressions(env, sequence_mapping_temp) + self.comp_node = self.pattern.get_comparison_node(subject_node, sequence_mapping_temp) + self.comp_node = self.comp_node.analyse_types(env) + + if self.comp_node and self.comp_node.is_literal: + self.comp_node.calculate_constant_result() + if not self.comp_node.constant_result: + # we know this pattern can't succeed. Ignore any errors and return None + return None + for error in errors: + report_error(error) + + self.comp_node = self.comp_node.coerce_to_boolean(env).coerce_to_simple(env) + + if self.target_assignments: + self.target_assignments = self.target_assignments.analyse_expressions(env) + if self.guard: + self.guard = self.guard.analyse_temp_boolean_expression(env) + self.body = self.body.analyse_expressions(env) + return self + + def generate_execution_code(self, code, end_label): + self.pattern.allocate_subject_temps(code) + self.comp_node.generate_evaluation_code(code) + + end_of_case_label = code.new_label() + + code.putln("if (!%s) { /* !pattern */" % self.comp_node.result()) + self.pattern.dispose_of_subject_temps(code) # failed, don't need the subjects + code.put_goto(end_of_case_label) + + code.putln("} else { /* pattern */") + self.comp_node.generate_disposal_code(code) + self.comp_node.free_temps(code) + if self.target_assignments: + self.target_assignments.generate_execution_code(code) + self.pattern.dispose_of_subject_temps(code) + self.pattern.release_subject_temps(code) # we're done with the subjects here + if self.guard: + self.guard.generate_evaluation_code(code) + code.putln("if (%s) { /* guard */" % self.guard.result()) + self.guard.generate_disposal_code(code) + self.guard.free_temps(code) + # body_insertion_point = code.insertion_point() + self.body.generate_execution_code(code) + if not self.body.is_terminator: + code.put_goto(end_label) + if self.guard: + code.putln("} /* guard */") + code.putln("} /* pattern */") + code.put_label(end_of_case_label) + + +class SubstitutedMatchCaseNode(MatchCaseBaseNode): + # body - Node - The (probably) if statement that it's replaced with + child_attrs = ["body"] + + def is_sequence_or_mapping(self): + return False + + def analyse_case_declarations(self, subject_node, env): + self.analyse_declarations(env) + + def analyse_declarations(self, env): + self.body.analyse_declarations(env) + + def analyse_case_expressions(self, subject_node, env, sequence_mapping_temp): + self.body = self.body.analyse_expressions(env) + return self + + def generate_execution_code(self, code, end_label): + self.body.generate_execution_code(code) + + +class PatternNode(Node): + """ + PatternNode is not an expression because + it does several things (evalutating a boolean expression, + assignment of targets), and they need to be done at different + times. + + as_targets [NameNode] any target assign by "as" + + Generated in analysis: + comp_node ExprNode node to evaluate for the pattern + + ---------------------------------------- + How these nodes are processed: + 1. During "analyse_declarations" PatternNode.generate_target_assignments + is called on the main PatternNode of the case. This calls its + sub-patterns generate_target_assignments recursively. + This creates a StatListNode that is held by the + MatchCaseNode. + 2. In the "analyse_expressions" phases, the MatchCaseNode calls + PatternNode.analyse_pattern_expressions, which calls its + sub-pattern recursively. + 3. At the end of the "analyse_expressions" stage the MatchCaseNode + class PatternNode.get_comparison_node (which calls + PatternNode.get_comparison_node for its sub-patterns). This + returns an ExprNode which can be evaluated to determine if the + pattern has matched. + While generating the comparison we try quite hard not to + analyse it until right at the end, because otherwise it'll lead + to a lot of repeated work for deeply nested patterns. + 4. In the code generation stage, PatternNodes hardly generate any + code themselves. However, they do set up whatever temps they + need (mainly for sub-pattern subjects), with "allocate_subject_temps", + "release_subject_temps", and "dispose_of_subject_temps" (which + they also call recursively on their sub-patterns) + """ + + # useful for type tests + is_match_value_pattern = False + is_match_and_assign_pattern = False + + child_attrs = ["as_targets"] + + def __init__(self, pos, **kwds): + if "as_targets" not in kwds: + kwds["as_targets"] = [] + super(PatternNode, self).__init__(pos, **kwds) + + def is_irrefutable(self): + return False + + def is_sequence_or_mapping(self): + """ + Used for determining whether to allocate a sequence_mapping_temp. + + An OrPattern containing at least one also returns True + """ + return False + + def get_targets(self): + targets = self.get_main_pattern_targets() + for target in self.as_targets: + self.add_target_to_targets(targets, target.name) + return targets + + def update_targets_with_targets(self, targets, other_targets): + for name in targets.intersection(other_targets): + error(self.pos, "multiple assignments to name '%s' in pattern" % name) + targets.update(other_targets) + + def add_target_to_targets(self, targets, target): + if target in targets: + error(self.pos, "multiple assignments to name '%s in pattern" % target) + targets.add(target) + + def get_main_pattern_targets(self): + # exclude "as" target + raise NotImplementedError + + def is_simple_value_comparison(self): + # Can this be converted to an "if ... elif: ..." statement? + # Only worth doing to take advantage of things like SwitchTransform + # so there's little benefit on doing it too widely + return False + + def get_simple_comparison_node(self): + """ + Returns an ExprNode that can be used as the case in an if-statement + + Should only be called if is_simple_value_comparison() is True + """ + raise NotImplementedError + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + error(self.pos, "This type of pattern is not currently supported %s" % self) + raise NotImplementedError + + def validate_irrefutable(self): + for attr in self.child_attrs: + child = getattr(self, attr) + if child is not None and isinstance(child, PatternNode): + child.validate_irrefutable() + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + error(self.pos, "This type of pattern is not currently supported %s" % self) + raise NotImplementedError + + def generate_result_code(self, code): + pass + + def generate_target_assignments(self, subject_node, env): + # Generates the assignment code needed to initialize all the targets. + # Returns either a StatListNode or None + assignments = [] + for target in self.as_targets: + if self.is_match_value_pattern and self.value and self.value.is_simple(): + # in this case we can optimize slightly and just take the value + subject_node = self.value.clone_node() + assignments.append( + Nodes.SingleAssignmentNode( + target.pos, lhs=target.clone_node(), rhs=subject_node + ) + ) + assignments.extend( + self.generate_main_pattern_assignment_list(subject_node, env) + ) + if assignments: + return Nodes.StatListNode(self.pos, stats=assignments) + else: + return None + + def generate_main_pattern_assignment_list(self, subject_node, env): + # generates assignments for everything except the "as_target". + # Override in subclasses. + # Returns a list of Nodes + return [] + + def allocate_subject_temps(self, code): + pass # Implement in nodes that need it + + def release_subject_temps(self, code): + pass # Implement in nodes that need it + + def dispose_of_subject_temps(self, code): + pass # Implement in nodes that need it + + +class MatchValuePatternNode(PatternNode): + """ + value ExprNode + is_is_check bool Picks "is" or equality check + """ + + is_match_value_pattern = True + + child_attrs = PatternNode.child_attrs + ["value"] + + is_is_check = False + + def get_main_pattern_targets(self): + return set() + + def is_simple_value_comparison(self): + return True + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + # for this node the comparison and "simple" comparison are the same + return LazyCoerceToBool(self.pos, + arg=self.get_simple_comparison_node(subject_node) + ) + + def get_simple_comparison_node(self, subject_node): + op = "is" if self.is_is_check else "==" + return ExprNodes.PrimaryCmpNode( + self.pos, operator=op, operand1=subject_node, operand2=self.value + ) + + def analyse_declarations(self, env): + super(MatchValuePatternNode, self).analyse_declarations(env) + if self.value: + self.value.analyse_declarations(env) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + if self.value: + self.value = self.value.analyse_expressions(env) + return self + + +class MatchAndAssignPatternNode(PatternNode): + """ + target NameNode or None the target to assign to (None = wildcard) + is_star bool + """ + + target = None + is_star = False + is_match_and_assign_pattern = True + + child_attrs = PatternNode.child_attrs + ["target"] + + def is_irrefutable(self): + return True + + def irrefutable_message(self): + if self.target: + return "name capture '%s'" % self.target.name + else: + return "wildcard" + + def get_main_pattern_targets(self): + if self.target: + return {self.target.name} + else: + return set() + + def is_simple_value_comparison(self): + return self.is_irrefutable() # the comparison is to "True" + + def get_simple_comparison_node(self, subject_node): + assert self.is_simple_value_comparison() + return self.get_comparison_node(subject_node, None) + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + return ExprNodes.BoolNode(self.pos, value=True) + + def generate_main_pattern_assignment_list(self, subject_node, env): + if self.target: + return [ + Nodes.SingleAssignmentNode( + self.pos, lhs=self.target.clone_node(), rhs=subject_node + ) + ] + else: + return [] + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + return self # nothing to analyse + + +class OrPatternNode(PatternNode): + """ + alternatives list of PatternNodes + + generated: + which_alternative_temp - an integer temp node. 0 for failed; 1, 2... + identify the alternative that succeeded + """ + which_alternative_temp = None + sequence_mapping_temp = None # used in a similar way to MatchCaseNode, + # to avoid recalcutating if we're a sequence or mapping + + child_attrs = PatternNode.child_attrs + ["alternatives"] + + def get_first_irrefutable(self): + for alternative in self.alternatives: + if alternative.is_irrefutable(): + return alternative + return None + + def is_irrefutable(self): + return self.get_first_irrefutable() is not None + + def irrefutable_message(self): + return self.get_first_irrefutable().irrefutable_message() + + def is_sequence_or_mapping(self): + # this affects if the caller generates a temp for it. If so the + # this node can forward the temp to the relevant alternative + for a in self.alternatives: + if a.is_sequence_or_mapping(): + return True + return False + + def get_main_pattern_targets(self): + child_targets = None + for alternative in self.alternatives: + alternative_targets = alternative.get_targets() + if child_targets is not None and child_targets != alternative_targets: + error(self.pos, "alternative patterns bind different names") + child_targets = alternative_targets + return child_targets + + def validate_irrefutable(self): + super(OrPatternNode, self).validate_irrefutable() + found_irrefutable_case = None + for alternative in self.alternatives: + if found_irrefutable_case: + error( + found_irrefutable_case.pos, + ( + "%s makes remaining patterns unreachable" + % found_irrefutable_case.irrefutable_message() + ), + ) + break + if alternative.is_irrefutable(): + found_irrefutable_case = alternative + alternative.validate_irrefutable() + + def is_simple_value_comparison(self): + return all( + # it turns out to be hard to generate correct assignment code + # for or patterns with targets + a.is_simple_value_comparison() and not a.get_targets() + for a in self.alternatives + ) + + def is_really_simple_value_comparison(self): + # like is_simple_value_comparison but also doesn't have any targets + return (self.is_simple_value_comparison() and + all(not a.get_targets() for a in self.alternatives)) + + def get_simple_comparison_node(self, subject_node): + assert self.is_simple_value_comparison() + assert len(self.alternatives) >= 2, self.alternatives + checks = [] + for a in self.alternatives: + checks.append(a.get_simple_comparison_node(subject_node)) + if any(isinstance(ch, ExprNodes.BoolNode) and ch.value for ch in checks): + # handle the obvious very simple case + return ExprNodes.BoolNode(self.pos, value=True) + return generate_binop_tree_from_list(self.pos, "or", checks) + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + if self.is_really_simple_value_comparison(): + return self.get_simple_comparison_node(subject_node) + + cond_exprs = [] + for n, a in enumerate(self.alternatives, start=1): + a_test = a.get_comparison_node(subject_node, sequence_mapping_temp) + a_value = ExprNodes.IntNode(a.pos, value=str(n)) + if isinstance(a_test, ExprNodes.BoolNode) and a_test.value: + cond_exprs.append(a_value) + break # no point in going further + else: + cond_exprs.append( + ExprNodes.CondExprNode( + self.pos, + test = a_test, + true_val = a_value, + false_val = ExprNodes.IntNode(self.pos, value="0") + ) + ) + + expr = generate_binop_tree_from_list(self.pos, "or", cond_exprs) + + if self.which_alternative_temp: + expr = ExprNodes.AssignmentExpressionNode( + self.pos, + lhs = self.which_alternative_temp, + rhs = expr + ) + return LazyCoerceToBool(expr.pos, arg=expr) + + def analyse_declarations(self, env): + super(OrPatternNode, self).analyse_declarations(env) + for a in self.alternatives: + a.analyse_declarations(env) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + self.alternatives = [ + a.analyse_pattern_expressions(env, sequence_mapping_temp) + for a in self.alternatives + ] + if not sequence_mapping_temp: + sequence_mapping_count = 0 + for a in self.alternatives: + if a.is_sequence_or_mapping(): + sequence_mapping_count += 1 + if sequence_mapping_count >= 2: + self.sequence_mapping_temp = AssignableTempNode( + self.pos, PyrexTypes.c_uint_type + ) + self.sequence_mapping_temp.is_addressable = lambda: True + sequence_mapping_temp = self.sequence_mapping_temp + return self + + def generate_main_pattern_assignment_list(self, subject_node, env): + assignments = [] + ifclauses = [] + for n, a in enumerate(self.alternatives, start=1): + a_assignment = a.generate_target_assignments(subject_node, env) + if a_assignment: + if not self.which_alternative_temp: + self.which_alternative_temp = AssignableTempNode(self.pos, PyrexTypes.c_int_type) + # Switch code paths depending on which node gets assigned + ifclause = Nodes.IfClauseNode( + a.pos, + condition=ExprNodes.PrimaryCmpNode( + a.pos, + operator="==", + operand1=self.which_alternative_temp, + operand2=ExprNodes.IntNode(a.pos, value=str(n)) + ), + body = a_assignment + ) + ifclauses.append(ifclause) + if ifclauses: + assignments.append( + Nodes.IfStatNode( + self.pos, + if_clauses=ifclauses, + else_clause=None + ) + ) + + return assignments + + def allocate_subject_temps(self, code): + if self.sequence_mapping_temp: + self.sequence_mapping_temp.allocate(code) + code.putln( + "%s = 0; /* sequence/mapping test temp */" + % self.sequence_mapping_temp.result() + ) + # For things that are a sequence at compile-time it's difficult + # to avoid generating the sequence mapping temp. Therefore, silence + # an "unused error" + code.putln("(void)%s;" % self.sequence_mapping_temp.result()) + if self.which_alternative_temp: + self.which_alternative_temp.allocate(code) + for a in self.alternatives: + a.allocate_subject_temps(code) + + def release_subject_temps(self, code): + if self.sequence_mapping_temp: + self.sequence_mapping_temp.release(code) + if self.which_alternative_temp: + self.which_alternative_temp.release(code) + for a in self.alternatives: + a.release_subject_temps(code) + + def dispose_of_subject_temps(self, code): + if self.which_alternative_temp: + self.which_alternative_temp.generate_disposal_code(code) + if self.sequence_mapping_temp: + self.sequence_mapping_temp.generate_disposal_code(code) + for a in self.alternatives: + a.dispose_of_subject_temps(code) + + +class MatchSequencePatternNode(PatternNode): + """ + patterns list of PatternNodes + + generated: + subjects [TrackTypeTempNode] individual subsubjects can be assigned to these + """ + + subjects = None + needs_length_temp = False + + child_attrs = PatternNode.child_attrs + ["patterns"] + + Pyx_sequence_check_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, + [ + PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg( + "sequence_mapping_temp", + PyrexTypes.c_ptr_type(PyrexTypes.c_uint_type), + None, + ), + ], + exception_value="-1", + ) + + def is_sequence_or_mapping(self): + return True + + def __init__(self, pos, **kwds): + super(MatchSequencePatternNode, self).__init__(pos, **kwds) + self.length_temp = AssignableTempNode(self.pos, PyrexTypes.c_py_ssize_t_type) + + def get_main_pattern_targets(self): + targets = set() + star_count = 0 + for pattern in self.patterns: + if pattern.is_match_and_assign_pattern and pattern.is_star: + star_count += 1 + self.update_targets_with_targets(targets, pattern.get_targets()) + if star_count > 1: + error(self.pos, "multiple starred names in sequence pattern") + return targets + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + from .UtilNodes import TempResultFromStatNode, ResultRefNode + + test = None + assert getattr(self, "subject_temps", None) is not None + + seq_test = self.make_sequence_check(subject_node, sequence_mapping_temp) + if isinstance(seq_test, ExprNodes.BoolNode) and not seq_test.value: + return seq_test # no point in proceeding further! + + has_star = False + all_tests = [seq_test] + pattern_tests = [] + for n, pattern in enumerate(self.patterns): + if isinstance(pattern, MatchAndAssignPatternNode) and pattern.is_star: + has_star = True + self.needs_length_temp = True + + if self.subject_temps[n] is None: + # The subject has been identified as unneeded, so don't evaluate it + continue + p_test = pattern.get_comparison_node(self.subject_temps[n]) + + result_ref = ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type) + subject_assignment = Nodes.SingleAssignmentNode( + self.pos, + lhs=self.subject_temps[n], # the temp node + rhs=self.subjects[n], # the regular node + ) + test_assignment = Nodes.SingleAssignmentNode( + self.pos, lhs=result_ref, rhs=p_test + ) + stats = Nodes.StatListNode( + self.pos, stats=[subject_assignment, test_assignment] + ) + pattern_tests.append(TempResultFromStatNode(result_ref, stats)) + + min_length = len(self.patterns) + if has_star: + min_length -= 1 + # check whether we need a length call... + if not (self.patterns and len(self.patterns) == 1 and has_star): + length_call = self.make_length_call_node(subject_node) + + if length_call.is_literal and ( + (has_star and min_length < length_call.constant_result) + or (not has_star and min_length != length_call.constant_result) + ): + # definitely failed! + return ExprNodes.BoolNode(self.pos, value=False) + seq_len_test = ExprNodes.PrimaryCmpNode( + self.pos, + operator=">=" if has_star else "==", + operand1=length_call, + operand2=ExprNodes.IntNode(self.pos, value=str(min_length)), + ) + all_tests.append(seq_len_test) + else: + self.needs_length_temp = False + all_tests.extend(pattern_tests) + test = generate_binop_tree_from_list(self.pos, "and", all_tests) + return LazyCoerceToBool(test.pos, arg=test) + + def generate_subjects(self, subject_node, env): + assert self.subjects is None # not called twice + + star_idx = None + for n, pattern in enumerate(self.patterns): + if pattern.is_match_and_assign_pattern and pattern.is_star: + star_idx = n + if star_idx is None: + idxs = list(range(len(self.patterns))) + else: + fwd_idxs = list(range(star_idx)) + backward_idxs = list(range(star_idx - len(self.patterns) + 1, 0)) + star_idx = ( + fwd_idxs[-1] + 1 if fwd_idxs else None, + backward_idxs[0] if backward_idxs else None, + ) + idxs = fwd_idxs + [star_idx] + backward_idxs + + subjects = [] + for pattern, idx in zip(self.patterns, idxs): + indexer = self.make_indexing_node(pattern, subject_node, idx, env) + subjects.append(ExprNodes.ProxyNode(indexer) if indexer else None) + self.subjects = subjects + self.subject_temps = [ + None if p.is_irrefutable() else TrackTypeTempNode(self.pos, s) + for s, p in zip(self.subjects, self.patterns) + ] + + def generate_main_pattern_assignment_list(self, subject_node, env): + assignments = [] + self.generate_subjects(subject_node, env) + for subject_temp, subject, pattern in zip( + self.subject_temps, self.subjects, self.patterns + ): + needs_result_ref = False + if subject_temp is not None: + subject = subject_temp + else: + if subject is None: + assert not pattern.get_targets() + continue + elif not subject.is_literal or subject.is_temp: + from .UtilNodes import ResultRefNode, LetNode + + subject = ResultRefNode(subject) + needs_result_ref = True + p_assignments = pattern.generate_target_assignments(subject, env) + if needs_result_ref: + p_assignments = LetNode(subject, p_assignments) + else: + p_assignments = p_assignments + if p_assignments: + assignments.append(p_assignments) + return assignments + + def make_sequence_check(self, subject_node, sequence_mapping_temp): + # Note: the sequence check code is very quick on Python 3.10+ + # but potentially quite slow on lower versions (although should + # be medium quick for common types). It'd be nice to cache the + # results of it where it's been called on the same object + # multiple times. + # DW has decided that that's too complicated to implement + # for now. + utility_code = UtilityCode.load_cached("IsSequence", "MatchCase.c") + if sequence_mapping_temp is not None: + sequence_mapping_temp = ExprNodes.AmpersandNode( + self.pos, operand=sequence_mapping_temp + ) + else: + sequence_mapping_temp = ExprNodes.NullNode(self.pos) + call = ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_IsSequence", + self.Pyx_sequence_check_type, + utility_code=utility_code, + args=[subject_node, sequence_mapping_temp], + ) + + def type_check(type): + # type-check need not be perfect, it's an optimization + if type in [Builtin.list_type, Builtin.tuple_type]: + return True + if type.is_memoryviewslice or type.is_ctuple: + return True + if type in [ + Builtin.str_type, + Builtin.bytes_type, + Builtin.unicode_type, + Builtin.bytearray_type, + Builtin.dict_type, + Builtin.set_type, + ]: + # non-exhaustive list at this stage, but returning "False" is + # an optimization so it's allowed to be non-exchaustive + return False + if type.is_numeric or type.is_struct or type.is_enum: + # again, not exhaustive + return False + return None + + return StaticTypeCheckNode( + self.pos, arg=subject_node, fallback=call, check=type_check + ) + + def make_length_call_node(self, subject_node): + len_entry = Builtin.builtin_scope.lookup("len") + if subject_node.type.is_memoryviewslice: + len_call = ExprNodes.IndexNode( + self.pos, + base=ExprNodes.AttributeNode( + self.pos, obj=subject_node, attribute="shape" + ), + index=ExprNodes.IntNode(self.pos, value="0"), + ) + elif subject_node.type.is_ctuple: + len_call = ExprNodes.IntNode( + self.pos, value=str(len(subject_node.type.components)) + ) + else: + len_call = ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode(self.pos, name="len", entry=len_entry), + args=[subject_node], + ) + if self.needs_length_temp: + return ExprNodes.AssignmentExpressionNode( + self.pos, lhs=self.length_temp, rhs=len_call + ) + else: + return len_call + + def make_indexing_node(self, pattern, subject_node, idx, env): + if pattern.is_irrefutable() and not pattern.get_targets(): + # Nothing to do - index isn't used + return None + + def get_index_from_int(i): + if i is None: + return None + else: + int_node = ExprNodes.IntNode(pattern.pos, value=str(i)) + if i >= 0: + return int_node + else: + self.needs_length_temp = True + return ExprNodes.binop_node( + pattern.pos, + operator="+", + operand1=self.length_temp, + operand2=int_node, + ) + + if isinstance(idx, tuple): + start = get_index_from_int(idx[0]) + stop = get_index_from_int(idx[1]) + indexer = SliceToListNode( + pattern.pos, + base=subject_node, + start=start, + stop=stop, + length_node=self.length_temp if self.needs_length_temp else None, + ) + else: + indexer = CompilerDirectivesExprNode( + arg=ExprNodes.IndexNode( + pattern.pos, base=subject_node, index=get_index_from_int(idx) + ), + directives=copy_inherited_directives( + env.directives, boundscheck=False, wraparound=False + ), + ) + return indexer + + def analyse_declarations(self, env): + for p in self.patterns: + p.analyse_declarations(env) + return super(MatchSequencePatternNode, self).analyse_declarations(env) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + for n in range(len(self.subjects)): + if self.subjects[n]: + self.subjects[n] = self.subjects[n].analyse_types(env) + for n in range(len(self.patterns)): + self.patterns[n] = self.patterns[n].analyse_pattern_expressions(env, None) + return self + + def allocate_subject_temps(self, code): + if self.needs_length_temp: + self.length_temp.allocate(code) + for temp in self.subject_temps: + if temp is not None: + temp.allocate(code) + for pattern in self.patterns: + pattern.allocate_subject_temps(code) + + def release_subject_temps(self, code): + if self.needs_length_temp: + self.length_temp.release(code) + for temp in self.subject_temps: + if temp is not None: + temp.release(code) + for pattern in self.patterns: + pattern.release_subject_temps(code) + + def dispose_of_subject_temps(self, code): + if self.needs_length_temp: + code.put_xdecref_clear(self.length_temp.result(), self.length_temp.type) + for temp in self.subject_temps: + if temp is not None: + code.put_xdecref_clear(temp.result(), temp.type) + for pattern in self.patterns: + pattern.dispose_of_subject_temps(code) + + +class MatchMappingPatternNode(PatternNode): + """ + keys list of Literals or AttributeNodes + value_patterns list of PatternNodes of equal length to keys + double_star_capture_target NameNode or None + + needs_runtime_keycheck - bool - are there any keys which can only be resolved at runtime + subjects [temp nodes or None] individual subsubjects can be assigned to these + """ + + keys = [] + value_patterns = [] + double_star_capture_target = None + subject_temps = None + double_star_temp = None + + needs_runtime_keycheck = False + + child_attrs = PatternNode.child_attrs + [ + "keys", + "value_patterns", + "double_star_capture_target", + ] + + Pyx_mapping_check_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, + [ + PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg( + "sequence_mapping_temp", + PyrexTypes.c_ptr_type(PyrexTypes.c_uint_type), + None, + ), + ], + exception_value="-1", + ) + # lie about the types of keys for simplicity + Pyx_mapping_check_duplicates_type = PyrexTypes.CFuncType( + PyrexTypes.c_int_type, + [ + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), + ], + exception_value="-1", + ) + # lie about the types of keys and subjects for simplicity + Pyx_mapping_extract_subjects_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, + [ + PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None), + ], + exception_value="-1", + ) + Pyx_mapping_doublestar_type = PyrexTypes.CFuncType( + Builtin.dict_type, + [ + PyrexTypes.CFuncTypeArg("mapping", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("keys", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("nKeys", PyrexTypes.c_py_ssize_t_type, None), + ], + ) + + def is_sequence_or_mapping(self): + return True + + def get_main_pattern_targets(self): + targets = set() + for pattern in self.value_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) + if self.double_star_capture_target: + self.add_target_to_targets(targets, self.double_star_capture_target.name) + return targets + + def validate_keys(self): + # called after constant folding + seen_keys = set() + for k in self.keys: + if k.has_constant_result(): + value = k.constant_result + if k.is_string_literal: + value = repr(value) + if value in seen_keys: + error(k.pos, "mapping pattern checks duplicate key (%s)" % value) + seen_keys.add(value) + else: + self.needs_runtime_keycheck = True + + if self.keys: + # it's very useful to sort keys early so the literal keys + # come first + sorted_keys = sorted( + zip(self.keys, self.value_patterns), + key=lambda kvp: (not kvp[0].is_literal), + ) + self.keys, self.value_patterns = [list(l) for l in zip(*sorted_keys)] + + def analyse_declarations(self, env): + super(MatchMappingPatternNode, self).analyse_declarations(env) + self.validate_keys() + for k in self.keys: + k.analyse_declarations(env) + for vp in self.value_patterns: + vp.analyse_declarations(env) + if self.double_star_capture_target: + self.double_star_capture_target.analyse_declarations(env) + + def generate_subjects(self, subject_node, env): + assert self.subject_temps is None # already calculated + subject_temps = [] + for pattern in self.value_patterns: + if pattern.is_match_and_assign_pattern and not pattern.target: + subject_temps.append(None) + else: + subject_temps.append( + AssignableTempNode(pattern.pos, PyrexTypes.py_object_type) + ) + self.subject_temps = subject_temps + + def generate_main_pattern_assignment_list(self, subject_node, env): + self.generate_subjects(subject_node, env) + assignments = [] + for subject, pattern in zip(self.subject_temps, self.value_patterns): + p_assignments = pattern.generate_target_assignments(subject, env) + if p_assignments: + assignments.extend(p_assignments.stats) + if self.double_star_capture_target: + self.double_star_temp = AssignableTempNode(self.pos, Builtin.dict_type) + assignments.append( + Nodes.SingleAssignmentNode( + self.double_star_temp.pos, + lhs=self.double_star_capture_target, + rhs=self.double_star_temp, + ) + ) + return assignments + + def is_dict_type_check(self, type): + # Returns true if it's an exact dict, False if it's definitely not + # an exact dict, None if it might be + # type-check need not be perfect, it's an optimization + if type is Builtin.dict_type: + return True + if type in Builtin.builtin_types: + # all other builtin types aren't mappings (except DictProxyType, but + # Cython doesn't know about that) + return False + if not type.is_pyobject: + # for now any non-pyobject type is False + return False + return None + + def make_mapping_check(self, subject_node, sequence_mapping_temp): + # Note: the mapping check code is very quick on Python 3.10+ + # but potentially quite slow on lower versions (although should + # be medium quick for common types). It'd be nice to cache the + # results of it where it's been called on the same object + # multiple times. + # DW has decided that that's too complicated to implement + # for now. + utility_code = UtilityCode.load_cached("IsMapping", "MatchCase.c") + if sequence_mapping_temp is not None: + sequence_mapping_temp = ExprNodes.AmpersandNode( + self.pos, operand=sequence_mapping_temp + ) + else: + sequence_mapping_temp = ExprNodes.NullNode(self.pos) + call = ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_IsMapping", + self.Pyx_mapping_check_type, + utility_code=utility_code, + args=[subject_node, sequence_mapping_temp], + ) + + return StaticTypeCheckNode( + self.pos, arg=subject_node, fallback=call, check=self.is_dict_type_check + ) + + def make_duplicate_keys_check(self, n_fixed_keys): + utility_code = UtilityCode.load_cached("MappingKeyCheck", "MatchCase.c") + if n_fixed_keys == len(self.keys): + return None # nothing to check + + return Nodes.ExprStatNode( + self.pos, + expr=ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_CheckMappingDuplicateKeys", + self.Pyx_mapping_check_duplicates_type, + utility_code=utility_code, + args=[ + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(n_fixed_keys)), + ExprNodes.IntNode(self.pos, value=str(len(self.keys))) + ], + ), + ) + + def check_all_keys(self, subject_node): + # It's debatable here whether to go for individual unpacking or a function. + # Current implementation is a function that's loosely copied from CPython. + # For small numbers of keys it might be better to generate the code instead. + # There's three versions depending on if we know that the type is exactly + # a dict, definitely not or dict, or unknown. + # The advantages of generating a function are: + # * more compact code + # * easier to check the type once then branch the implementation + # * faster in the cases that are more likely to fail due to wrong keys being + # present than due to the values not matching the patterns + if not self.keys: + return ExprNodes.BoolNode(self.pos, value=True) + + is_dict = self.is_dict_type_check(subject_node.type) + if is_dict: + util_code = UtilityCode.load_cached("ExtractExactDict", "MatchCase.c") + func_name = "__Pyx_MatchCase_Mapping_ExtractDict" + elif is_dict is False: # exact False... None indicates "might be dict" + # For any other non-generic PyObject type + util_code = UtilityCode.load_cached("ExtractNonDict", "MatchCase.c") + func_name = "__Pyx_MatchCase_Mapping_ExtractNonDict" + else: + util_code = UtilityCode.load_cached("ExtractGeneric", "MatchCase.c") + func_name = "__Pyx_MatchCase_Mapping_Extract" + + return ExprNodes.PythonCapiCallNode( + self.pos, + func_name, + self.Pyx_mapping_extract_subjects_type, + utility_code=util_code, + args=[ + subject_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode( + self.pos, + value=str(len(self.keys)) + ), + MappingOrClassComparisonNode.make_subjects_node(self.pos), + ], + ) + + def make_double_star_capture(self, subject_node, test_result): + # test_result being the variable that holds "case check passed until now" + is_dict = self.is_dict_type_check(subject_node.type) + if is_dict: + tag = "ExactDict" + elif is_dict is False: + tag = "NotDict" + else: + tag = "" + utility_code = TempitaUtilityCode.load_cached( + "DoubleStarCapture", "MatchCase.c", context={"tag": tag} + ) + func = ExprNodes.PythonCapiCallNode( + self.double_star_capture_target.pos, + "__Pyx_MatchCase_DoubleStarCapture" + tag, + self.Pyx_mapping_doublestar_type, + utility_code=utility_code, + args=[ + subject_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(len(self.keys))) + ], + ) + assignment = Nodes.SingleAssignmentNode( + self.double_star_capture_target.pos, lhs=self.double_star_temp, rhs=func + ) + if_clause = Nodes.IfClauseNode( + self.double_star_capture_target.pos, condition=test_result, body=assignment + ) + return Nodes.IfStatNode( + self.double_star_capture_target.pos, + if_clauses=[if_clause], + else_clause=None, + ) + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + from . import UtilNodes + + var_keys = [] + n_literal_keys = 0 + for k in self.keys: + if not k.is_literal: + var_keys.append(k) + else: + n_literal_keys += 1 + + all_tests = [] + all_tests.append(self.make_mapping_check(subject_node, sequence_mapping_temp)) + all_tests.append(self.check_all_keys(subject_node)) + + if any(isinstance(test, ExprNodes.BoolNode) and not test.value for test in all_tests): + # identify automatic-failure + return ExprNodes.BoolNode(self.pos, value=False) + + for pattern, subject in zip(self.value_patterns, self.subject_temps): + if pattern.is_irrefutable(): + continue + assert subject + all_tests.append(pattern.get_comparison_node(subject)) + + all_tests = generate_binop_tree_from_list(self.pos, "and", all_tests) + + test_result = UtilNodes.ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type) + duplicate_check = self.make_duplicate_keys_check(n_literal_keys) + body = Nodes.StatListNode( + self.pos, + stats=([duplicate_check] if duplicate_check else []) + [ + Nodes.SingleAssignmentNode(self.pos, lhs=test_result, rhs=all_tests), + ], + ) + if self.double_star_capture_target: + assert self.double_star_temp + body.stats.append( + # make_double_star_capture wraps itself in an if + self.make_double_star_capture(subject_node, test_result) + ) + + if duplicate_check or self.double_star_capture_target: + body = UtilNodes.TempResultFromStatNode(test_result, body) + else: + body = all_tests + if self.keys or self.double_star_capture_target: + body = MappingOrClassComparisonNode( + body.pos, + arg=LazyCoerceToBool(body.pos, arg=body), + keys_array=self.keys, + subjects_array=self.subject_temps + ) + return LazyCoerceToBool(body.pos, arg=body) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + def to_temp_or_literal(node): + if node.is_literal: + return node + else: + return node.coerce_to_temp(env) + + self.keys = [ + to_temp_or_literal(k.analyse_expressions(env)) + for k in self.keys + ] + + self.value_patterns = [ p.analyse_pattern_expressions(env, None) for p in self.value_patterns ] + return self + + def allocate_subject_temps(self, code): + for temp in self.subject_temps: + if temp is not None: + temp.allocate(code) + for pattern in self.value_patterns: + pattern.allocate_subject_temps(code) + if self.double_star_temp: + self.double_star_temp.allocate(code) + + def release_subject_temps(self, code): + for temp in self.subject_temps: + if temp is not None: + temp.release(code) + for pattern in self.value_patterns: + pattern.release_subject_temps(code) + if self.double_star_temp: + self.double_star_temp.release(code) + + def dispose_of_subject_temps(self, code): + for temp in self.subject_temps: + if temp is not None: + code.put_xdecref_clear(temp.result(), temp.type) + for pattern in self.value_patterns: + pattern.dispose_of_subject_temps(code) + if self.double_star_temp: + code.put_xdecref_clear( + self.double_star_temp.result(), self.double_star_temp.type + ) + + +class ClassPatternNode(PatternNode): + """ + class_ NameNode or AttributeNode + positional_patterns list of PatternNodes + keyword_pattern_names list of NameNodes + keyword_pattern_patterns list of PatternNodes + (same length as keyword_pattern_names) + """ + + class_ = None + positional_patterns = [] + keyword_pattern_names = [] + keyword_pattern_patterns = [] + + # as with the mapping functions, lie a little about some of the types for + # ease of declaration + Pyx_positional_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, + [ + PyrexTypes.CFuncTypeArg("subject", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None), + PyrexTypes.CFuncTypeArg("fixed_names", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("n_fixed", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("match_self", PyrexTypes.c_int_type, None), + PyrexTypes.CFuncTypeArg("subjects", PyrexTypes.c_void_ptr_ptr_type, None), + PyrexTypes.CFuncTypeArg("n_subjects", PyrexTypes.c_int_type, None), + ], + exception_value="-1", + ) + + Pyx_istype_type = PyrexTypes.CFuncType( + Builtin.type_type, + [ + PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), + ], + ) + + child_attrs = PatternNode.child_attrs + [ + "class_", + "positional_patterns", + "keyword_pattern_patterns", + # keyword_pattern_names are deliberately excluded. They're only NameNodes as a + # convenient way of storing a name and a pos. There's nothing to be gained from + # processing them + ] + + def generate_subjects(self, subject_node): + assert not hasattr(self, "keyword_subject_temps") + + if self.class_known_type: + # maximizes type inference + subject_node = ExprNodes.TypecastNode( + subject_node.pos, + operand=subject_node, + type=self.class_known_type, + typecheck=False, + ) + + self.keyword_subject_temps = [] + self.keyword_subject_attrs = [] + for p, p_name in zip(self.keyword_pattern_patterns, self.keyword_pattern_names): + # The attribute lookups are calculated here to maximize chance of type interference + attr_lookup = ExprNodes.AttributeNode( + p_name.pos, obj=subject_node, attribute=p_name.name, dont_mangle_private_names=True + ) + self.keyword_subject_attrs.append(attr_lookup) + if not p.get_targets() and p.is_irrefutable(): + self.keyword_subject_temps.append(None) + else: + # Hopefully the type can be assigned later + self.keyword_subject_temps.append(TrackTypeTempNode(p.pos, attr_lookup)) + + self.positional_subject_temps = [] + for p in self.positional_patterns: + if not p.get_targets() and p.is_irrefutable(): + self.positional_subject_temps.append(None) + else: + self.positional_subject_temps.append( + AssignableTempNode(p.pos, PyrexTypes.py_object_type) + ) + + def get_main_pattern_targets(self): + targets = set() + for pattern in self.positional_patterns + self.keyword_pattern_patterns: + self.update_targets_with_targets(targets, pattern.get_targets()) + return targets + + def generate_main_pattern_assignment_list(self, subject_node, env): + self.generate_subjects(subject_node) + assignments = [] + patterns = self.keyword_pattern_patterns + self.positional_patterns + temps = self.keyword_subject_temps + self.positional_subject_temps + for pattern, temp in zip(patterns, temps): + pattern_assignments = pattern.generate_target_assignments(temp, env) + if pattern_assignments: + assignments.extend(pattern_assignments.stats) + return assignments + + def make_typecheck_call(self, subject_node, class_node): + if not subject_node.type.is_pyobject: + with local_errors(True) as errors: + # TODO - it'd be nice to be able to match up simple c types + # e.g. "int" to "int", "double" to "double" + # without having to go through this + subject_node = LazyCoerceToPyObject(subject_node.pos, arg=subject_node) + if errors: + return ExprNodes.BoolNode(self.pos, value=False) + if self.class_known_type: + if not self.class_known_type.is_pyobject: + error(self.pos, "class must be a Python object") + return ExprNodes.BoolNode(self.pos, value=False) + + if subject_node.type.subtype_of_resolved_type(self.class_known_type): + if subject_node.may_be_none(): + return ExprNodes.PrimaryCmpNode( + self.pos, + operator="is_not", + operand1=subject_node, + operand2=ExprNodes.NoneNode(self.pos), + ) + else: + return ExprNodes.BoolNode(self.pos, value=True) + # if subject_node.type is not PyrexTypes.py_object_type + # I suspect the value is false, but possibly can't prove it + + return ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode( + self.pos, + name="isinstance", + entry=Builtin.builtin_scope.lookup("isinstance"), + ), + args=[subject_node, class_node], + ) + + def make_keyword_pattern_lookups(self): + # These are always looking up fixed names. + # Therefore, get best efficiency by letting Cython do the lookup + # and so infer the types + assert self.keyword_pattern_names + + from .UtilNodes import ResultRefNode, TempResultFromStatNode + + passed_rr = ResultRefNode(pos=self.pos, type=PyrexTypes.c_bint_type) + stats = [] + for pattern_name, subject_temp, lookup in zip( + self.keyword_pattern_names, + self.keyword_subject_temps, + self.keyword_subject_attrs, + ): + if subject_temp: + subject_temp.arg = lookup # it should now know the type + stat = Nodes.SingleAssignmentNode( + pattern_name.pos, lhs=subject_temp, rhs=lookup + ) + else: + stat = Nodes.ExprStatNode(pattern_name.pos, expr=lookup) + stats.append(stat) + except_clause = Nodes.ExceptClauseNode( + self.pos, + pattern=[ + ExprNodes.NameNode( + self.pos, + name="AttributeError", + entry=Builtin.builtin_scope.lookup("AttributeError"), + ) + ], + body=Nodes.StatListNode( + self.pos, + stats=[ + Nodes.SingleAssignmentNode( + self.pos, + lhs=passed_rr, + rhs=ExprNodes.BoolNode(self.pos, value=False), + ) + ], + ), + target=None, + ) + else_clause = Nodes.SingleAssignmentNode( + self.pos, lhs=passed_rr, rhs=ExprNodes.BoolNode(self.pos, value=True) + ) + try_except = Nodes.TryExceptStatNode( + self.pos, + body=Nodes.StatListNode(self.pos, stats=stats), + except_clauses=[except_clause], + else_clause=else_clause, + ) + return TempResultFromStatNode(passed_rr, try_except) + + def make_positional_args_call(self, subject_node, class_node): + assert self.positional_patterns + util_code = UtilityCode.load_cached("ClassPositionalPatterns", "MatchCase.c") + keynames = [ + ExprNodes.StringNode(n.pos, value=n.name) + for n in self.keyword_pattern_names + ] + # -1 is "unknown" + match_self = ( + -1 + if (len(self.positional_patterns) == 1 and not self.keyword_pattern_names) + else 0 + ) + if match_self and self.class_known_type: + for t in [ + # Builtin.bool_type ends up being py_object_type + Builtin.bytearray_type, + Builtin.bytes_type, + Builtin.dict_type, + Builtin.float_type, + Builtin.frozenset_type, + Builtin.long_type, + Builtin.list_type, + Builtin.set_type, + Builtin.unicode_type, + Builtin.str_type, + Builtin.tuple_type, + ]: + if self.class_known_type.subtype_of_resolved_type(t): + match_self = 1 + break + else: + if self.class_known_type.is_extension_type and not ( + self.class_known_type.is_external + or not self.class_known_type.scope.method_table_cname + ): # effectively extern visibility + match_self = 0 # I think... Relies on knowing the bases + + match_self = ExprNodes.IntNode(self.pos, value=str(match_self)) + n_subjects = ExprNodes.IntNode(self.pos, value=str(len(self.positional_patterns))) + return MappingOrClassComparisonNode( + self.pos, + arg=ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_ClassPositional", + self.Pyx_positional_type, + utility_code=util_code, + args=[ + subject_node, + class_node, + MappingOrClassComparisonNode.make_keys_node(self.pos), + ExprNodes.IntNode(self.pos, value=str(len(keynames))), + match_self, + MappingOrClassComparisonNode.make_subjects_node(self.pos), + n_subjects, + ] + ), + subjects_array=self.positional_subject_temps, + keys_array=keynames, + ) + return + + def make_subpattern_checks(self): + patterns = self.keyword_pattern_patterns + self.positional_patterns + temps = self.keyword_subject_temps + self.positional_subject_temps + checks = [] + for temp, pattern in zip(temps, patterns): + if temp: + checks.append(pattern.get_comparison_node(temp)) + return checks + + def get_comparison_node(self, subject_node, sequence_mapping_temp=None): + from .UtilNodes import ResultRefNode, EvalWithTempExprNode + + if self.class_known_type: + class_node = self.class_.clone_node() + class_node.entry = self.class_known_type.entry + else: + if not self.class_.type is Builtin.type_type: + util_code = UtilityCode.load_cached("MatchClassIsType", "MatchCase.c") + class_node = ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_IsType", + self.Pyx_istype_type, + utility_code=util_code, + args=[self.class_], + ) + class_node = ResultRefNode(class_node) + + all_checks = [] + all_checks.append(self.make_typecheck_call(subject_node, class_node)) + + if self.class_known_type: + # From this point on we know the type of the subject + subject_node = ExprNodes.TypecastNode( + self.class_.pos, + operand=subject_node, + type=self.class_known_type, + typecheck=False, + ) + if self.positional_patterns: + all_checks.append(self.make_positional_args_call(subject_node, class_node)) + if self.keyword_pattern_names: + all_checks.append(self.make_keyword_pattern_lookups()) + + all_checks.extend(self.make_subpattern_checks()) + + + if any(isinstance(ch, ExprNodes.BoolNode) and not ch.value for ch in all_checks): + # handle any obvious failures + return ExprNodes.BoolNode(self.pos, value=False) + + all_checks = generate_binop_tree_from_list(self.pos, "and", all_checks) + + if isinstance(class_node, ResultRefNode) and not all_checks.is_literal: + return LazyCoerceToBool(class_node.pos, arg=EvalWithTempExprNode(class_node, all_checks)) + else: + return LazyCoerceToBool(all_checks.pos, arg=all_checks) + + def analyse_declarations(self, env): + self.validate_keywords() + # Try to work out the type early + self.class_.analyse_declarations(env) + self.class_known_type = self.class_.analyse_as_extension_type(env) + for p in self.positional_patterns: + p.analyse_declarations(env) + for p_name, p in zip(self.keyword_pattern_names, self.keyword_pattern_patterns): + p_name.analyse_declarations(env) + p.analyse_declarations(env) + super(ClassPatternNode, self).analyse_declarations(env) + + def analyse_pattern_expressions(self, env, sequence_mapping_temp): + self.class_ = self.class_.analyse_types(env) + + self.keyword_subject_attrs = [ a.analyse_types(env) for a in self.keyword_subject_attrs ] + self.keyword_pattern_patterns = [ p.analyse_pattern_expressions(env, None) for p in self.keyword_pattern_patterns ] + self.positional_patterns = [ p.analyse_pattern_expressions(env, None) for p in self.positional_patterns ] + + return self + + def allocate_subject_temps(self, code): + for temp in self.keyword_subject_temps + self.positional_subject_temps: + if temp is not None: + temp.allocate(code) + for pattern in self.keyword_pattern_patterns + self.positional_patterns: + pattern.allocate_subject_temps(code) + + def release_subject_temps(self, code): + for temp in self.keyword_subject_temps + self.positional_subject_temps: + if temp is not None: + temp.release(code) + for pattern in self.keyword_pattern_patterns + self.positional_patterns: + pattern.release_subject_temps(code) + + def dispose_of_subject_temps(self, code): + for temp in self.keyword_subject_temps + self.positional_subject_temps: + if temp is not None: + code.put_xdecref_clear(temp.result(), temp.type) + for pattern in self.keyword_pattern_patterns + self.positional_patterns: + pattern.dispose_of_subject_temps(code) + + def validate_keywords(self): + seen = set() + for kw in self.keyword_pattern_names: + if kw.name in seen: + error( + kw.name, "attribute name repeated in class pattern: '%s" % kw.name + ) + seen.add(kw.name) + + +class SubstitutedIfStatListNode(Nodes.StatListNode): + """ + Like StatListNode but with a "goto end of match" at the + end of it + + match_node - the enclosing match statement + """ + + def generate_execution_code(self, code): + super(SubstitutedIfStatListNode, self).generate_execution_code(code) + if not self.is_terminator: + code.put_goto(self.match_node.end_label) + + +class StaticTypeCheckNode(ExprNodes.ExprNode): + """ + Useful for structural pattern matching, where we + can skip the "is_seqeunce/is_mapping" checks if + we know the type in advantage (or reduce it to a + None check). + + This should optimize itself out at the analyse_expressions + stage + + arg ExprNode + fallback ExprNode Function to be called if the static + typecheck isn't optimized out + check callable Returns True, False, or None (for "can't tell") + """ + + child_attrs = ["fallback"] # arg in not included since it's in "fallback" + + def analyse_types(self, env): + check = self.check(self.arg.type) + if check: + if self.arg.may_be_none(): + return ExprNodes.PrimaryCmpNode( + self.pos, + operand1=self.arg, + operand2=ExprNodes.NoneNode(self.pos), + operator="is_not", + ).analyse_expressions(env) + else: + return ExprNodes.BoolNode(pos=self.pos, value=True).analyse_expressions( + env + ) + elif check is None: + return self.fallback.analyse_expressions(env) + else: + return ExprNodes.BoolNode(pos=self.pos, value=False).analyse_expressions( + env + ) + + +class AssignableTempNode(ExprNodes.TempNode): + lhs_of_first_assignment = True # assume it can be assigned to once + _assigned_twice = False + + def infer_type(self, env): + return self.type + + def generate_assignment_code(self, rhs, code, overloaded_assignment=False): + assert ( + not self._assigned_twice + ) # if this happens it's not a disaster but it needs a refactor + self._assigned_twice = True + if self.type.is_pyobject: + rhs.make_owned_reference(code) + if not self.lhs_of_first_assignment: + code.put_decref(self.result(), self.ctype()) + code.putln( + "%s = %s;" + % ( + self.result(), + rhs.result() if overloaded_assignment else rhs.result_as(self.ctype()), + ) + ) + rhs.generate_post_assignment_code(code) + rhs.free_temps(code) + + def generate_post_assignment_code(self, code): + code.put_incref(self.result(), self.type) + + def generate_disposal_code(self, code): + pass # handled elsewhere - we expect to use this temp multiple times + + def clone_node(self): + return self # temps break if you make a copy! + + +class TrackTypeTempNode(AssignableTempNode): + # Like a temp node, but type is set from arg + + lhs_of_first_assignment = True # assume it can be assigned to once + _assigned_twice = False + + @property + def type(self): + return getattr(self.arg, "type", None) + + def __init__(self, pos, arg): + ExprNodes.ExprNode.__init__(self, pos) # skip a level + self.arg = arg + + def infer_type(self, env): + return self.arg.infer_type(env) + + +class SliceToListNode(ExprNodes.ExprNode): + """ + Used as a brief temporary node to optimize + case [..., *_, ...]. + Always reduces to something else after analyse_types + """ + + subexprs = ["base", "start", "stop", "length_node"] + + type = Builtin.list_type + + Pyx_iterable_to_list_type = PyrexTypes.CFuncType( + Builtin.list_type, + [ + PyrexTypes.CFuncTypeArg("iterable", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + ], + ) + + def generate_via_slicing(self, env): + # for any more complicated type that doesn't have a specialized path + # we can simply slice it and copy it to list + res = CompilerDirectivesExprNode( + arg=ExprNodes.SliceIndexNode( + self.pos, base=self.base, start=self.start, stop=self.stop + ), + directives=copy_inherited_directives( + env.directives, boundcheck=False, wraparound=False + ), + ) + res = ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode( + self.pos, + name="list", + entry=Builtin.builtin_scope.lookup("list"), + ), + args=[res], + ) + return res + + def get_stop(self): + if not self.stop: + if self.length_node: + return self.length_node + else: + return ExprNodes.SimpleCallNode( + self.pos, + function=ExprNodes.NameNode( + self.pos, name="len", entry=Builtin.builtin_scope.lookup("len") + ), + args=[self.base], + ) + else: + return self.stop + + def generate_for_memoryview(self, env): + # Requires Cython code generation... + # A list comprehension with indexing turns out to be a good option + from .UtilityCode import CythonUtilityCode + + suffix = self.base.type.specialization_suffix() + util_code = CythonUtilityCode.load( + "MemoryviewSliceToList", + "MatchCase_Cy.pyx", + context={ + "decl_code": self.base.type.empty_declaration_code(pyrex=True), + "suffix": suffix, + }, + ) + func_type = PyrexTypes.CFuncType( + Builtin.list_type, + [ + PyrexTypes.CFuncTypeArg("x", self.base.type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + ], + ) + env.use_utility_code( + util_code + ) # attaching it to the call node doesn't seem enough + return ExprNodes.PythonCapiCallNode( + self.pos, + "__Pyx_MatchCase_SliceMemoryview_%s" % suffix, + func_type, + utility_code=util_code, + args=[ + self.base, + self.start if self.start else ExprNodes.IntNode(self.pos, value="0"), + self.get_stop(), + ], + ) + + def generate_for_pyobject(self): + util_code_name = None + func_name = None + if self.base.type is Builtin.tuple_type: + util_code_name = "TupleSliceToList" + elif self.base.type is Builtin.list_type: + func_name = "PyList_GetSlice" + elif ( + self.base.type.is_pyobject + and not self.base.type is PyrexTypes.py_object_type + ): + # some specialized type that almost certainly isn't a list. Just go straight + # to the "other" version of it + util_code_name = "OtherSequenceSliceToList" + else: + util_code_name = "UnknownTypeSliceToList" + if not func_name: + func_name = "__Pyx_MatchCase_%s" % util_code_name + if util_code_name: + util_code = UtilityCode.load_cached( + util_code_name, + "MatchCase.c" + ) + else: + util_code = None + start = self.start if self.start else ExprNodes.IntNode(self.pos, value="0") + stop = self.get_stop() + return ExprNodes.PythonCapiCallNode( + self.pos, + func_name, + self.Pyx_iterable_to_list_type, + utility_code=util_code, + args=[self.base, start, stop], + ) + + def analyse_types(self, env): + self.base = self.base.analyse_types(env) + if self.base.type.is_memoryviewslice: + result = self.generate_for_memoryview(env) + elif self.base.type.is_pyobject: + result = self.generate_for_pyobject() + else: + # Some other type (probably a ctuple). + # Just slice it, copy it to a list and hope it works + result = self.generate_via_slicing(env) + return result.analyse_types(env) + + +class CompilerDirectivesExprNode(ExprNodes.ProxyNode): + # Like compiler directives node, but for an expression + # directives {string:value} A dictionary holding the right value for + # *all* possible directives. + # arg ExprNode + + def __init__(self, arg, directives): + super(CompilerDirectivesExprNode, self).__init__(arg) + self.directives = directives + + @contextmanager + def _apply_directives(self, obj): + old = obj.directives + obj.directives = self.directives + yield + obj.directives = old + + @property + def is_temp(self): + return self.arg.is_temp + + def infer_type(self, env): + with self._apply_directives(env): + return super(CompilerDirectivesExprNode, self).infer_type(env) + + def analyse_declarations(self, env): + with self._apply_directives(env): + self.arg.analyse_declarations(env) + + def analyse_types(self, env): + with self._apply_directives(env): + return super(CompilerDirectivesExprNode, self).analyse_types(env) + + def generate_result_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_result_code(code) + + def generate_evaluation_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_evaluation_code(code) + + def generate_disposal_code(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).generate_disposal_code(code) + + def free_temps(self, code): + with self._apply_directives(code.globalstate): + super(CompilerDirectivesExprNode, self).free_temps(code) + + def annotate(self, code): + with self._apply_directives(code.globalstate): + self.arg.annotate(code) + + +class LazyCoerceToPyObject(ExprNodes.ExprNode): + """ + Just calls "self.arg.coerce_to_pyobject" when it's analysed, + so doesn't need 'env' when it's created + arg - ExprNode + """ + subexprs = ["arg"] + type = PyrexTypes.py_object_type + + def analyse_types(self, env): + return self.arg.analyse_types(env).coerce_to_pyobject(env) + + +class LazyCoerceToBool(ExprNodes.ExprNode): + """ + Just calls "self.arg.coerce_to_bool" when it's analysed, + so doesn't need 'env' when it's created + arg - ExprNode + """ + subexprs = ["arg"] + type = PyrexTypes.c_bint_type + + def analyse_types(self, env): + return self.arg.analyse_boolean_expression(env) + +def generate_binop_tree_from_list(pos, operator, list_of_tests): + """ + Given a list of operands generates a roughly balanced tree: + (test1 op test2) op (test3 op test4) + This is better than (((test1 op test2) op test3) op test4) + because it generates a shallower tree of nodes so is + less likely to overflow the compiler + """ + len_tests = len(list_of_tests) + if len_tests == 1: + return list_of_tests[0] + else: + split_idx = len_tests // 2 + operand1 = generate_binop_tree_from_list( + pos, operator, list_of_tests[:split_idx] + ) + operand2 = generate_binop_tree_from_list( + pos, operator, list_of_tests[split_idx:] + ) + return ExprNodes.binop_node( + pos, + operator=operator, + operand1=operand1, + operand2=operand2 + ) + + +class MappingOrClassComparisonNode(ExprNodes.ExprNode): + """ + Combined with MappingOrClassComparisonNodeInner this is responsible + for setting up up the arrays of subjects and keys that are used in + the function calls that handle these types of patterns + + Note that self.keys_array is owned by this but used by + MappingOrClassComparisonNodeInner - that's mainly to ensure that + it gets evaluated in the correct order + """ + subexprs = ["keys_array", "inner"] + + keys_array_cname = "__pyx_match_mapping_keys" + subjects_array_cname = "__pyx_match_mapping_subjects" + + @property + def type(self): + return self.inner.type + + @classmethod + def make_keys_node(cls, pos): + return ExprNodes.RawCNameExprNode( + pos, + type=PyrexTypes.c_void_ptr_type, + cname=cls.keys_array_cname + ) + + @classmethod + def make_subjects_node(cls, pos): + return ExprNodes.RawCNameExprNode( + pos, + type=PyrexTypes.c_void_ptr_ptr_type, + cname=cls.subjects_array_cname + ) + + def __init__(self, pos, arg, subjects_array, **kwds): + super(MappingOrClassComparisonNode, self).__init__(pos, **kwds) + self.inner = MappingOrClassComparisonNodeInner( + pos, + arg=arg, + keys_array = self.keys_array, + subjects_array = subjects_array + ) + + def analyse_types(self, env): + self.inner = self.inner.analyse_types(env) + self.keys_array = [ + key.analyse_types(env).coerce_to_simple(env) for key in self.keys_array + ] + return self + + def generate_result_code(self, code): + pass + + def calculate_result_code(self): + return self.inner.calculate_result_code() + + +class MappingOrClassComparisonNodeInner(ExprNodes.ExprNode): + """ + Sets up the arrays of subjects and keys + + Created by the constructor of MappingComparisonNode + (no need to create directly) + + has attributes: + * arg - the main comparison node + * keys_array - list of ExprNodes representing keys + * subjects_array - list of ExprNodes representing subjects + """ + subexprs = ['arg'] + + @property + def type(self): + return self.arg.type + + def analyse_types(self, env): + self.arg = self.arg.analyse_types(env) + for n in range(len(self.keys_array)): + key = self.keys_array[n].analyse_types(env) + key = key.coerce_to_pyobject(env) + self.keys_array[n] = key + assert self.arg.type is PyrexTypes.c_bint_type + return self + + def generate_evaluation_code(self, code): + code.putln("{") + keys_str = ", ".join(k.result() for k in self.keys_array) + if not keys_str: + # GCC gets worried about overflow if we pass + # a genuinely empty array + keys_str = "NULL" + code.putln("PyObject *%s[] = {%s};" % ( + MappingOrClassComparisonNode.keys_array_cname, + keys_str, + )) + subjects_str = ", ".join( + "&"+subject.result() if subject is not None else "NULL" for subject in self.subjects_array + ) + if not subjects_str: + # GCC gets worried about overflow if we pass + # a genuinely empty array + subjects_str = "NULL" + code.putln("PyObject **%s[] = {%s};" % ( + MappingOrClassComparisonNode.subjects_array_cname, + subjects_str + )) + super(MappingOrClassComparisonNodeInner, self).generate_evaluation_code(code) + + code.putln("}") + + def generate_result_code(self, code): + pass + + def calculate_result_code(self): + return self.arg.result()
\ No newline at end of file diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 66b4f97cf..006d1023c 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -10154,6 +10154,17 @@ class CnameDecoratorNode(StatNode): self.node.generate_execution_code(code) +class ErrorNode(Node): + """ + Node type for things that we want to get through the parser + (especially for things that are being scanned in "tentative_scan" + blocks), but should immediately raise and error afterwards. + + what str + """ + child_attrs = [] + + #------------------------------------------------------------------------------------ # # Runtime support code diff --git a/Cython/Compiler/ParseTreeTransforms.pxd b/Cython/Compiler/ParseTreeTransforms.pxd index efbb14f70..2778be4ef 100644 --- a/Cython/Compiler/ParseTreeTransforms.pxd +++ b/Cython/Compiler/ParseTreeTransforms.pxd @@ -18,6 +18,7 @@ cdef class PostParse(ScopeTrackingTransform): cdef dict specialattribute_handlers cdef size_t lambda_counter cdef size_t genexpr_counter + cdef bint in_pattern_node cdef _visit_assignment_node(self, node, list expr_list) diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index bc4943b79..301d93335 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -193,6 +193,7 @@ class PostParse(ScopeTrackingTransform): self.specialattribute_handlers = { '__cythonbufferdefaults__' : self.handle_bufferdefaults } + self.in_pattern_node = False def visit_LambdaNode(self, node): # unpack a lambda expression into the corresponding DefNode @@ -385,6 +386,32 @@ class PostParse(ScopeTrackingTransform): self.visitchildren(node) return node + def visit_ErrorNode(self, node): + error(node.pos, node.what) + return None + + def visit_MatchCaseNode(self, node): + node.validate_targets() + self.visitchildren(node) + return node + + def visit_MatchNode(self, node): + node.validate_irrefutable() + self.visitchildren(node) + return node + + def visit_PatternNode(self, node): + in_pattern_node, self.in_pattern_node = self.in_pattern_node, True + self.visitchildren(node) + self.in_pattern_node = in_pattern_node + return node + + def visit_JoinedStrNode(self, node): + if self.in_pattern_node: + error(node.pos, "f-strings are not accepted for pattern matching") + self.visitchildren(node) + return node + class _AssignmentExpressionTargetNameFinder(TreeVisitor): def __init__(self): super(_AssignmentExpressionTargetNameFinder, self).__init__() @@ -918,6 +945,9 @@ class InterpretCompilerDirectives(CythonTransform): self.directives = old_directives return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + # The following four functions track imports and cimports that # begin with "cython" def is_cython_directive(self, name): @@ -1550,6 +1580,9 @@ class ParallelRangeTransform(CythonTransform, SkipDeclarations): class WithTransform(VisitorTransform, SkipDeclarations): + # also includes some transforms for MatchCase + # (because this is a convenient time to do them, before constant folding and + # branch elimination) def visit_WithStatNode(self, node): self.visitchildren(node, 'body') pos = node.pos @@ -1611,6 +1644,11 @@ class WithTransform(VisitorTransform, SkipDeclarations): ) return node + def visit_MatchNode(self, node): + node.refactor_cases() + self.visitchildren(node) + return node + def visit_ExprNode(self, node): # With statements are never inside expressions. return node @@ -1990,6 +2028,9 @@ class ForwardDeclareTypes(CythonTransform): env.directives = old return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_ModuleNode(self, node): self.module_scope = node.scope self.module_scope.directives = node.directives @@ -2863,6 +2904,9 @@ class AdjustDefByDirectives(CythonTransform, SkipDeclarations): self.directives = old_directives return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_DefNode(self, node): modifiers = [] if 'inline' in self.directives: diff --git a/Cython/Compiler/Parsing.pxd b/Cython/Compiler/Parsing.pxd index 72a855fd4..fc3e2749f 100644 --- a/Cython/Compiler/Parsing.pxd +++ b/Cython/Compiler/Parsing.pxd @@ -62,6 +62,8 @@ cdef expect_ellipsis(PyrexScanner s) cdef make_slice_nodes(pos, subscripts) cpdef make_slice_node(pos, start, stop = *, step = *) cdef p_atom(PyrexScanner s) +cdef p_atom_string(PyrexScanner s) +cdef p_atom_ident_constants(PyrexScanner s, bint bools_are_pybool = *) @cython.locals(value=unicode) cdef p_int_literal(PyrexScanner s) cdef p_name(PyrexScanner s, name) diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 7c7b7f8a8..75eb194c6 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -25,6 +25,7 @@ from functools import partial, reduce from .Scanning import PyrexScanner, FileSourceDescriptor, tentatively_scan from . import Nodes from . import ExprNodes +from . import MatchCaseNodes from . import Builtin from . import StringEncoding from .StringEncoding import EncodedString, bytes_literal, _unicode, _bytes @@ -717,36 +718,59 @@ def p_atom(s): s.next() return ExprNodes.ImagNode(pos, value = value) elif sy == 'BEGIN_STRING': - kind, bytes_value, unicode_value = p_cat_string_literal(s) - if kind == 'c': - return ExprNodes.CharNode(pos, value = bytes_value) - elif kind == 'u': - return ExprNodes.UnicodeNode(pos, value = unicode_value, bytes_value = bytes_value) - elif kind == 'b': - return ExprNodes.BytesNode(pos, value = bytes_value) - elif kind == 'f': - return ExprNodes.JoinedStrNode(pos, values = unicode_value) - elif kind == '': - return ExprNodes.StringNode(pos, value = bytes_value, unicode_value = unicode_value) - else: - s.error("invalid string kind '%s'" % kind) + return p_atom_string(s) elif sy == 'IDENT': - name = s.systring - if name == "None": - result = ExprNodes.NoneNode(pos) - elif name == "True": - result = ExprNodes.BoolNode(pos, value=True) - elif name == "False": - result = ExprNodes.BoolNode(pos, value=False) - elif name == "NULL" and not s.in_python_file: - result = ExprNodes.NullNode(pos) - else: - result = p_name(s, name) - s.next() + result = p_atom_ident_constants(s) + if result is None: + result = p_name(s, s.systring) + s.next() return result else: s.error("Expected an identifier or literal") + +def p_atom_string(s): + pos = s.position() + kind, bytes_value, unicode_value = p_cat_string_literal(s) + if kind == 'c': + return ExprNodes.CharNode(pos, value=bytes_value) + elif kind == 'u': + return ExprNodes.UnicodeNode(pos, value=unicode_value, bytes_value=bytes_value) + elif kind == 'b': + return ExprNodes.BytesNode(pos, value=bytes_value) + elif kind == 'f': + return ExprNodes.JoinedStrNode(pos, values=unicode_value) + elif kind == '': + return ExprNodes.StringNode(pos, value=bytes_value, unicode_value=unicode_value) + else: + s.error("invalid string kind '%s'" % kind) + + +def p_atom_ident_constants(s, bools_are_pybool=False): + """ + Returns None if it isn't one special-cased named constants. + Only calls s.next() if it successfully matches a matches. + """ + pos = s.position() + name = s.systring + result = None + if bools_are_pybool: + extra_kwds = {'type': Builtin.bool_type} + else: + extra_kwds = {} + if name == "None": + result = ExprNodes.NoneNode(pos) + elif name == "True": + result = ExprNodes.BoolNode(pos, value=True, **extra_kwds) + elif name == "False": + result = ExprNodes.BoolNode(pos, value=False, **extra_kwds) + elif name == "NULL" and not s.in_python_file: + result = ExprNodes.NullNode(pos) + if result: + s.next() + return result + + def p_int_literal(s): pos = s.position() value = s.systring @@ -2443,6 +2467,11 @@ def p_statement(s, ctx, first_statement = 0): elif decorators: s.error("Decorators can only be followed by functions or classes") s.put_back(u'IDENT', ident_name, ident_pos) # re-insert original token + if s.sy == 'IDENT' and s.systring == 'match': + # p_match_statement returns None on a "soft" initial failure + match_statement = p_match_statement(s, ctx) + if match_statement: + return match_statement return p_simple_statement_list(s, ctx, first_statement=first_statement) @@ -4025,6 +4054,447 @@ def p_cpp_class_attribute(s, ctx): return node +def p_match_statement(s, ctx): + assert s.sy == "IDENT" and s.systring == "match" + pos = s.position() + with tentatively_scan(s) as errors: + s.next() + subject = p_namedexpr_test(s) + subjects = None + if s.sy == ",": + subjects = [subject] + while s.sy == ",": + s.next() + if s.sy == ":": + break + subjects.append(p_test(s)) + if subjects is not None: + subject = ExprNodes.TupleNode(pos, args=subjects) + s.expect(":") + if errors: + return None + + # at this stage were commited to it being a match block so continue + # outside "with tentatively_scan" + # (I think this deviates from the PEG parser slightly, and it'd + # backtrack on the whole thing) + s.expect_newline() + s.expect_indent() + cases = [] + while s.sy != "DEDENT": + cases.append(p_case_block(s, ctx)) + s.expect_dedent() + return MatchCaseNodes.MatchNode(pos, subject=subject, cases=cases) + + +def p_case_block(s, ctx): + if not (s.sy == "IDENT" and s.systring == "case"): + s.error("Expected 'case'") + s.next() + pos = s.position() + pattern = p_patterns(s) + guard = None + if s.sy == 'if': + s.next() + guard = p_test(s) + body = p_suite(s, ctx) + + return MatchCaseNodes.MatchCaseNode(pos, pattern=pattern, body=body, guard=guard) + + +def p_patterns(s): + # note - in slight contrast to the name (which comes from the Python grammar), + # returns a single pattern + patterns = [] + seq = False + pos = s.position() + while True: + with tentatively_scan(s) as errors: + pattern = p_maybe_star_pattern(s) + if errors: + if patterns: + break # all is good provided we have at least 1 pattern + else: + e = errors[0] + s.error(e.args[1], pos=e.args[0]) + patterns.append(pattern) + + if s.sy == ",": + seq = True + s.next() + if s.sy in [":", "if"]: + break # common reasons to break + else: + break + + if seq: + return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns) + else: + return patterns[0] + + +def p_maybe_star_pattern(s): + # For match case. Either star_pattern or pattern + if s.sy == "*": + # star pattern + s.next() + target = None + if s.systring != "_": # for match-case '_' is treated as a special wildcard + target = p_pattern_capture_target(s) + else: + s.next() + pattern = MatchCaseNodes.MatchAndAssignPatternNode( + s.position(), target=target, is_star=True + ) + return pattern + else: + pattern = p_pattern(s) + return pattern + + +def p_pattern(s): + # try "as_pattern" then "or_pattern" + # (but practically "as_pattern" starts with "or_pattern" too) + patterns = [] + pos = s.position() + while True: + patterns.append(p_closed_pattern(s)) + if s.sy == "|": + s.next() + else: + break + + if len(patterns) > 1: + pattern = MatchCaseNodes.OrPatternNode( + pos, + alternatives=patterns + ) + else: + pattern = patterns[0] + + if s.sy == 'IDENT' and s.systring == 'as': + s.next() + with tentatively_scan(s) as errors: + pattern.as_targets.append(p_pattern_capture_target(s)) + if errors and s.sy == "_": + s.next() + # make this a specific error + return Nodes.ErrorNode(errors[0].args[0], what=errors[0].args[1]) + elif errors: + with tentatively_scan(s): + expr = p_test(s) + return Nodes.ErrorNode(expr.pos, what="Invalid pattern target") + s.error(errors[0]) + return pattern + + +def p_closed_pattern(s): + """ + The PEG parser specifies it as + | literal_pattern + | capture_pattern + | wildcard_pattern + | value_pattern + | group_pattern + | sequence_pattern + | mapping_pattern + | class_pattern + + For the sake avoiding too much backtracking, we know: + * starts with "{" is a sequence_pattern + * starts with "[" is a mapping_pattern + * starts with "(" is a group_pattern or sequence_pattern + * wildcard pattern is just identifier=='_' + The rest are then tried in order with backtracking + """ + if s.sy == 'IDENT' and s.systring == '_': + pos = s.position() + s.next() + return MatchCaseNodes.MatchAndAssignPatternNode(pos) + elif s.sy == '{': + return p_mapping_pattern(s) + elif s.sy == '[': + return p_sequence_pattern(s) + elif s.sy == '(': + with tentatively_scan(s) as errors: + result = p_group_pattern(s) + if not errors: + return result + return p_sequence_pattern(s) + + with tentatively_scan(s) as errors: + result = p_literal_pattern(s) + if not errors: + return result + with tentatively_scan(s) as errors: + result = p_capture_pattern(s) + if not errors: + return result + with tentatively_scan(s) as errors: + result = p_value_pattern(s) + if not errors: + return result + return p_class_pattern(s) + + +def p_literal_pattern(s): + # a lot of duplication in this function with "p_atom" + next_must_be_a_number = False + sign = '' + if s.sy == '-': + sign = s.sy + sign_pos = s.position() + s.next() + next_must_be_a_number = True + + sy = s.sy + pos = s.position() + + res = None + if sy == 'INT': + res = p_int_literal(s) + elif sy == 'FLOAT': + value = s.systring + s.next() + res = ExprNodes.FloatNode(pos, value=value) + + if res and sign == "-": + res = ExprNodes.UnaryMinusNode(sign_pos, operand=res) + + if res and s.sy in ['+', '-']: + sign = s.sy + s.next() + if s.sy != 'IMAG': + s.error("Expected imaginary number") + else: + add_pos = s.position() + value = s.systring[:-1] + s.next() + res = ExprNodes.binop_node( + add_pos, + sign, + operand1=res, + operand2=ExprNodes.ImagNode(s.position(), value=value) + ) + + if not res and sy == 'IMAG': + value = s.systring[:-1] + s.next() + res = ExprNodes.ImagNode(pos, value=sign+value) + if sign == "-": + res = ExprNodes.UnaryMinusNode(sign_pos, operand=res) + + if res: + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) + + if next_must_be_a_number: + s.error("Expected a number") + if sy == 'BEGIN_STRING': + res = p_atom_string(s) + # f-strings not being accepted is validated in PostParse + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) + elif sy == 'IDENT': + # Note that p_atom_ident_constants includes NULL. + # This is a deliberate Cython addition to the pattern matching specification + result = p_atom_ident_constants(s, bools_are_pybool=True) + if result: + return MatchCaseNodes.MatchValuePatternNode(pos, value=result, is_is_check=True) + + s.error("Failed to match literal") + + +def p_capture_pattern(s): + return MatchCaseNodes.MatchAndAssignPatternNode( + s.position(), + target=p_pattern_capture_target(s) + ) + + +def p_value_pattern(s): + if s.sy != "IDENT": + s.error("Expected identifier") + pos = s.position() + res = p_name(s, s.systring) + s.next() + if s.sy != '.': + s.error("Expected '.'") + while s.sy == '.': + attr_pos = s.position() + s.next() + attr = p_ident(s) + res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr) + if s.sy in ['(', '=']: + s.error("Unexpected symbol '%s'" % s.sy) + return MatchCaseNodes.MatchValuePatternNode(pos, value=res) + + +def p_group_pattern(s): + s.expect("(") + pattern = p_pattern(s) + s.expect(")") + return pattern + + +def p_sequence_pattern(s): + opener = s.sy + pos = s.position() + if opener in ['[', '(']: + closer = ']' if opener == '[' else ')' + s.next() + # maybe_sequence_pattern and open_sequence_pattern + patterns = [] + if s.sy == closer: + s.next() + else: + while True: + patterns.append(p_maybe_star_pattern(s)) + if s.sy == ",": + s.next() + if s.sy == closer: + break + else: + if opener == ')' and len(patterns) == 1: + s.error("tuple-like pattern of length 1 must finish with ','") + break + s.expect(closer) + return MatchCaseNodes.MatchSequencePatternNode(pos, patterns=patterns) + else: + s.error("Expected '[' or '('") + + +def p_mapping_pattern(s): + pos = s.position() + s.expect('{') + if s.sy == '}': + # trivial empty mapping + s.next() + return MatchCaseNodes.MatchMappingPatternNode(pos) + + double_star_capture_target = None + items_patterns = [] + double_star_set_twice = None + pattern_after_double_star = None + star_star_arg_pos = None + while True: + if double_star_capture_target and not star_star_arg_pos: + star_star_arg_pos = s.position() + if s.sy == '**': + s.next() + double_star_capture_target = p_pattern_capture_target(s) + else: + # key=(literal_expr | attr) + with tentatively_scan(s) as errors: + pattern = p_literal_pattern(s) + key = pattern.value + if errors: + pattern = p_value_pattern(s) + key = pattern.value + s.expect(':') + value = p_pattern(s) + items_patterns.append((key, value)) + if double_star_capture_target: + pattern_after_double_star = value.pos + if s.sy==',': + s.next() + else: + break + if s.sy=='}': + break + if s.sy != '}': + s.error("Expected '}'") + s.next() + if double_star_set_twice is not None: + return Nodes.ErrorNode(double_star_set_twice, what = "Double star capture set twice") + if pattern_after_double_star: + return Nodes.ErrorNode(pattern_after_double_star, what = "pattern follows ** capture") + if star_star_arg_pos is not None: + return Nodes.ErrorNode( + star_star_arg_pos, + what = "** pattern must be the final part of a mapping pattern." + ) + return MatchCaseNodes.MatchMappingPatternNode( + pos, + keys = [kv[0] for kv in items_patterns], + value_patterns = [kv[1] for kv in items_patterns], + double_star_capture_target = double_star_capture_target + ) + + +def p_class_pattern(s): + # start by parsing the class as name_or_attr + pos = s.position() + res = p_name(s, s.systring) + s.next() + while s.sy == '.': + attr_pos = s.position() + s.next() + attr = p_ident(s) + res = ExprNodes.AttributeNode(attr_pos, obj=res, attribute=attr) + class_ = res + + s.expect("(") + if s.sy == ")": + # trivial case with no arguments matched + s.next() + return MatchCaseNodes.ClassPatternNode(pos, class_=class_) + + # parse the arguments + positional_patterns = [] + keyword_patterns = [] + keyword_patterns_error = None + while True: + with tentatively_scan(s) as errors: + positional_patterns.append(p_pattern(s)) + if not errors: + if keyword_patterns: + keyword_patterns_error = s.position() + else: + with tentatively_scan(s) as errors: + keyword_patterns.append(p_keyword_pattern(s)) + if s.sy != ",": + break + s.next() + if s.sy == ")": + break # Allow trailing comma. + s.expect(")") + + if keyword_patterns_error is not None: + return Nodes.ErrorNode( + keyword_patterns_error, + what="Positional patterns follow keyword patterns" + ) + return MatchCaseNodes.ClassPatternNode( + pos, class_ = class_, + positional_patterns = positional_patterns, + keyword_pattern_names = [kv[0] for kv in keyword_patterns], + keyword_pattern_patterns = [kv[1] for kv in keyword_patterns], + ) + + +def p_keyword_pattern(s): + if s.sy != "IDENT": + s.error("Expected identifier") + arg = p_name(s, s.systring) + s.next() + s.expect("=") + value = p_pattern(s) + return arg, value + + +def p_pattern_capture_target(s): + # any name but '_', and with some constraints on what follows + if s.sy != 'IDENT': + s.error("Expected identifier") + if s.systring == '_': + s.error("Pattern capture target cannot be '_'") + target = p_name(s, s.systring) + s.next() + if s.sy in ['.', '(', '=']: + s.error("Illegal next symbol '%s'" % s.sy) + return target + + + #---------------------------------------------- # # Debugging diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 92e2eb9c0..e9545d76e 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -318,6 +318,9 @@ class CythonTransform(VisitorTransform): self.current_directives = old return node + def visit_CompilerDirectivesExprNode(self, node): + return self.visit_CompilerDirectivesNode(node) + def visit_Node(self, node): self._process_children(node) return node diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 8bcd26b6f..45a8e6f59 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -12,9 +12,10 @@ from functools import partial from .Compiler import Errors from .CodeWriter import CodeWriter -from .Compiler.TreeFragment import TreeFragment, strip_common_indent +from .Compiler.TreeFragment import TreeFragment, strip_common_indent, StringParseContext from .Compiler.Visitor import TreeVisitor, VisitorTransform from .Compiler import TreePath +from .Compiler.ParseTreeTransforms import PostParse class NodeTypeWriter(TreeVisitor): @@ -357,3 +358,24 @@ def write_newer_file(file_path, newer_than, content, dedent=False, encoding=None while other_time is None or other_time >= os.path.getmtime(file_path): write_file(file_path, content, dedent=dedent, encoding=encoding) + + +def py_parse_code(code): + """ + Compiles code far enough to get errors from the parser and post-parse stage. + + Is useful for checking for syntax errors, however it doesn't generate runable + code. + """ + context = StringParseContext("test") + # all the errors we care about are in the parsing or postparse stage + try: + with Errors.local_errors() as errors: + result = TreeFragment(code, pipeline=[PostParse(context)]) + result = result.substitute() + if errors: + raise errors[0] # compile error, which should get caught + else: + return result + except Errors.CompileError as e: + raise SyntaxError(e.message_only) diff --git a/Cython/Utility/MatchCase.c b/Cython/Utility/MatchCase.c new file mode 100644 index 000000000..55c8b99b4 --- /dev/null +++ b/Cython/Utility/MatchCase.c @@ -0,0 +1,907 @@ +///////////////////////////// ABCCheck ////////////////////////////// + +#if PY_VERSION_HEX < 0x030A0000 +static CYTHON_INLINE int __Pyx_MatchCase_IsExactSequence(PyObject *o) { + // is one of the small list of builtin types known to be a sequence + if (PyList_CheckExact(o) || PyTuple_CheckExact(o) || + PyType_CheckExact(o, PyRange_Type) || PyType_CheckExact(o, PyMemoryView_Type)) { + // Use exact type match for these checks. I in the event of inheritence we need to make sure + // that it isn't a mapping too + return 1; + } + return 0; +} + +static CYTHON_INLINE int __Pyx_MatchCase_IsExactMapping(PyObject *o) { + // Py_Dict is the only regularly used mapping type + // "types.MappingProxyType" also exists but is correctly covered by + // the isinstance(o, Mapping) check + return PyDict_CheckExact(o); +} + +static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) { + if (PyType_GetFlags(Py_TYPE(o)) & (Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS)) || + PyByteArray_Check(o)) { + return 1; // these types are deliberately excluded from the sequence test + // even though they look like sequences for most other purposes. + // Leave them as inexact checks since they do pass + // "isinstance(o, collections.abc.Sequence)" so it's very hard to + // reason about their subclasses + } + if (o == Py_None || PyLong_CheckExact(o) || PyFloat_CheckExact(o)) { + return 1; + } + #if PY_MAJOR_VERSION < 3 + if (PyInt_CheckExact(o)) { + return 1; + } + #endif + + return 0; +} + +// sequence_mapping_temp: For Python 3.10 testing sequences and mappings are +// really quick and this is ignored. For lower versions of Python they're +// slow, especially in the "fail" case. +// Therefore, we store an int temp to avoid duplicating tests. +// The bits of it in order are: +// 0. definitely a sequence +// 1. definitely a mapping +// - note that both of the above and be true when +// the type is registered with both abc types (not via inheritance) +// and in this case we return true for both IsSequence or IsMapping +// (which seems like the best handling of an ambiguous situation) +// 2. definitely not a sequence +// 3. definitely not a mapping + +#if PY_VERSION_HEX < 0x030A0000 +#define __PYX_DEFINITELY_SEQUENCE_FLAG 1U +#define __PYX_DEFINITELY_MAPPING_FLAG (1U<<1) +#define __PYX_DEFINITELY_NOT_SEQUENCE_FLAG (1U<<2) +#define __PYX_DEFINITELY_NOT_MAPPING_FLAG (1U<<3) +#define __PYX_SEQUENCE_MAPPING_ERROR (1U<<4) // only used by the ABCCheck function +#endif + +static int __Pyx_MatchCase_InitAndIsInstanceAbc(PyObject *o, PyObject *abc_module, + PyObject **abc_type, PyObject *name) { + assert(!abc_type); + abc_type = PyObject_GetAttr(abc_module, name); + if (!abc_type) { + return -1; + } + return PyObject_IsInstance(o, abc_type); +} + +// the result is defined using the specification for sequence_mapping_temp +// (detailed in "is_sequence") +static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, int definitely_not_sequence, int definitely_not_mapping) { + // in Python 3.10 objects can have their sequence bit set or their mapping bit set + // but not both. Practically this translates to "which type is registered first". + // In Python < 3.10 we can only determine this if they're direct bases (by looking + // at the MRO order). If they're registered manually then we can't tell + + PyObject *abc_module=NULL, *sequence_type=NULL, *mapping_type=NULL; + PyObject *mro; + int sequence_result=0, mapping_result=0; + unsigned int result = 0; + + abc_module = PyImport_ImportModule( +#if PY_VERSION_HEX > 0x03030000 + "collections.abc" +#else + "collections" +#endif + ); + if (!abc_module) { + return __PYX_SEQUENCE_MAPPING_ERROR; + } + if (sequence_first) { + if (definitely_not_sequence) { + result = __PYX_DEFINITELY_SEQUENCE_FLAG; + goto end; + } + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); + if (sequence_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (sequence_result == 0) { + result |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG; + goto end; + } + // else wait to see what mapping is + } + if (!definitely_not_mapping) { + mapping_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &mapping_type, PYIDENT("Mapping")); + if (mapping_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (mapping_result == 0) { + result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG; + if (sequence_first) { + assert(sequence_result); + result |= __PYX_DEFINITELY_SEQUENCE_FLAG; + } + goto end; + } else /* mapping_result == 1 */ { + if (sequence_first && !sequence_result) { + result |= __PYX_DEFINITELY_MAPPING_FLAG; + goto end; + } + } + } + if (!sequence_first) { + // here we know mapping_result is true because we'd have returned otherwise + assert(mapping_result); + if (!definitely_not_sequence) { + sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence")); + } + if (sequence_result < 0) { + result = __PYX_SEQUENCE_MAPPING_ERROR; + goto end; + } else if (sequence_result == 0) { + result |= (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG); + goto end; + } /* else sequence_result == 1, continue to check both */ + } + + // It's an instance of both types. Look up the MRO order. + // In event of failure treat it as "could be either" + result = __PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG; + mro = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__mro__"); + Py_ssize_t i; + if (!mro) { + PyErr_Clear(); + goto end; + } + if (!PyTuple_Check(mro)) { + Py_DECREF(mro); + goto end; + } + for (i=1; i < PyTuple_GET_SIZE(mro); ++i) { + int is_subclass_sequence, is_subclass_mapping; + PyObject *mro_item = PyTuple_GET_ITEM(mro, i); + is_subclass_sequence = PyObject_IsSubclass(mro_item, sequence_type); + if (is_subclass_sequence < 0) goto loop_error; + is_subclass_mapping = PyObject_IsSubclass(mro_item, mapping_type); + if (is_subclass_mapping < 0) goto loop_error; + if (is_subclass_sequence && !is_subclass_mapping) { + result = (__PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + break; + } else if (is_subclass_mapping && !is_subclass_sequence) { + result = (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_MAPPING_FLAG); + break; + } + } + // If we get to the end of the loop without breaking then neither type is in + // the MRO, so they've both been registered manually. We don't know which was + // registered first so accept the object as either as a compromise + if (0) { + loop_error: + PyErr_Clear(); + } + Py_DECREF(mro); + + end: + Py_XDECREF(abc_module); + Py_XDECREF(sequence_type); + Py_XDECREF(mapping_type); + return result; +} +#endif + +///////////////////////////// IsSequence.proto ////////////////////// + +static int __Pyx_MatchCase_IsSequence(PyObject *o, unsigned int *sequence_mapping_temp); /* proto */ + +//////////////////////////// IsSequence ///////////////////////// +//@requires: ABCCheck + +static int __Pyx_MatchCase_IsSequence(PyObject *o, unsigned int *sequence_mapping_temp) { +#if PY_VERSION_HEX >= 0x030A0000 + return __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_SEQUENCE); +#else + // Py_TPFLAGS_SEQUENCE doesn't exit. + PyObject *o_module_name; + unsigned int abc_result, dummy=0; + + if (sequence_mapping_temp) { + // maybe we already know the answer + if (*sequence_mapping_temp & __PYX_DEFINITELY_SEQUENCE_FLAG) { + return 1; + } + if (*sequence_mapping_temp & __PYX_DEFINITELY_NOT_SEQUENCE_FLAG) { + return 0; + } + } else { + // Probably quicker to just assign it and not check from here + sequence_mapping_temp = &dummy; + } + + // Start by check a known list of types + if (__Pyx_MatchCase_IsExactSequence(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 1; + } + if (__Pyx_MatchCase_IsExactMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_MAPPING_FLAG | __PYX_DEFINITELY_NOT_SEQUENCE_FLAG); + return 0; + } + if (__Pyx_MatchCase_IsExactNeitherSequenceNorMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 0; + } + + abc_result = __Pyx_MatchCase_ABCCheck( + o, 1, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_SEQUENCE_FLAG, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_MAPPING_FLAG + ); + if (abc_result & __PYX_SEQUENCE_MAPPING_ERROR) { + return -1; + } + *sequence_mapping_temp = abc_result; + if (*sequence_mapping_temp & __PYX_DEFINITELY_SEQUENCE_FLAG) { + return 1; + } + + // array.array is a more complicated check (and unfortunately isn't covered by + // collections.abc.Sequence on Python <3.10). + // Do the test by checking the module name, and then importing/testing the class + // It also doesn't give perfect results for classes that inherit from both array.array + // and a mapping + o_module_name = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__module__"); + if (!o_module_name) { + return -1; + } +#if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(o_module_name) && PyUnicode_CompareWithASCIIString(o_module_name, "array") == 0) +#else + if (PyBytes_Check(o_module_name) && PyBytes_AS_STRING(o_module_name)[0] == 'a' && + PyBytes_AS_STRING(o_module_name)[1] == 'r' && PyBytes_AS_STRING(o_module_name)[2] == 'r' && + PyBytes_AS_STRING(o_module_name)[3] == 'a' && PyBytes_AS_STRING(o_module_name)[4] == 'y' && + PyBytes_AS_STRING(o_module_name)[5] == '\0') +#endif + { + int is_array; + PyObject *array_module, *array_object; + Py_DECREF(o_module_name); + array_module = PyImport_ImportModule("array"); + if (!array_module) { + PyErr_Clear(); + return 0; // treat these tests as "soft" and don't cause an exception + } + array_object = PyObject_GetAttrString(array_module, "array"); + Py_DECREF(array_module); + if (!array_object) { + PyErr_Clear(); + return 0; + } + is_array = PyObject_IsInstance(o, array_object); + Py_DECREF(array_object); + if (is_array) { + *sequence_mapping_temp |= __PYX_DEFINITELY_SEQUENCE_FLAG; + return 1; + } + PyErr_Clear(); + } else { + Py_DECREF(o_module_name); + } + *sequence_mapping_temp |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG; + return 0; +#endif +} + +////////////////////// OtherSequenceSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////// OtherSequenceSliceToList ////////////////////////// + +// This is substantially based off ceval unpack_iterable. +// It's also pretty similar to itertools.islice +// Indices must be postive - there's no wraparound or boundschecking + +static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { + int total = end-start; + int i; + PyObject *list; + ssizeargfunc slot; + PyTypeObject *type = Py_TYPE(x); + + list = PyList_New(total); + if (!list) { + return NULL; + } + +#if CYTHON_USE_TYPE_SLOTS || PY_MAJOR_VERSION < 3 || CYTHON_COMPILING_IN_PYPY + slot = type->tp_as_sequence ? type->tp_as_sequence->sq_item : NULL; +#else + if ((PY_VERSION_HEX >= 0x030A0000) || __Pyx_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) { + // PyType_GetSlot only works on heap types in Python <3.10 + slot = (ssizeargfunc) PyType_GetSlot(type, Py_sq_item); + } +#endif + if (!slot) { + #if !defined(Py_LIMITED_API) && !defined(PySequence_ITEM) + // PyPy (and maybe others?) implements PySequence_ITEM as a function. In this case + // it's slightly more efficient than using PySequence_GetItem since it skips negative indices + slot = PySequence_ITEM; + #else + slot = PySequence_GetItem; + #endif + } + + for (i=start; i<end; ++i) { + PyObject *obj = slot(x, i); + if (!obj) { + Py_DECREF(list); + return NULL; + } + PyList_SET_ITEM(list, i-start, obj); + } + return list; +} + +////////////////////// TupleSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_TupleSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////// TupleSliceToList ////////////////////////// +//@requires: OtherSequenceSliceToList +//@requires: ObjectHandling.c::TupleAndListFromArray + +// Note that this should also work fine on lists (if needed) +// Indices must be postive - there's no wraparound or boundschecking + +static PyObject *__Pyx_MatchCase_TupleSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { +#if !CYTHON_COMPILING_IN_CPYTHON + return __Pyx_MatchCase_OtherSequenceSliceToList(x, start, end); +#else + PyObject **array; + + (void)__Pyx_MatchCase_OtherSequenceSliceToList; // clear unused warning + + array = PySequence_Fast_ITEMS(x); + return __Pyx_PyList_FromArray(array+start, end-start); +#endif +} + +////////////////////////// UnknownTypeSliceToList.proto ////////////////////// + +static PyObject *__Pyx_MatchCase_UnknownTypeSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end); /* proto */ + +////////////////////////// UnknownTypeSliceToList.proto ////////////////////// +//@requires: TupleSliceToList +//@requires: OtherSequenceSliceToList + +static PyObject *__Pyx_MatchCase_UnknownTypeSliceToList(PyObject *x, Py_ssize_t start, Py_ssize_t end) { + if (PyList_CheckExact(x)) { + return PyList_GetSlice(x, start, end); + } +#if !CYTHON_COMPILING_IN_CPYTHON + // since __Pyx_MatchCase_TupleToList only does anything special in CPython, skip the check otherwise + if (PyTuple_CheckExact(x)) { + return __Pyx_MatchCase_TupleSliceToList(x, start, end); + } +#else + (void)__Pyx_MatchCase_TupleSliceToList; +#endif + return __Pyx_MatchCase_OtherSequenceSliceToList(x, start, end); +} + +///////////////////////////// IsMapping.proto ////////////////////// + +static int __Pyx_MatchCase_IsMapping(PyObject *o, unsigned int *sequence_mapping_temp); /* proto */ + +//////////////////////////// IsMapping ///////////////////////// +//@requires: ABCCheck + +static int __Pyx_MatchCase_IsMapping(PyObject *o, unsigned int *sequence_mapping_temp) { +#if PY_VERSION_HEX >= 0x030A0000 + return __Pyx_PyType_HasFeature(Py_TYPE(o), Py_TPFLAGS_MAPPING); +#else + unsigned int abc_result, dummy=0; + if (sequence_mapping_temp) { + // do we already know the answer? + if (*sequence_mapping_temp & __PYX_DEFINITELY_MAPPING_FLAG) { + return 1; + } else if (*sequence_mapping_temp & __PYX_DEFINITELY_NOT_MAPPING_FLAG) { + return 0; + } + } else { + sequence_mapping_temp = &dummy; // just so we can assign freely without checking + } + + if (__Pyx_MatchCase_IsExactMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_MAPPING_FLAG | __PYX_DEFINITELY_NOT_SEQUENCE_FLAG); + return 1; + } + if (__Pyx_MatchCase_IsExactSequence(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 0; + } + if (__Pyx_MatchCase_IsExactNeitherSequenceNorMapping(o)) { + *sequence_mapping_temp |= (__PYX_DEFINITELY_NOT_SEQUENCE_FLAG | __PYX_DEFINITELY_NOT_MAPPING_FLAG); + return 0; + } + + // otherwise check against collections.abc.Mapping + abc_result = __Pyx_MatchCase_ABCCheck( + o, 0, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_SEQUENCE_FLAG, + *sequence_mapping_temp & __PYX_DEFINITELY_NOT_MAPPING_FLAG + ); + if (abc_result & __PYX_SEQUENCE_MAPPING_ERROR) { + return -1; + } + *sequence_mapping_temp = abc_result; + return *sequence_mapping_temp & __PYX_DEFINITELY_MAPPING_FLAG; +#endif +} + +//////////////////////// MappingKeyCheck.proto ///////////////////////// + +static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys); + +//////////////////////// MappingKeyCheck /////////////////////////////// + +static int __Pyx_MatchCase_CheckMappingDuplicateKeys(PyObject *keys[], Py_ssize_t nFixedKeys, Py_ssize_t nKeys) { + // Inputs are arrays, and typically fairly small. It may be more efficient to + // loop over the array than create a set. + + // The CPython implementation (match_keys in ceval.c) does this concurrently with + // taking the keys out of the dictionary. I'm choosing to do it separately since the + // majority of the time the keys will be known at compile-time so Cython can skip + // this step completely. + + PyObject *var_keys_set; + PyObject *key; + Py_ssize_t n; + int contains; + + var_keys_set = PySet_New(NULL); + if (!var_keys_set) return -1; + + for (n=nFixedKeys; n < nKeys; ++n) { + key = keys[n]; + contains = PySet_Contains(var_keys_set, key); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } else { + if (PySet_Add(var_keys_set, key)) { + goto bad; + } + } + } + for (n=0; n < nFixedKeys; ++n) { + key = keys[n]; + contains = PySet_Contains(var_keys_set, key); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } + } + Py_DECREF(var_keys_set); + return 0; + + raise_error: + #if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_ValueError, + "mapping pattern checks duplicate key (%R)", key); + #else + // DW really can't be bothered working around features that don't exist in + // Python 2, so just provide less information! + PyErr_SetString(PyExc_ValueError, + "mapping pattern checks duplicate key"); + #endif + bad: + Py_DECREF(var_keys_set); + return -1; +} + +/////////////////////////// ExtractExactDict.proto //////////////// + +// the variadic arguments are a list of PyObject** to subjects to be filled. They may be NULL +// in which case they're ignored. +// +// This is a specialized version for when we have an exact dict (which is likely to be pretty common) + +#if CYTHON_REFNANNY +#define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(__pyx_refnanny, __VA_ARGS__) +#else +#define __Pyx_MatchCase_Mapping_ExtractDict(...) __Pyx__MatchCase_Mapping_ExtractDict(NULL, __VA_ARGS__) +#endif +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ + +/////////////////////////// ExtractExactDict //////////////// + +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractDict(void *__pyx_refnanny, PyObject *dict, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { + Py_ssize_t i; + + for (i=0; i<nKeys; ++i) { + PyObject *key = keys[i]; + PyObject **subject = subjects[i]; + if (!subject) { + int contains = PyDict_Contains(dict, key); + if (contains <= 0) { + return -1; // any subjects that were already set will be cleaned up externally + } + } else { + PyObject *value = __Pyx_PyDict_GetItemStrWithError(dict, key); + if (!value) { + return (PyErr_Occurred()) ? -1 : 0; // any subjects that were already set will be cleaned up externally + } + __Pyx_XDECREF_SET(*subject, value); + __Pyx_INCREF(*subject); // capture this incref with refnanny! + } + } + return 1; // success +} + +///////////////////////// ExtractNonDict.proto //////////////////////////////// + +// the variadic arguments are a list of PyObject** to subjects to be filled. They may be NULL +// in which case they're ignored. +// +// This is a specialized version for the rarer case when the type isn't an exact dict. + +#if CYTHON_REFNANNY +#define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(__pyx_refnanny, __VA_ARGS__) +#else +#define __Pyx_MatchCase_Mapping_ExtractNonDict(...) __Pyx__MatchCase_Mapping_ExtractNonDict(NULL, __VA_ARGS__) +#endif +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ + +///////////////////////// ExtractNonDict ////////////////////////////////////// +//@requires: ObjectHandling.c::PyObjectCall2Args + +// largely adapted from match_keys in CPython ceval.c + +static int __Pyx__MatchCase_Mapping_ExtractNonDict(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { + PyObject *dummy=NULL, *get=NULL; + Py_ssize_t i; + int result = 0; +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + PyObject *get_method = NULL, *get_self = NULL; +#endif + + dummy = PyObject_CallObject((PyObject *)&PyBaseObject_Type, NULL); + if (!dummy) { + return -1; + } + get = PyObject_GetAttrString(mapping, "get"); + if (!get) { + result = -1; + goto end; + } +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + if (likely(PyMethod_Check(get))) { + // both of these are borrowed + get_method = PyMethod_GET_FUNCTION(get); + get_self = PyMethod_GET_SELF(get); + } +#endif + + for (i=0; i<nKeys; ++i) { + PyObject **subject; + PyObject *value = NULL; + PyObject *key = keys[i]; + + // TODO - there's an optimization here (although it deviates from the strict definition of pattern matching). + // If we don't need the values then we can call PyObject_Contains instead of "get". If we don't need *any* + // of the values then we can skip initialization "get" and "dummy" +#if CYTHON_UNPACK_METHODS && CYTHON_VECTORCALL + if (likely(get_method)) { + PyObject *args[] = { get_self, key, dummy }; + value = _PyObject_Vectorcall(get_method, args, 3, NULL); + } + else +#endif + { + value = __Pyx_PyObject_Call2Args(get, key, dummy); + } + if (!value) { + result = -1; + goto end; + } else if (value == dummy) { + Py_DECREF(value); + goto end; // failed + } else { + subject = subjects[i]; + if (subject) { + __Pyx_XDECREF_SET(*subject, value); + __Pyx_GOTREF(*subject); + } else { + Py_DECREF(value); + } + } + } + result = 1; + + end: + Py_XDECREF(dummy); + Py_XDECREF(get); + return result; +} + +///////////////////////// ExtractGeneric.proto //////////////////////////////// + +#if CYTHON_REFNANNY +#define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(__pyx_refnanny, __VA_ARGS__) +#else +#define __Pyx_MatchCase_Mapping_Extract(...) __Pyx__MatchCase_Mapping_Extract(NULL, __VA_ARGS__) +#endif +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]); /* proto */ + +////////////////////// ExtractGeneric ////////////////////////////////////// +//@requires: ExtractExactDict +//@requires: ExtractNonDict + +static CYTHON_INLINE int __Pyx__MatchCase_Mapping_Extract(void *__pyx_refnanny, PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys, PyObject **subjects[]) { + if (PyDict_CheckExact(mapping)) { + return __Pyx_MatchCase_Mapping_ExtractDict(mapping, keys, nKeys, subjects); + } else { + return __Pyx_MatchCase_Mapping_ExtractNonDict(mapping, keys, nKeys, subjects); + } +} + +///////////////////////////// DoubleStarCapture.proto ////////////////////// + +static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys); /* proto */ + +//////////////////////////// DoubleStarCapture ////////////////////////////// + +// The implementation is largely copied from the original COPY_DICT_WITHOUT_KEYS opcode +// implementation of CPython +// https://github.com/python/cpython/blob/145bf269df3530176f6ebeab1324890ef7070bf8/Python/ceval.c#L3977 +// (now removed in favour of building the same thing from a combination of opcodes) +// The differences are: +// 1. We use an array of keys rather than a tuple of keys +// 2. We add a shortcut for when there will be no left over keys (because I guess it's pretty common) +// +// Tempita variable 'tag' can be "NonDict", "ExactDict" or empty + +static PyObject* __Pyx_MatchCase_DoubleStarCapture{{tag}}(PyObject *mapping, PyObject *keys[], Py_ssize_t nKeys) { + PyObject *dict_out; + Py_ssize_t i; + + {{if tag != "NonDict"}} + // shortcut for when there are no left-over keys + if ({{if tag=="ExactDict"}}(1){{else}}PyDict_CheckExact(mapping){{endif}}) { + Py_ssize_t s = PyDict_Size(mapping); + if (s == -1) { + return NULL; + } + if (s == nKeys) { + return PyDict_New(); + } + } + {{endif}} + + {{if tag=="ExactDict"}} + dict_out = PyDict_Copy(mapping); + {{else}} + dict_out = PyDict_New(); + {{endif}} + if (!dict_out) { + return NULL; + } + {{if tag!="ExactDict"}} + if (PyDict_Update(dict_out, mapping)) { + Py_DECREF(dict_out); + return NULL; + } + {{endif}} + + for (i=0; i<nKeys; ++i) { + if (PyDict_DelItem(dict_out, keys[i])) { + Py_DECREF(dict_out); + return NULL; + } + } + return dict_out; +} + +////////////////////////////// ClassPositionalPatterns.proto //////////////////////// + +#if CYTHON_REFNANNY +#define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(__pyx_refnanny, __VA_ARGS__) +#else +#define __Pyx_MatchCase_ClassPositional(...) __Pyx__MatchCase_ClassPositional(NULL, __VA_ARGS__) +#endif +static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects); /* proto */ + +/////////////////////////////// ClassPositionalPatterns ////////////////////////////// + +static int __Pyx_MatchCase_ClassCheckDuplicateAttrs(const char *tp_name, PyObject *fixed_names[], Py_ssize_t n_fixed, PyObject *match_args, Py_ssize_t num_args) { + // a lot of the basic logic of this is shared with __Pyx_MatchCase_CheckMappingDuplicateKeys + // but they take different input types so it isn't easy to actually share the code. + + // Inputs are tuples, and typically fairly small. It may be more efficient to + // loop over the tuple than create a set. + + PyObject *attrs_set; + PyObject *attr = NULL; + Py_ssize_t n; + int contains; + + attrs_set = PySet_New(NULL); + if (!attrs_set) return -1; + + num_args = PyTuple_GET_SIZE(match_args) < num_args ? PyTuple_GET_SIZE(match_args) : num_args; + for (n=0; n < num_args; ++n) { + attr = PyTuple_GET_ITEM(match_args, n); + contains = PySet_Contains(attrs_set, attr); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } else { + if (PySet_Add(attrs_set, attr)) { + goto bad; + } + } + } + for (n=0; n < n_fixed; ++n) { + attr = fixed_names[n]; + contains = PySet_Contains(attrs_set, attr); + if (contains < 0) { + goto bad; + } else if (contains == 1) { + goto raise_error; + } + } + Py_DECREF(attrs_set); + return 0; + + raise_error: + #if PY_MAJOR_VERSION > 2 + PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute %R", + tp_name, attr); + #else + // DW has no interest in working around the lack of %R in Python 2.7 + PyErr_Format(PyExc_TypeError, "%s() got multiple sub-patterns for attribute", + tp_name); + #endif + bad: + Py_DECREF(attrs_set); + return -1; +} + +// Adapted from ceval.c "match_class" in CPython +// +// The argument match_self can equal 1 for "known to be true" +// 0 for "known to be false" +// -1 for "unknown", runtime test +// nargs is >= 0 otherwise this function will be skipped +static int __Pyx__MatchCase_ClassPositional(void *__pyx_refnanny, PyObject *subject, PyTypeObject *type, PyObject *fixed_names[], Py_ssize_t n_fixed, int match_self, PyObject **subjects[], Py_ssize_t n_subjects) +{ + PyObject *match_args; + Py_ssize_t allowed, i; + int result; + + match_args = PyObject_GetAttrString((PyObject*)type, "__match_args__"); + if (!match_args) { + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + + if (match_self == -1) { + #if defined(_Py_TPFLAGS_MATCH_SELF) + match_self = PyType_HasFeature(type, + _Py_TPFLAGS_MATCH_SELF); + #else + // probably an earlier version of Python. Go off the known list in the specification + match_self = ((PyType_GetFlags(type) & + // long should capture bool too + (Py_TPFLAGS_LONG_SUBCLASS | Py_TPFLAGS_LIST_SUBCLASS | Py_TPFLAGS_TUPLE_SUBCLASS | + Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS | Py_TPFLAGS_DICT_SUBCLASS + #if PY_MAJOR_VERSION < 3 + | Py_TPFLAGS_IN_SUBCLASS + #endif + )) || + PyType_IsSubtype(type, &PyByteArray_Type) || + PyType_IsSubtype(type, &PyFloat_Type) || + PyType_IsSubtype(type, &PyFrozenSet_Type) || + ); + #endif + } + } else { + return -1; + } + } else { + match_self = 0; + if (!PyTuple_CheckExact(match_args)) { + PyErr_Format(PyExc_TypeError, "%s.__match_args__ must be a tuple (got %s)", + type->tp_name, + Py_TYPE(match_args)->tp_name + ); + Py_DECREF(match_args); + return -1; + } + } + + allowed = match_self ? + 1 : (match_args ? PyTuple_GET_SIZE(match_args) : 0); + if (allowed < n_subjects) { + const char *plural = (allowed == 1) ? "" : "s"; + PyErr_Format(PyExc_TypeError, + "%s() accepts %d positional sub-pattern%s (%d given)", + type->tp_name, + allowed, plural, n_subjects); + Py_XDECREF(match_args); + return -1; + } + if (match_self) { + PyObject **self_subject = subjects[0]; + if (self_subject) { + // Easy. Copy the subject itself, and move on to kwargs. + __Pyx_XDECREF_SET(*self_subject, subject); + __Pyx_INCREF(*self_subject); + } + result = 1; + goto end_match_self; + } + // next stage is to check for duplicate attributes. + if (__Pyx_MatchCase_ClassCheckDuplicateAttrs(type->tp_name, fixed_names, n_fixed, match_args, n_subjects)) { + result = -1; + goto end; + } + + for (i = 0; i < n_subjects; i++) { + PyObject *attr; + PyObject **subject_i; + PyObject *name = PyTuple_GET_ITEM(match_args, i); + if (!PyUnicode_CheckExact(name)) { + PyErr_Format(PyExc_TypeError, + "__match_args__ elements must be strings " + "(got %s)", Py_TYPE(name)->tp_name); + result = -1; + goto end; + } + + attr = PyObject_GetAttr(subject, name); + if (attr == NULL && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + result = 0; + goto end; + } + subject_i = subjects[i]; + if (subject_i) { + __Pyx_XDECREF_SET(*subject_i, attr); + __Pyx_GOTREF(attr); + } else { + Py_DECREF(attr); + } + } + result = 1; + + end: + Py_DECREF(match_args); + end_match_self: // because match_args isn't set + return result; +} + +//////////////////////// MatchClassIsType.proto ///////////////////////////// + +static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type); /* proto */ + +//////////////////////// MatchClassIsType ///////////////////////////// + +static PyTypeObject* __Pyx_MatchCase_IsType(PyObject* type) { + #if PY_MAJOR_VERSION < 3 + if (PyClass_Check(type)) { + // I don't really think it's worth the effort getting this to work! + PyErr_Format(PyExc_TypeError, "called match pattern must be a new-style class."); + return NULL; + } + #endif + if (!PyType_Check(type)) { + PyErr_Format(PyExc_TypeError, "called match pattern must be a type"); + return NULL; + } + Py_INCREF(type); + return (PyTypeObject*)type; +} diff --git a/Cython/Utility/MatchCase_Cy.pyx b/Cython/Utility/MatchCase_Cy.pyx new file mode 100644 index 000000000..dbb478ffe --- /dev/null +++ b/Cython/Utility/MatchCase_Cy.pyx @@ -0,0 +1,12 @@ +################### MemoryviewSliceToList ####################### + +cimport cython + +@cname("__Pyx_MatchCase_SliceMemoryview_{{suffix}}") +cdef list slice_to_list({{decl_code}} x, Py_ssize_t start, Py_ssize_t stop): + if stop < 0: + # use -1 as a flag for "end" + stop = x.shape[0] + # This code performs slightly better than [ xi for xi in x ] + with cython.boundscheck(False), cython.wraparound(False): + return [ x[i] for i in range(start, stop) ] diff --git a/Tools/ci-run.sh b/Tools/ci-run.sh index 905a9d1e3..ffde4cbe1 100644 --- a/Tools/ci-run.sh +++ b/Tools/ci-run.sh @@ -83,6 +83,8 @@ else python -m pip install -r test-requirements.txt || exit 1 if [[ $PYTHON_VERSION != "pypy"* && $PYTHON_VERSION != "3."[1]* ]]; then python -m pip install -r test-requirements-cpython.txt || exit 1 + elif [[ $PYTHON_VERSION == "pypy-2.7" ]]; then + python -m pip install -r test-requirements-pypy27.txt || exit 1 fi fi fi diff --git a/test-requirements-pypy27.txt b/test-requirements-pypy27.txt index 9f9505240..6d4f83bca 100644 --- a/test-requirements-pypy27.txt +++ b/test-requirements-pypy27.txt @@ -1,2 +1,3 @@ -r test-requirements.txt +enum34==1.1.10 mock==3.0.5 diff --git a/tests/run/extra_patma.pyx b/tests/run/extra_patma.pyx new file mode 100644 index 000000000..76357f36f --- /dev/null +++ b/tests/run/extra_patma.pyx @@ -0,0 +1,173 @@ +# mode: run + +# Extra pattern matching test for Cython-specific features, optimizations, etc. + +cimport cython + +import array +import sys + +__doc__ = "" + + +cdef bint is_null(int* x): + return False # disabled - currently just a parser test + match x: + case NULL: + return True + case _: + return False + + +def test_is_null(): + """ + >>> test_is_null() + """ + cdef int some_int = 1 + return # disabled - currently just a parser test + assert is_null(&some_int) == False + assert is_null(NULL) == True + + +if sys.version_info[0] > 2: + __doc__ += """ + array.array doesn't have the buffer protocol in Py2 and + this doesn't really feel worth working around to test + >>> print(test_memoryview(array.array('i', [0, 1, 2]))) + a 1 + >>> print(test_memoryview(array.array('i', []))) + b + >>> print(test_memoryview(array.array('i', [5]))) + c [5] + """ + +# goes via .shape instead +@cython.test_fail_if_path_exists("//CallNode//NameNode[@name = 'len']") +# No need for "is Sequence check" +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +def test_memoryview(int[:] x): + """ + >>> print(test_memoryview(None)) + no! + """ + match x: + case [0, y, 2]: + assert cython.typeof(y) == "int", cython.typeof(y) # type inference works + return f"a {y}" + case []: + return "b" + case [*z]: + return f"c {z}" + return "no!" + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +def test_list_to_sequence(list x): + """ + >>> test_list_to_sequence([1,2,3]) + True + >>> test_list_to_sequence(None) + False + """ + match x: + case [*_]: + return True + case _: + return False + + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +@cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds! +def test_list_not_None_to_sequence(list x not None): + """ + >>> test_list_not_None_to_sequence([1,2,3]) + True + """ + match x: + case [*_]: + return True + case _: + return False + +@cython.test_fail_if_path_exists("//PythonCapiCallNode//PythonCapiFunctionNode[@cname = '__Pyx_MatchCase_IsSequence']") +@cython.test_fail_if_path_exists("//CmpNode") # There's nothing to compare - it always succeeds! +def test_ctuple_to_sequence((int, int) x): + """ + >>> test_ctuple_to_sequence((1, 2)) + (1, 2) + """ + match x: + case [a, b, c]: # can't possibly succeed! + return a, b, c + case [a, b]: + assert cython.typeof(a) == "int", cython.typeof(a) # test that types have inferred + return a, b + +cdef class C: + cdef double x + def __init__(self, x): + self.x = x + +def class_attr_lookup(x): + """ + >>> class_attr_lookup(C(5)) + 5.0 + >>> class_attr_lookup([1]) + >>> class_attr_lookup(None) + """ + match x: + case C(x=y): # This can only work with cdef attribute lookup + assert cython.typeof(y) == "double", cython.typeof(y) + return y + +class PyClass(object): + pass + +@cython.test_assert_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']") +def class_typecheck_exists(x): + """ + Test exists to confirm that the unoptimized case makes an isinstance check + (and thus the optimized class_typecheck_exists is testing the right thing). + If the implementation changes to not use a call to "isinstance" this test + can happily be deleted + >>> class_typecheck_exists(5) + False + >>> class_typecheck_exists(PyClass()) + True + """ + match x: + case PyClass(): + return True + case _: + return False + + +@cython.test_fail_if_path_exists("//NameNode[@name='isinstance']") +@cython.test_fail_if_path_exists("//PythonCapiFunctionNode[@cname='__Pyx_TypeCheck']") +def class_typecheck_doesnt_exist(C x): + """ + >>> class_typecheck_doesnt_exist(C(5)) + True + >>> class_typecheck_doesnt_exist(None) # it is None-safe though! + False + """ + match x: + case C(): + return True + case _: + return False + +def simple_or_with_targets(x): + """ + This was being mishandled by being converted to an if statement + without accounting for target assignment + >>> simple_or_with_targets(1) + 1 + >>> simple_or_with_targets(2) + 2 + >>> simple_or_with_targets(3) + 3 + >>> simple_or_with_targets(4) + """ + match x: + case ((1 as y)|(2 as y)|(3 as y)): + return y diff --git a/tests/run/extra_patma_py.py b/tests/run/extra_patma_py.py new file mode 100644 index 000000000..a5046b997 --- /dev/null +++ b/tests/run/extra_patma_py.py @@ -0,0 +1,126 @@ +# mode: run +# tag: pure3.10 + +from __future__ import print_function + +import array +import sys + +__doc__ = "" + +def test_type_inference(x): + """ + The type should not be infered to be anything specific + >>> test_type_inference(1) + one 1 + >>> test_type_inference([]) + any object [] + """ + match x: + case 1 as a: + print("one", a) + case a: + print("any object", a) + + +def test_assignment_and_guards(x): + """ + Tests that the flow control is right. The second case can be + reached either by failing the pattern or by failing the guard, + and this affects whether variables are assigned + >>> test_assignment_and_guards([1]) + ('first', 1) + >>> test_assignment_and_guards([1, 2]) + ('second', 1) + >>> test_assignment_and_guards([-1, 2]) + ('second', -1) + """ + match x: + case [a] if a>0: + return "first", a + case [a, *_]: + return "second", a + + +def test_array_is_sequence(x): + """ + Because this has to be specifically special-cased on early Python versions + >>> test_array_is_sequence(array.array('i', [0, 1, 2])) + 1 + >>> test_array_is_sequence(array.array('i', [0, 1, 2, 3, 4])) + [0, 1, 2, 3, 4] + """ + match x: + case [0, y, 2]: + return y + case [*z]: + return z + case _: + return "Not a sequence" + + +def test_duplicate_keys(key1, key2): + """ + Extra to TestValueErrors in test_patma + Cython sorts keys into literal and runtime. This tests when two runtime keys clash + + >>> test_duplicate_keys("a", "b") + True + + Slightly awkward doctest to work around Py2 incompatibility + >>> try: + ... test_duplicate_keys("a", "a") + ... except ValueError as e: + ... if sys.version_info[0] > 2: + ... assert e.args[0] == "mapping pattern checks duplicate key ('a')", e.args[0] + ... else: + ... assert e.args[0] == "mapping pattern checks duplicate key" + """ + class Keys: + KEY_1 = key1 + KEY_2 = key2 + + match {"a": 1, "b": 2}: + case {Keys.KEY_1: _, Keys.KEY_2: _}: + return True + case _: + return False + + +class PyClass(object): + pass + + +class PrivateAttrLookupOuter: + """ + CPython doesn't mangle private names in class patterns + (so Cython should do the same) + + >>> py_class_inst = PyClass() + >>> py_class_inst._PyClass__something = 1 + >>> py_class_inst._PrivateAttrLookupOuter__something = 2 + >>> py_class_inst.__something = 3 + >>> PrivateAttrLookupOuter().f(py_class_inst) + 3 + """ + def f(self, x): + match x: + case PyClass(__something=y): + return y + + +if sys.version_info[0] < 3: + class OldStyleClass: + pass + + def test_oldstyle_class_failure(x): + match x: + case OldStyleClass(): + return True + + __doc__ += """ + >>> test_oldstyle_class_failure(1) + Traceback (most recent call last): + ... + TypeError: called match pattern must be a new-style class. + """ diff --git a/tests/run/test_patma.py b/tests/run/test_patma.py new file mode 100644 index 000000000..e51ba0dbd --- /dev/null +++ b/tests/run/test_patma.py @@ -0,0 +1,3198 @@ +### COPIED FROM CPython 3.12 alpha (July 2022) +### Original part after ############ +# cython: language_level=3 + +# new code +import cython +from Cython.TestUtils import py_parse_code + + +if cython.compiled: + def compile(code, name, what): + assert what == 'exec' + py_parse_code(code) + + +def disable(func): + pass + + +############## SLIGHTLY MODIFIED ORIGINAL CODE +import array +import collections +import enum +import inspect +import sys +import unittest + +if sys.version_info > (3, 10): + import dataclasses + @dataclasses.dataclass + class Point: + x: int + y: int +else: + # predates dataclasses with match args + class Point: + __match_args__ = ("x", "y") + x: int + y: int + def __init__(self, x, y): + self.x = x + self.y = y + def __eq__(self, other): + if not isinstance(other, Point): + return False + return self.x == other.x and self.y == other.y + +# TestCompiler removed - it's very CPython-specific +# TestTracing also mainly removed - doesn't seem like a core test +# except for one test that seems misplaced in CPython (which is below) + +class TestTracing(unittest.TestCase): + if sys.version_info < (3, 4): + class SubTestClass(object): + def __enter__(self): + return self + def __exit__(self, exc_type, exc_value, traceback): + return + def __call__(self, *args): + return self + subTest = SubTestClass() + + def test_parser_deeply_nested_patterns(self): + # Deeply nested patterns can cause exponential backtracking when parsing. + # See CPython gh-93671 for more information. + # + # DW: Cython note - this doesn't break the parser but may cause a + # RecursionError later in the code-generation. I don't believe that's + # easily avoidable with the way Cython visitors currently work + + levels = 100 + + patterns = [ + "A" + "(" * levels + ")" * levels, + "{1:" * levels + "1" + "}" * levels, + "[" * levels + "1" + "]" * levels, + ] + + for pattern in patterns: + with self.subTest(pattern): + code = inspect.cleandoc(""" + match None: + case {}: + pass + """.format(pattern)) + compile(code, "<string>", "exec") + + +# FIXME - remove all the "return"s added to cause code to be dropped +############## ORIGINAL PART FROM CPYTHON + + +class TestInheritance(unittest.TestCase): + + @staticmethod + def check_sequence_then_mapping(x): + match x: + case [*_]: + return "seq" + case {}: + return "map" + + @staticmethod + def check_mapping_then_sequence(x): + match x: + case {}: + return "map" + case [*_]: + return "seq" + + def test_multiple_inheritance_mapping(self): + class C: + pass + class M1(collections.UserDict, collections.abc.Sequence): + pass + class M2(C, collections.UserDict, collections.abc.Sequence): + pass + class M3(collections.UserDict, C, list): + pass + class M4(dict, collections.abc.Sequence, C): + pass + self.assertEqual(self.check_sequence_then_mapping(M1()), "map") + self.assertEqual(self.check_sequence_then_mapping(M2()), "map") + self.assertEqual(self.check_sequence_then_mapping(M3()), "map") + self.assertEqual(self.check_sequence_then_mapping(M4()), "map") + self.assertEqual(self.check_mapping_then_sequence(M1()), "map") + self.assertEqual(self.check_mapping_then_sequence(M2()), "map") + self.assertEqual(self.check_mapping_then_sequence(M3()), "map") + self.assertEqual(self.check_mapping_then_sequence(M4()), "map") + + def test_multiple_inheritance_sequence(self): + class C: + pass + class S1(collections.UserList, collections.abc.Mapping): + pass + class S2(C, collections.UserList, collections.abc.Mapping): + pass + class S3(list, C, collections.abc.Mapping): + pass + class S4(collections.UserList, dict, C): + pass + self.assertEqual(self.check_sequence_then_mapping(S1()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S2()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S3()), "seq") + self.assertEqual(self.check_sequence_then_mapping(S4()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S1()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S2()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S3()), "seq") + self.assertEqual(self.check_mapping_then_sequence(S4()), "seq") + + def test_late_registration_mapping(self): + class Parent: + pass + class ChildPre(Parent): + pass + class GrandchildPre(ChildPre): + pass + collections.abc.Mapping.register(Parent) + class ChildPost(Parent): + pass + class GrandchildPost(ChildPost): + pass + self.assertEqual(self.check_sequence_then_mapping(Parent()), "map") + self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "map") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "map") + self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "map") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "map") + self.assertEqual(self.check_mapping_then_sequence(Parent()), "map") + self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "map") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "map") + self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map") + + def test_late_registration_sequence(self): + class Parent: + pass + class ChildPre(Parent): + pass + class GrandchildPre(ChildPre): + pass + collections.abc.Sequence.register(Parent) + class ChildPost(Parent): + pass + class GrandchildPost(ChildPost): + pass + self.assertEqual(self.check_sequence_then_mapping(Parent()), "seq") + self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "seq") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "seq") + self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "seq") + self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "seq") + self.assertEqual(self.check_mapping_then_sequence(Parent()), "seq") + self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "seq") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "seq") + self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "seq") + self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "seq") + + +class TestPatma(unittest.TestCase): + + def test_patma_000(self): + match 0: + case 0: + x = True + self.assertIs(x, True) + + def test_patma_001(self): + match 0: + case 0 if False: + x = False + case 0 if True: + x = True + self.assertIs(x, True) + + def test_patma_002(self): + match 0: + case 0: + x = True + case 0: + x = False + self.assertIs(x, True) + + def test_patma_003(self): + x = False + match 0: + case 0 | 1 | 2 | 3: + x = True + self.assertIs(x, True) + + def test_patma_004(self): + x = False + match 1: + case 0 | 1 | 2 | 3: + x = True + self.assertIs(x, True) + + def test_patma_005(self): + x = False + match 2: + case 0 | 1 | 2 | 3: + x = True + self.assertIs(x, True) + + def test_patma_006(self): + x = False + match 3: + case 0 | 1 | 2 | 3: + x = True + self.assertIs(x, True) + + def test_patma_007(self): + x = False + match 4: + case 0 | 1 | 2 | 3: + x = True + self.assertIs(x, False) + + def test_patma_008(self): + x = 0 + class A: + y = 1 + match x: + case A.y as z: + pass + self.assertEqual(x, 0) + self.assertEqual(A.y, 1) + + def test_patma_009(self): + class A: + B = 0 + match 0: + case x if x: + z = 0 + case _ as y if y == x and y: + z = 1 + case A.B: + z = 2 + self.assertEqual(A.B, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 2) + + def test_patma_010(self): + match (): + case []: + x = 0 + self.assertEqual(x, 0) + + def test_patma_011(self): + match (0, 1, 2): + case [*x]: + y = 0 + self.assertEqual(x, [0, 1, 2]) + self.assertEqual(y, 0) + + def test_patma_012(self): + match (0, 1, 2): + case [0, *x]: + y = 0 + self.assertEqual(x, [1, 2]) + self.assertEqual(y, 0) + + def test_patma_013(self): + match (0, 1, 2): + case [0, 1, *x,]: + y = 0 + self.assertEqual(x, [2]) + self.assertEqual(y, 0) + + def test_patma_014(self): + match (0, 1, 2): + case [0, 1, 2, *x]: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_015(self): + match (0, 1, 2): + case [*x, 2,]: + y = 0 + self.assertEqual(x, [0, 1]) + self.assertEqual(y, 0) + + def test_patma_016(self): + match (0, 1, 2): + case [*x, 1, 2]: + y = 0 + self.assertEqual(x, [0]) + self.assertEqual(y, 0) + + def test_patma_017(self): + match (0, 1, 2): + case [*x, 0, 1, 2,]: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_018(self): + match (0, 1, 2): + case [0, *x, 2]: + y = 0 + self.assertEqual(x, [1]) + self.assertEqual(y, 0) + + def test_patma_019(self): + match (0, 1, 2): + case [0, 1, *x, 2,]: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_020(self): + match (0, 1, 2): + case [0, *x, 1, 2]: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_021(self): + match (0, 1, 2): + case [*x,]: + y = 0 + self.assertEqual(x, [0, 1, 2]) + self.assertEqual(y, 0) + + def test_patma_022(self): + x = {} + match x: + case {}: + y = 0 + self.assertEqual(x, {}) + self.assertEqual(y, 0) + + def test_patma_023(self): + x = {0: 0} + match x: + case {}: + y = 0 + self.assertEqual(x, {0: 0}) + self.assertEqual(y, 0) + + def test_patma_024(self): + x = {} + y = None + match x: + case {0: 0}: + y = 0 + self.assertEqual(x, {}) + self.assertIs(y, None) + + def test_patma_025(self): + x = {0: 0} + match x: + case {0: (0 | 1 | 2 as z)}: + y = 0 + self.assertEqual(x, {0: 0}) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_026(self): + x = {0: 1} + match x: + case {0: (0 | 1 | 2 as z)}: + y = 0 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 0) + self.assertEqual(z, 1) + + def test_patma_027(self): + x = {0: 2} + match x: + case {0: (0 | 1 | 2 as z)}: + y = 0 + self.assertEqual(x, {0: 2}) + self.assertEqual(y, 0) + self.assertEqual(z, 2) + + def test_patma_028(self): + x = {0: 3} + y = None + match x: + case {0: (0 | 1 | 2 as z)}: + y = 0 + self.assertEqual(x, {0: 3}) + self.assertIs(y, None) + + def test_patma_029(self): + x = {} + y = None + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: [1, 2, {}], 1: [[]]}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {}) + self.assertIs(y, None) + + def test_patma_030(self): + x = {False: (True, 2.0, {})} + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: [1, 2, {}], 1: [[]]}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {False: (True, 2.0, {})}) + self.assertEqual(y, 0) + + def test_patma_031(self): + x = {False: (True, 2.0, {}), 1: [[]], 2: 0} + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: [1, 2, {}], 1: [[]]}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {False: (True, 2.0, {}), 1: [[]], 2: 0}) + self.assertEqual(y, 0) + + def test_patma_032(self): + x = {False: (True, 2.0, {}), 1: [[]], 2: 0} + match x: + case {0: [1, 2]}: + y = 0 + case {0: [1, 2, {}], 1: [[]]}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {False: (True, 2.0, {}), 1: [[]], 2: 0}) + self.assertEqual(y, 1) + + def test_patma_033(self): + x = [] + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: [1, 2, {}], 1: [[]]}: + y = 1 + case []: + y = 2 + self.assertEqual(x, []) + self.assertEqual(y, 2) + + def test_patma_034(self): + x = {0: 0} + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: ([1, 2, {}] | False)} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {0: 0}) + self.assertEqual(y, 1) + + def test_patma_035(self): + x = {0: 0} + match x: + case {0: [1, 2, {}]}: + y = 0 + case {0: [1, 2, {}] | True} | {1: [[]]} | {0: [1, 2, {}]} | [] | "X" | {}: + y = 1 + case []: + y = 2 + self.assertEqual(x, {0: 0}) + self.assertEqual(y, 1) + + def test_patma_036(self): + x = 0 + match x: + case 0 | 1 | 2: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_037(self): + x = 1 + match x: + case 0 | 1 | 2: + y = 0 + self.assertEqual(x, 1) + self.assertEqual(y, 0) + + def test_patma_038(self): + x = 2 + match x: + case 0 | 1 | 2: + y = 0 + self.assertEqual(x, 2) + self.assertEqual(y, 0) + + def test_patma_039(self): + x = 3 + y = None + match x: + case 0 | 1 | 2: + y = 0 + self.assertEqual(x, 3) + self.assertIs(y, None) + + def test_patma_040(self): + x = 0 + match x: + case (0 as z) | (1 as z) | (2 as z) if z == x % 2: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_041(self): + x = 1 + match x: + case (0 as z) | (1 as z) | (2 as z) if z == x % 2: + y = 0 + self.assertEqual(x, 1) + self.assertEqual(y, 0) + self.assertEqual(z, 1) + + def test_patma_042(self): + x = 2 + y = None + match x: + case (0 as z) | (1 as z) | (2 as z) if z == x % 2: + y = 0 + self.assertEqual(x, 2) + self.assertIs(y, None) + self.assertEqual(z, 2) + + def test_patma_043(self): + x = 3 + y = None + match x: + case (0 as z) | (1 as z) | (2 as z) if z == x % 2: + y = 0 + self.assertEqual(x, 3) + self.assertIs(y, None) + + def test_patma_044(self): + x = () + match x: + case []: + y = 0 + self.assertEqual(x, ()) + self.assertEqual(y, 0) + + def test_patma_045(self): + x = () + match x: + case (): + y = 0 + self.assertEqual(x, ()) + self.assertEqual(y, 0) + + def test_patma_046(self): + x = (0,) + match x: + case [0]: + y = 0 + self.assertEqual(x, (0,)) + self.assertEqual(y, 0) + + def test_patma_047(self): + x = ((),) + match x: + case [[]]: + y = 0 + self.assertEqual(x, ((),)) + self.assertEqual(y, 0) + + def test_patma_048(self): + x = [0, 1] + match x: + case [0, 1] | [1, 0]: + y = 0 + self.assertEqual(x, [0, 1]) + self.assertEqual(y, 0) + + def test_patma_049(self): + x = [1, 0] + match x: + case [0, 1] | [1, 0]: + y = 0 + self.assertEqual(x, [1, 0]) + self.assertEqual(y, 0) + + def test_patma_050(self): + x = [0, 0] + y = None + match x: + case [0, 1] | [1, 0]: + y = 0 + self.assertEqual(x, [0, 0]) + self.assertIs(y, None) + + def test_patma_051(self): + w = None + x = [1, 0] + match x: + case [(0 as w)]: + y = 0 + case [z] | [1, (0 | 1 as z)] | [z]: + y = 1 + self.assertIs(w, None) + self.assertEqual(x, [1, 0]) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_052(self): + x = [1, 0] + match x: + case [0]: + y = 0 + case [1, 0] if (x := x[:0]): + y = 1 + case [1, 0]: + y = 2 + self.assertEqual(x, []) + self.assertEqual(y, 2) + + def test_patma_053(self): + x = {0} + y = None + match x: + case [0]: + y = 0 + self.assertEqual(x, {0}) + self.assertIs(y, None) + + def test_patma_054(self): + x = set() + y = None + match x: + case []: + y = 0 + self.assertEqual(x, set()) + self.assertIs(y, None) + + def test_patma_055(self): + x = iter([1, 2, 3]) + y = None + match x: + case []: + y = 0 + self.assertEqual([*x], [1, 2, 3]) + self.assertIs(y, None) + + def test_patma_056(self): + x = {} + y = None + match x: + case []: + y = 0 + self.assertEqual(x, {}) + self.assertIs(y, None) + + def test_patma_057(self): + x = {0: False, 1: True} + y = None + match x: + case [0, 1]: + y = 0 + self.assertEqual(x, {0: False, 1: True}) + self.assertIs(y, None) + + def test_patma_058(self): + x = 0 + match x: + case 0: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_059(self): + x = 0 + y = None + match x: + case False: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, None) + + def test_patma_060(self): + x = 0 + y = None + match x: + case 1: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_061(self): + x = 0 + y = None + match x: + case None: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_062(self): + x = 0 + match x: + case 0: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_063(self): + x = 0 + y = None + match x: + case 1: + y = 0 + case 1: + y = 1 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_064(self): + x = "x" + match x: + case "x": + y = 0 + case "y": + y = 1 + self.assertEqual(x, "x") + self.assertEqual(y, 0) + + def test_patma_065(self): + x = "x" + match x: + case "y": + y = 0 + case "x": + y = 1 + self.assertEqual(x, "x") + self.assertEqual(y, 1) + + def test_patma_066(self): + x = "x" + match x: + case "": + y = 0 + case "x": + y = 1 + self.assertEqual(x, "x") + self.assertEqual(y, 1) + + def test_patma_067(self): + x = b"x" + match x: + case b"y": + y = 0 + case b"x": + y = 1 + self.assertEqual(x, b"x") + self.assertEqual(y, 1) + + def test_patma_068(self): + x = 0 + match x: + case 0 if False: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_069(self): + x = 0 + y = None + match x: + case 0 if 0: + y = 0 + case 0 if 0: + y = 1 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_070(self): + x = 0 + match x: + case 0 if True: + y = 0 + case 0 if True: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_071(self): + x = 0 + match x: + case 0 if 1: + y = 0 + case 0 if 1: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_072(self): + x = 0 + match x: + case 0 if True: + y = 0 + case 0 if True: + y = 1 + y = 2 + self.assertEqual(x, 0) + self.assertEqual(y, 2) + + def test_patma_073(self): + x = 0 + match x: + case 0 if 0: + y = 0 + case 0 if 1: + y = 1 + y = 2 + self.assertEqual(x, 0) + self.assertEqual(y, 2) + + def test_patma_074(self): + x = 0 + y = None + match x: + case 0 if not (x := 1): + y = 0 + case 1: + y = 1 + self.assertEqual(x, 1) + self.assertIs(y, None) + + def test_patma_075(self): + x = "x" + match x: + case ["x"]: + y = 0 + case "x": + y = 1 + self.assertEqual(x, "x") + self.assertEqual(y, 1) + + def test_patma_076(self): + x = b"x" + match x: + case [b"x"]: + y = 0 + case ["x"]: + y = 1 + case [120]: + y = 2 + case b"x": + y = 4 + self.assertEqual(x, b"x") + self.assertEqual(y, 4) + + def test_patma_077(self): + x = bytearray(b"x") + y = None + match x: + case [120]: + y = 0 + case 120: + y = 1 + self.assertEqual(x, b"x") + self.assertIs(y, None) + + def test_patma_078(self): + x = "" + match x: + case []: + y = 0 + case [""]: + y = 1 + case "": + y = 2 + self.assertEqual(x, "") + self.assertEqual(y, 2) + + def test_patma_079(self): + x = "xxx" + match x: + case ["x", "x", "x"]: + y = 0 + case ["xxx"]: + y = 1 + case "xxx": + y = 2 + self.assertEqual(x, "xxx") + self.assertEqual(y, 2) + + def test_patma_080(self): + x = b"xxx" + match x: + case [120, 120, 120]: + y = 0 + case [b"xxx"]: + y = 1 + case b"xxx": + y = 2 + self.assertEqual(x, b"xxx") + self.assertEqual(y, 2) + + def test_patma_081(self): + x = 0 + match x: + case 0 if not (x := 1): + y = 0 + case (0 as z): + y = 1 + self.assertEqual(x, 1) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_082(self): + x = 0 + match x: + case (1 as z) if not (x := 1): + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_083(self): + x = 0 + match x: + case (0 as z): + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_084(self): + x = 0 + y = None + match x: + case (1 as z): + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_085(self): + x = 0 + y = None + match x: + case (0 as z) if (w := 0): + y = 0 + self.assertEqual(w, 0) + self.assertEqual(x, 0) + self.assertIs(y, None) + self.assertEqual(z, 0) + + def test_patma_086(self): + x = 0 + match x: + case ((0 as w) as z): + y = 0 + self.assertEqual(w, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_087(self): + x = 0 + match x: + case (0 | 1) | 2: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_088(self): + x = 1 + match x: + case (0 | 1) | 2: + y = 0 + self.assertEqual(x, 1) + self.assertEqual(y, 0) + + def test_patma_089(self): + x = 2 + match x: + case (0 | 1) | 2: + y = 0 + self.assertEqual(x, 2) + self.assertEqual(y, 0) + + def test_patma_090(self): + x = 3 + y = None + match x: + case (0 | 1) | 2: + y = 0 + self.assertEqual(x, 3) + self.assertIs(y, None) + + def test_patma_091(self): + x = 0 + match x: + case 0 | (1 | 2): + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_092(self): + x = 1 + match x: + case 0 | (1 | 2): + y = 0 + self.assertEqual(x, 1) + self.assertEqual(y, 0) + + def test_patma_093(self): + x = 2 + match x: + case 0 | (1 | 2): + y = 0 + self.assertEqual(x, 2) + self.assertEqual(y, 0) + + def test_patma_094(self): + x = 3 + y = None + match x: + case 0 | (1 | 2): + y = 0 + self.assertEqual(x, 3) + self.assertIs(y, None) + + def test_patma_095(self): + x = 0 + match x: + case -0: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_096(self): + x = 0 + match x: + case -0.0: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_097(self): + x = 0 + match x: + case -0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_098(self): + x = 0 + match x: + case -0.0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_099(self): + x = -1 + match x: + case -1: + y = 0 + self.assertEqual(x, -1) + self.assertEqual(y, 0) + + def test_patma_100(self): + x = -1.5 + match x: + case -1.5: + y = 0 + self.assertEqual(x, -1.5) + self.assertEqual(y, 0) + + def test_patma_101(self): + x = -1j + match x: + case -1j: + y = 0 + self.assertEqual(x, -1j) + self.assertEqual(y, 0) + + def test_patma_102(self): + x = -1.5j + match x: + case -1.5j: + y = 0 + self.assertEqual(x, -1.5j) + self.assertEqual(y, 0) + + def test_patma_103(self): + x = 0 + match x: + case 0 + 0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_104(self): + x = 0 + match x: + case 0 - 0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_105(self): + x = 0 + match x: + case -0 + 0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_106(self): + x = 0 + match x: + case -0 - 0j: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_107(self): + x = 0.25 + 1.75j + match x: + case 0.25 + 1.75j: + y = 0 + self.assertEqual(x, 0.25 + 1.75j) + self.assertEqual(y, 0) + + def test_patma_108(self): + x = 0.25 - 1.75j + match x: + case 0.25 - 1.75j: + y = 0 + self.assertEqual(x, 0.25 - 1.75j) + self.assertEqual(y, 0) + + def test_patma_109(self): + x = -0.25 + 1.75j + match x: + case -0.25 + 1.75j: + y = 0 + self.assertEqual(x, -0.25 + 1.75j) + self.assertEqual(y, 0) + + def test_patma_110(self): + x = -0.25 - 1.75j + match x: + case -0.25 - 1.75j: + y = 0 + self.assertEqual(x, -0.25 - 1.75j) + self.assertEqual(y, 0) + + def test_patma_111(self): + class A: + B = 0 + x = 0 + match x: + case A.B: + y = 0 + self.assertEqual(A.B, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_112(self): + class A: + class B: + C = 0 + x = 0 + match x: + case A.B.C: + y = 0 + self.assertEqual(A.B.C, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_113(self): + class A: + class B: + C = 0 + D = 1 + x = 1 + match x: + case A.B.C: + y = 0 + case A.B.D: + y = 1 + self.assertEqual(A.B.C, 0) + self.assertEqual(A.B.D, 1) + self.assertEqual(x, 1) + self.assertEqual(y, 1) + + def test_patma_114(self): + class A: + class B: + class C: + D = 0 + x = 0 + match x: + case A.B.C.D: + y = 0 + self.assertEqual(A.B.C.D, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_115(self): + class A: + class B: + class C: + D = 0 + E = 1 + x = 1 + match x: + case A.B.C.D: + y = 0 + case A.B.C.E: + y = 1 + self.assertEqual(A.B.C.D, 0) + self.assertEqual(A.B.C.E, 1) + self.assertEqual(x, 1) + self.assertEqual(y, 1) + + def test_patma_116(self): + match = case = 0 + match match: + case case: + x = 0 + self.assertEqual(match, 0) + self.assertEqual(case, 0) + self.assertEqual(x, 0) + + def test_patma_117(self): + match = case = 0 + match case: + case match: + x = 0 + self.assertEqual(match, 0) + self.assertEqual(case, 0) + self.assertEqual(x, 0) + + def test_patma_118(self): + x = [] + match x: + case [*_, _]: + y = 0 + case []: + y = 1 + self.assertEqual(x, []) + self.assertEqual(y, 1) + + def test_patma_119(self): + x = collections.defaultdict(int) + match x: + case {0: 0}: + y = 0 + case {}: + y = 1 + self.assertEqual(x, {}) + self.assertEqual(y, 1) + + def test_patma_120(self): + x = collections.defaultdict(int) + match x: + case {0: 0}: + y = 0 + case {**z}: + y = 1 + self.assertEqual(x, {}) + self.assertEqual(y, 1) + self.assertEqual(z, {}) + + def test_patma_121(self): + match (): + case (): + x = 0 + self.assertEqual(x, 0) + + def test_patma_122(self): + match (0, 1, 2): + case (*x,): + y = 0 + self.assertEqual(x, [0, 1, 2]) + self.assertEqual(y, 0) + + def test_patma_123(self): + match (0, 1, 2): + case 0, *x: + y = 0 + self.assertEqual(x, [1, 2]) + self.assertEqual(y, 0) + + def test_patma_124(self): + match (0, 1, 2): + case (0, 1, *x,): + y = 0 + self.assertEqual(x, [2]) + self.assertEqual(y, 0) + + def test_patma_125(self): + match (0, 1, 2): + case 0, 1, 2, *x: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_126(self): + match (0, 1, 2): + case *x, 2,: + y = 0 + self.assertEqual(x, [0, 1]) + self.assertEqual(y, 0) + + def test_patma_127(self): + match (0, 1, 2): + case (*x, 1, 2): + y = 0 + self.assertEqual(x, [0]) + self.assertEqual(y, 0) + + def test_patma_128(self): + match (0, 1, 2): + case *x, 0, 1, 2,: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_129(self): + match (0, 1, 2): + case (0, *x, 2): + y = 0 + self.assertEqual(x, [1]) + self.assertEqual(y, 0) + + def test_patma_130(self): + match (0, 1, 2): + case 0, 1, *x, 2,: + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_131(self): + match (0, 1, 2): + case (0, *x, 1, 2): + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + + def test_patma_132(self): + match (0, 1, 2): + case *x,: + y = 0 + self.assertEqual(x, [0, 1, 2]) + self.assertEqual(y, 0) + + def test_patma_133(self): + x = collections.defaultdict(int, {0: 1}) + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 1 + case {}: + y = 2 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 2) + + def test_patma_134(self): + x = collections.defaultdict(int, {0: 1}) + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 1 + case {**z}: + y = 2 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 2) + self.assertEqual(z, {0: 1}) + + def test_patma_135(self): + x = collections.defaultdict(int, {0: 1}) + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 1 + case {0: _, **z}: + y = 2 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 2) + self.assertEqual(z, {}) + + def test_patma_136(self): + x = {0: 1} + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 0 + case {}: + y = 1 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 1) + + def test_patma_137(self): + x = {0: 1} + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 0 + case {**z}: + y = 1 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 1) + self.assertEqual(z, {0: 1}) + + def test_patma_138(self): + x = {0: 1} + match x: + case {1: 0}: + y = 0 + case {0: 0}: + y = 0 + case {0: _, **z}: + y = 1 + self.assertEqual(x, {0: 1}) + self.assertEqual(y, 1) + self.assertEqual(z, {}) + + def test_patma_139(self): + x = False + match x: + case bool(z): + y = 0 + self.assertIs(x, False) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_140(self): + x = True + match x: + case bool(z): + y = 0 + self.assertIs(x, True) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_141(self): + x = bytearray() + match x: + case bytearray(z): + y = 0 + self.assertEqual(x, bytearray()) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_142(self): + x = b"" + match x: + case bytes(z): + y = 0 + self.assertEqual(x, b"") + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_143(self): + x = {} + match x: + case dict(z): + y = 0 + self.assertEqual(x, {}) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_144(self): + x : object = 0.0 # Cython-specific change. Otherwise x is inferred as int + # which makes assertIs(z, x) fail + match x: + case float(z): + y = 0 + self.assertEqual(x, 0.0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_145(self): + x = frozenset() + match x: + case frozenset(z): + y = 0 + self.assertEqual(x, frozenset()) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_146(self): + x = 0 + match x: + case int(z): + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_147(self): + x = [] + match x: + case list(z): + y = 0 + self.assertEqual(x, []) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_148(self): + x = set() + match x: + case set(z): + y = 0 + self.assertEqual(x, set()) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_149(self): + x = "" + match x: + case str(z): + y = 0 + self.assertEqual(x, "") + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_150(self): + x = () + match x: + case tuple(z): + y = 0 + self.assertEqual(x, ()) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_151(self): + x = 0 + match x,: + case y,: + z = 0 + self.assertEqual(x, 0) + self.assertIs(y, x) + self.assertIs(z, 0) + + def test_patma_152(self): + w = 0 + x = 0 + match w, x: + case y, z: + v = 0 + self.assertEqual(w, 0) + self.assertEqual(x, 0) + self.assertIs(y, w) + self.assertIs(z, x) + self.assertEqual(v, 0) + + def test_patma_153(self): + x = 0 + match w := x,: + case y as v,: + z = 0 + self.assertEqual(x, 0) + self.assertIs(y, x) + self.assertEqual(z, 0) + self.assertIs(w, x) + self.assertIs(v, y) + + def test_patma_154(self): + x = 0 + y = None + match x: + case 0 if x: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_155(self): + x = 0 + y = None + match x: + case 1e1000: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_156(self): + x = 0 + match x: + case z: + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_157(self): + x = 0 + y = None + match x: + case _ if x: + y = 0 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_158(self): + x = 0 + match x: + case -1e1000: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_159(self): + x = 0 + match x: + case 0 if not x: + y = 0 + case 1: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_160(self): + x = 0 + z = None + match x: + case 0: + y = 0 + case z if x: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, None) + + def test_patma_161(self): + x = 0 + match x: + case 0: + y = 0 + case _: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_162(self): + x = 0 + match x: + case 1 if x: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_163(self): + x = 0 + y = None + match x: + case 1: + y = 0 + case 1 if not x: + y = 1 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_164(self): + x = 0 + match x: + case 1: + y = 0 + case z: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertIs(z, x) + + def test_patma_165(self): + x = 0 + match x: + case 1 if x: + y = 0 + case _: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_166(self): + x = 0 + match x: + case z if not z: + y = 0 + case 0 if x: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_167(self): + x = 0 + match x: + case z if not z: + y = 0 + case 1: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_168(self): + x = 0 + match x: + case z if not x: + y = 0 + case z: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_169(self): + x = 0 + match x: + case z if not z: + y = 0 + case _ if x: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, x) + + def test_patma_170(self): + x = 0 + match x: + case _ if not x: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_171(self): + x = 0 + y = None + match x: + case _ if x: + y = 0 + case 1: + y = 1 + self.assertEqual(x, 0) + self.assertIs(y, None) + + def test_patma_172(self): + x = 0 + z = None + match x: + case _ if not x: + y = 0 + case z if not x: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertIs(z, None) + + def test_patma_173(self): + x = 0 + match x: + case _ if not x: + y = 0 + case _: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_174(self): + def http_error(status): + match status: + case 400: + return "Bad request" + case 401: + return "Unauthorized" + case 403: + return "Forbidden" + case 404: + return "Not found" + case 418: + return "I'm a teapot" + case _: + return "Something else" + self.assertEqual(http_error(400), "Bad request") + self.assertEqual(http_error(401), "Unauthorized") + self.assertEqual(http_error(403), "Forbidden") + self.assertEqual(http_error(404), "Not found") + self.assertEqual(http_error(418), "I'm a teapot") + self.assertEqual(http_error(123), "Something else") + self.assertEqual(http_error("400"), "Something else") + self.assertEqual(http_error(401 | 403 | 404), "Something else") # 407 + + def test_patma_175(self): + def http_error(status): + match status: + case 400: + return "Bad request" + case 401 | 403 | 404: + return "Not allowed" + case 418: + return "I'm a teapot" + self.assertEqual(http_error(400), "Bad request") + self.assertEqual(http_error(401), "Not allowed") + self.assertEqual(http_error(403), "Not allowed") + self.assertEqual(http_error(404), "Not allowed") + self.assertEqual(http_error(418), "I'm a teapot") + self.assertIs(http_error(123), None) + self.assertIs(http_error("400"), None) + self.assertIs(http_error(401 | 403 | 404), None) # 407 + + def test_patma_176(self): + def whereis(point): + match point: + case (0, 0): + return "Origin" + case (0, y): + return f"Y={y}" + case (x, 0): + return f"X={x}" + case (x, y): + return f"X={x}, Y={y}" + case _: + return "Not a point" + self.assertEqual(whereis((0, 0)), "Origin") + self.assertEqual(whereis((0, -1.0)), "Y=-1.0") + self.assertEqual(whereis(("X", 0)), "X=X") + self.assertEqual(whereis((None, 1j)), "X=None, Y=1j") + self.assertEqual(whereis(42), "Not a point") + + def test_patma_177(self): + def whereis(point): + match point: + case Point(0, 0): + return "Origin" + case Point(0, y): + return f"Y={y}" + case Point(x, 0): + return f"X={x}" + case Point(): + return "Somewhere else" + case _: + return "Not a point" + self.assertEqual(whereis(Point(1, 0)), "X=1") + self.assertEqual(whereis(Point(0, 0)), "Origin") + self.assertEqual(whereis(10), "Not a point") + self.assertEqual(whereis(Point(False, False)), "Origin") + self.assertEqual(whereis(Point(0, -1.0)), "Y=-1.0") + self.assertEqual(whereis(Point("X", 0)), "X=X") + self.assertEqual(whereis(Point(None, 1j)), "Somewhere else") + self.assertEqual(whereis(Point), "Not a point") + self.assertEqual(whereis(42), "Not a point") + + def test_patma_178(self): + def whereis(point): + match point: + case Point(1, var): + return var + self.assertEqual(whereis(Point(1, 0)), 0) + self.assertIs(whereis(Point(0, 0)), None) + + def test_patma_179(self): + def whereis(point): + match point: + case Point(1, y=var): + return var + self.assertEqual(whereis(Point(1, 0)), 0) + self.assertIs(whereis(Point(0, 0)), None) + + def test_patma_180(self): + def whereis(point): + match point: + case Point(x=1, y=var): + return var + self.assertEqual(whereis(Point(1, 0)), 0) + self.assertIs(whereis(Point(0, 0)), None) + + def test_patma_181(self): + def whereis(point): + match point: + case Point(y=var, x=1): + return var + self.assertEqual(whereis(Point(1, 0)), 0) + self.assertIs(whereis(Point(0, 0)), None) + + def test_patma_182(self): + def whereis(points): + match points: + case []: + return "No points" + case [Point(0, 0)]: + return "The origin" + case [Point(x, y)]: + return f"Single point {x}, {y}" + case [Point(0, y1), Point(0, y2)]: + return f"Two on the Y axis at {y1}, {y2}" + case _: + return "Something else" + self.assertEqual(whereis([]), "No points") + self.assertEqual(whereis([Point(0, 0)]), "The origin") + self.assertEqual(whereis([Point(0, 1)]), "Single point 0, 1") + self.assertEqual(whereis([Point(0, 0), Point(0, 0)]), "Two on the Y axis at 0, 0") + self.assertEqual(whereis([Point(0, 1), Point(0, 1)]), "Two on the Y axis at 1, 1") + self.assertEqual(whereis([Point(0, 0), Point(1, 0)]), "Something else") + self.assertEqual(whereis([Point(0, 0), Point(0, 0), Point(0, 0)]), "Something else") + self.assertEqual(whereis([Point(0, 1), Point(0, 1), Point(0, 1)]), "Something else") + + def test_patma_183(self): + def whereis(point): + match point: + case Point(x, y) if x == y: + return f"Y=X at {x}" + case Point(x, y): + return "Not on the diagonal" + self.assertEqual(whereis(Point(0, 0)), "Y=X at 0") + self.assertEqual(whereis(Point(0, False)), "Y=X at 0") + self.assertEqual(whereis(Point(False, 0)), "Y=X at False") + self.assertEqual(whereis(Point(-1 - 1j, -1 - 1j)), "Y=X at (-1-1j)") + self.assertEqual(whereis(Point("X", "X")), "Y=X at X") + self.assertEqual(whereis(Point("X", "x")), "Not on the diagonal") + + def test_patma_184(self): + class Seq(collections.abc.Sequence): + __getitem__ = None + def __len__(self): + return 0 + match Seq(): + case []: + y = 0 + self.assertEqual(y, 0) + + def test_patma_185(self): + class Seq(collections.abc.Sequence): + __getitem__ = None + def __len__(self): + return 42 + match Seq(): + case [*_]: + y = 0 + self.assertEqual(y, 0) + + def test_patma_186(self): + class Seq(collections.abc.Sequence): + def __getitem__(self, i): + return i + def __len__(self): + return 42 + match Seq(): + case [x, *_, y]: + z = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 41) + self.assertEqual(z, 0) + + def test_patma_187(self): + w = range(10) + match w: + case [x, y, *rest]: + z = 0 + self.assertEqual(w, range(10)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + self.assertEqual(rest, list(range(2, 10))) + + def test_patma_188(self): + w = range(100) + match w: + case (x, y, *rest): + z = 0 + self.assertEqual(w, range(100)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + self.assertEqual(rest, list(range(2, 100))) + + def test_patma_189(self): + w = range(1000) + match w: + case x, y, *rest: + z = 0 + self.assertEqual(w, range(1000)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + self.assertEqual(rest, list(range(2, 1000))) + + def test_patma_190(self): + w = range(1 << 10) + match w: + case [x, y, *_]: + z = 0 + self.assertEqual(w, range(1 << 10)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_191(self): + w = range(1 << 20) + match w: + case (x, y, *_): + z = 0 + self.assertEqual(w, range(1 << 20)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_192(self): + w = range(1 << 30) + match w: + case x, y, *_: + z = 0 + self.assertEqual(w, range(1 << 30)) + self.assertEqual(x, 0) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_193(self): + x = {"bandwidth": 0, "latency": 1} + match x: + case {"bandwidth": b, "latency": l}: + y = 0 + self.assertEqual(x, {"bandwidth": 0, "latency": 1}) + self.assertIs(b, x["bandwidth"]) + self.assertIs(l, x["latency"]) + self.assertEqual(y, 0) + + def test_patma_194(self): + x = {"bandwidth": 0, "latency": 1, "key": "value"} + match x: + case {"latency": l, "bandwidth": b}: + y = 0 + self.assertEqual(x, {"bandwidth": 0, "latency": 1, "key": "value"}) + self.assertIs(l, x["latency"]) + self.assertIs(b, x["bandwidth"]) + self.assertEqual(y, 0) + + def test_patma_195(self): + x = {"bandwidth": 0, "latency": 1, "key": "value"} + match x: + case {"bandwidth": b, "latency": l, **rest}: + y = 0 + self.assertEqual(x, {"bandwidth": 0, "latency": 1, "key": "value"}) + self.assertIs(b, x["bandwidth"]) + self.assertIs(l, x["latency"]) + self.assertEqual(rest, {"key": "value"}) + self.assertEqual(y, 0) + + def test_patma_196(self): + x = {"bandwidth": 0, "latency": 1} + match x: + case {"latency": l, "bandwidth": b, **rest}: + y = 0 + self.assertEqual(x, {"bandwidth": 0, "latency": 1}) + self.assertIs(l, x["latency"]) + self.assertIs(b, x["bandwidth"]) + self.assertEqual(rest, {}) + self.assertEqual(y, 0) + + def test_patma_197(self): + w = [Point(-1, 0), Point(1, 2)] + match w: + case (Point(x1, y1), Point(x2, y2) as p2): + z = 0 + self.assertEqual(w, [Point(-1, 0), Point(1, 2)]) + self.assertIs(x1, w[0].x) + self.assertIs(y1, w[0].y) + self.assertIs(p2, w[1]) + self.assertIs(x2, w[1].x) + self.assertIs(y2, w[1].y) + self.assertIs(z, 0) + + def test_patma_198(self): + class Color(enum.Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + def f(color): + match color: + case Color.RED: + return "I see red!" + case Color.GREEN: + return "Grass is green" + case Color.BLUE: + return "I'm feeling the blues :(" + self.assertEqual(f(Color.RED), "I see red!") + self.assertEqual(f(Color.GREEN), "Grass is green") + self.assertEqual(f(Color.BLUE), "I'm feeling the blues :(") + self.assertIs(f(Color), None) + self.assertIs(f(0), None) + self.assertIs(f(1), None) + self.assertIs(f(2), None) + self.assertIs(f(3), None) + self.assertIs(f(False), None) + self.assertIs(f(True), None) + self.assertIs(f(2+0j), None) + self.assertIs(f(3.0), None) + + def test_patma_199(self): + class Color(int, enum.Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + def f(color): + match color: + case Color.RED: + return "I see red!" + case Color.GREEN: + return "Grass is green" + case Color.BLUE: + return "I'm feeling the blues :(" + self.assertEqual(f(Color.RED), "I see red!") + self.assertEqual(f(Color.GREEN), "Grass is green") + self.assertEqual(f(Color.BLUE), "I'm feeling the blues :(") + self.assertIs(f(Color), None) + self.assertEqual(f(0), "I see red!") + self.assertEqual(f(1), "Grass is green") + self.assertEqual(f(2), "I'm feeling the blues :(") + self.assertIs(f(3), None) + self.assertEqual(f(False), "I see red!") + self.assertEqual(f(True), "Grass is green") + self.assertEqual(f(2+0j), "I'm feeling the blues :(") + self.assertIs(f(3.0), None) + + def test_patma_200(self): + class Class: + __match_args__ = ("a", "b") + c = Class() + c.a = 0 + c.b = 1 + match c: + case Class(x, y): + z = 0 + self.assertIs(x, c.a) + self.assertIs(y, c.b) + self.assertEqual(z, 0) + + def test_patma_201(self): + class Class: + __match_args__ = ("a", "b") + c = Class() + c.a = 0 + c.b = 1 + match c: + case Class(x, b=y): + z = 0 + self.assertIs(x, c.a) + self.assertIs(y, c.b) + self.assertEqual(z, 0) + + def test_patma_202(self): + class Parent: + __match_args__ = "a", "b" + class Child(Parent): + __match_args__ = ("c", "d") + c = Child() + c.a = 0 + c.b = 1 + match c: + case Parent(x, y): + z = 0 + self.assertIs(x, c.a) + self.assertIs(y, c.b) + self.assertEqual(z, 0) + + def test_patma_203(self): + class Parent: + __match_args__ = ("a", "b") + class Child(Parent): + __match_args__ = "c", "d" + c = Child() + c.a = 0 + c.b = 1 + match c: + case Parent(x, b=y): + z = 0 + self.assertIs(x, c.a) + self.assertIs(y, c.b) + self.assertEqual(z, 0) + + def test_patma_204(self): + def f(w): + match w: + case 42: + out = locals() + del out["w"] + return out + self.assertEqual(f(42), {}) + self.assertIs(f(0), None) + self.assertEqual(f(42.0), {}) + self.assertIs(f("42"), None) + + def test_patma_205(self): + def f(w): + match w: + case 42.0: + out = locals() + del out["w"] + return out + self.assertEqual(f(42.0), {}) + self.assertEqual(f(42), {}) + self.assertIs(f(0.0), None) + self.assertIs(f(0), None) + + def test_patma_206(self): + def f(w): + match w: + case 1 | 2 | 3: + out = locals() + del out["w"] + return out + self.assertEqual(f(1), {}) + self.assertEqual(f(2), {}) + self.assertEqual(f(3), {}) + self.assertEqual(f(3.0), {}) + self.assertIs(f(0), None) + self.assertIs(f(4), None) + self.assertIs(f("1"), None) + + def test_patma_207(self): + def f(w): + match w: + case [1, 2] | [3, 4]: + out = locals() + del out["w"] + return out + self.assertEqual(f([1, 2]), {}) + self.assertEqual(f([3, 4]), {}) + self.assertIs(f(42), None) + self.assertIs(f([2, 3]), None) + self.assertIs(f([1, 2, 3]), None) + self.assertEqual(f([1, 2.0]), {}) + + def test_patma_208(self): + def f(w): + match w: + case x: + out = locals() + del out["w"] + return out + self.assertEqual(f(42), {"x": 42}) + self.assertEqual(f((1, 2)), {"x": (1, 2)}) + self.assertEqual(f(None), {"x": None}) + + def test_patma_209(self): + def f(w): + match w: + case _: + out = locals() + del out["w"] + return out + self.assertEqual(f(42), {}) + self.assertEqual(f(None), {}) + self.assertEqual(f((1, 2)), {}) + + def test_patma_210(self): + def f(w): + match w: + case (x, y, z): + out = locals() + del out["w"] + return out + self.assertEqual(f((1, 2, 3)), {"x": 1, "y": 2, "z": 3}) + self.assertIs(f((1, 2)), None) + self.assertIs(f((1, 2, 3, 4)), None) + self.assertIs(f(123), None) + self.assertIs(f("abc"), None) + self.assertIs(f(b"abc"), None) + self.assertEqual(f(array.array("b", b"abc")), {'x': 97, 'y': 98, 'z': 99}) + self.assertEqual(f(memoryview(b"abc")), {"x": 97, "y": 98, "z": 99}) + self.assertIs(f(bytearray(b"abc")), None) + + def test_patma_211(self): + def f(w): + match w: + case {"x": x, "y": "y", "z": z}: + out = locals() + del out["w"] + return out + self.assertEqual(f({"x": "x", "y": "y", "z": "z"}), {"x": "x", "z": "z"}) + self.assertEqual(f({"x": "x", "y": "y", "z": "z", "a": "a"}), {"x": "x", "z": "z"}) + self.assertIs(f(({"x": "x", "y": "yy", "z": "z", "a": "a"})), None) + self.assertIs(f(({"x": "x", "y": "y"})), None) + + def test_patma_212(self): + def f(w): + match w: + case Point(int(xx), y="hello"): + out = locals() + del out["w"] + return out + self.assertEqual(f(Point(42, "hello")), {"xx": 42}) + + def test_patma_213(self): + def f(w): + match w: + case (p, q) as x: + out = locals() + del out["w"] + return out + self.assertEqual(f((1, 2)), {"p": 1, "q": 2, "x": (1, 2)}) + self.assertEqual(f([1, 2]), {"p": 1, "q": 2, "x": [1, 2]}) + self.assertIs(f(12), None) + self.assertIs(f((1, 2, 3)), None) + + def test_patma_214(self): + def f(): + match 42: + case 42: + return locals() + self.assertEqual(set(f()), set()) + + def test_patma_215(self): + def f(): + match 1: + case 1 | 2 | 3: + return locals() + self.assertEqual(set(f()), set()) + + def test_patma_216(self): + def f(): + match ...: + case _: + return locals() + self.assertEqual(set(f()), set()) + + def test_patma_217(self): + def f(): + match ...: + case abc: + return locals() + self.assertEqual(set(f()), {"abc"}) + + def test_patma_218(self): + def f(): + match ..., ...: + case a, b: + return locals() + self.assertEqual(set(f()), {"a", "b"}) + + def test_patma_219(self): + def f(): + match {"k": ..., "l": ...}: + case {"k": a, "l": b}: + return locals() + self.assertEqual(set(f()), {"a", "b"}) + + def test_patma_220(self): + def f(): + match Point(..., ...): + case Point(x, y=y): + return locals() + self.assertEqual(set(f()), {"x", "y"}) + + def test_patma_221(self): + def f(): + match ...: + case b as a: + return locals() + self.assertEqual(set(f()), {"a", "b"}) + + def test_patma_222(self): + def f(x): + match x: + case _: + return 0 + self.assertEqual(f(0), 0) + self.assertEqual(f(1), 0) + self.assertEqual(f(2), 0) + self.assertEqual(f(3), 0) + + def test_patma_223(self): + def f(x): + match x: + case 0: + return 0 + self.assertEqual(f(0), 0) + self.assertIs(f(1), None) + self.assertIs(f(2), None) + self.assertIs(f(3), None) + + def test_patma_224(self): + def f(x): + match x: + case 0: + return 0 + case _: + return 1 + self.assertEqual(f(0), 0) + self.assertEqual(f(1), 1) + self.assertEqual(f(2), 1) + self.assertEqual(f(3), 1) + + def test_patma_225(self): + def f(x): + match x: + case 0: + return 0 + case 1: + return 1 + self.assertEqual(f(0), 0) + self.assertEqual(f(1), 1) + self.assertIs(f(2), None) + self.assertIs(f(3), None) + + def test_patma_226(self): + def f(x): + match x: + case 0: + return 0 + case 1: + return 1 + case _: + return 2 + self.assertEqual(f(0), 0) + self.assertEqual(f(1), 1) + self.assertEqual(f(2), 2) + self.assertEqual(f(3), 2) + + def test_patma_227(self): + def f(x): + match x: + case 0: + return 0 + case 1: + return 1 + case 2: + return 2 + self.assertEqual(f(0), 0) + self.assertEqual(f(1), 1) + self.assertEqual(f(2), 2) + self.assertIs(f(3), None) + + def test_patma_228(self): + match(): + case(): + x = 0 + self.assertEqual(x, 0) + + def test_patma_229(self): + x = 0 + match(x): + case(x): + y = 0 + self.assertEqual(x, 0) + self.assertEqual(y, 0) + + def test_patma_230(self): + x = 0 + match x: + case False: + y = 0 + case 0: + y = 1 + self.assertEqual(x, 0) + self.assertEqual(y, 1) + + def test_patma_231(self): + x = 1 + match x: + case True: + y = 0 + case 1: + y = 1 + self.assertEqual(x, 1) + self.assertEqual(y, 1) + + def test_patma_232(self): + class Eq: + def __eq__(self, other): + return True + x = eq = Eq() + y = None + match x: + case None: + y = 0 + self.assertIs(x, eq) + self.assertEqual(y, None) + + def test_patma_233(self): + x = False + match x: + case False: + y = 0 + self.assertIs(x, False) + self.assertEqual(y, 0) + + def test_patma_234(self): + x = True + match x: + case True: + y = 0 + self.assertIs(x, True) + self.assertEqual(y, 0) + + def test_patma_235(self): + x = None + match x: + case None: + y = 0 + self.assertIs(x, None) + self.assertEqual(y, 0) + + def test_patma_236(self): + x = 0 + match x: + case (0 as w) as z: + y = 0 + self.assertEqual(w, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_237(self): + x = 0 + match x: + case (0 as w) as z: + y = 0 + self.assertEqual(w, 0) + self.assertEqual(x, 0) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_238(self): + x = ((0, 1), (2, 3)) + match x: + case ((a as b, c as d) as e) as w, ((f as g, h) as i) as z: + y = 0 + self.assertEqual(a, 0) + self.assertEqual(b, 0) + self.assertEqual(c, 1) + self.assertEqual(d, 1) + self.assertEqual(e, (0, 1)) + self.assertEqual(f, 2) + self.assertEqual(g, 2) + self.assertEqual(h, 3) + self.assertEqual(i, (2, 3)) + self.assertEqual(w, (0, 1)) + self.assertEqual(x, ((0, 1), (2, 3))) + self.assertEqual(y, 0) + self.assertEqual(z, (2, 3)) + + def test_patma_239(self): + x = collections.UserDict({0: 1, 2: 3}) + match x: + case {2: 3}: + y = 0 + self.assertEqual(x, {0: 1, 2: 3}) + self.assertEqual(y, 0) + + def test_patma_240(self): + x = collections.UserDict({0: 1, 2: 3}) + match x: + case {2: 3, **z}: + y = 0 + self.assertEqual(x, {0: 1, 2: 3}) + self.assertEqual(y, 0) + self.assertEqual(z, {0: 1}) + + def test_patma_241(self): + x = [[{0: 0}]] + match x: + case list([({-0-0j: int(real=0+0j, imag=0-0j) | (1) as z},)]): + y = 0 + self.assertEqual(x, [[{0: 0}]]) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_patma_242(self): + x = range(3) + match x: + case [y, *_, z]: + w = 0 + self.assertEqual(w, 0) + self.assertEqual(x, range(3)) + self.assertEqual(y, 0) + self.assertEqual(z, 2) + + def test_patma_243(self): + x = range(3) + match x: + case [_, *_, y]: + z = 0 + self.assertEqual(x, range(3)) + self.assertEqual(y, 2) + self.assertEqual(z, 0) + + def test_patma_244(self): + x = range(3) + match x: + case [*_, y]: + z = 0 + self.assertEqual(x, range(3)) + self.assertEqual(y, 2) + self.assertEqual(z, 0) + + def test_patma_245(self): + x = {"y": 1} + match x: + case {"y": (0 as y) | (1 as y)}: + z = 0 + self.assertEqual(x, {"y": 1}) + self.assertEqual(y, 1) + self.assertEqual(z, 0) + + def test_patma_246(self): + def f(x): + match x: + case ((a, b, c, d, e, f, g, h, i, 9) | + (h, g, i, a, b, d, e, c, f, 10) | + (g, b, a, c, d, -5, e, h, i, f) | + (-1, d, f, b, g, e, i, a, h, c)): + w: object = 0 # annotation is for Cython, otherwise it's an int and it's always in locals + out = locals() + del out["x"] + return out + alts = [ + dict(a=0, b=1, c=2, d=3, e=4, f=5, g=6, h=7, i=8, w=0), + dict(h=1, g=2, i=3, a=4, b=5, d=6, e=7, c=8, f=9, w=0), + dict(g=0, b=-1, a=-2, c=-3, d=-4, e=-6, h=-7, i=-8, f=-9, w=0), + dict(d=-2, f=-3, b=-4, g=-5, e=-6, i=-7, a=-8, h=-9, c=-10, w=0), + dict(), + ] + self.assertEqual(f(range(10)), alts[0]) + self.assertEqual(f(range(1, 11)), alts[1]) + self.assertEqual(f(range(0, -10, -1)), alts[2]) + self.assertEqual(f(range(-1, -11, -1)), alts[3]) + self.assertEqual(f(range(10, 20)), alts[4]) + + def test_patma_247(self): + def f(x): + match x: + case [y, (a, b, c, d, e, f, g, h, i, 9) | + (h, g, i, a, b, d, e, c, f, 10) | + (g, b, a, c, d, -5, e, h, i, f) | + (-1, d, f, b, g, e, i, a, h, c), z]: + w: object = 0 # annotation is for Cython, otherwise it's an int and always in locals + out = locals() + del out["x"] + return out + alts = [ + dict(a=0, b=1, c=2, d=3, e=4, f=5, g=6, h=7, i=8, w=0, y=False, z=True), + dict(h=1, g=2, i=3, a=4, b=5, d=6, e=7, c=8, f=9, w=0, y=False, z=True), + dict(g=0, b=-1, a=-2, c=-3, d=-4, e=-6, h=-7, i=-8, f=-9, w=0, y=False, z=True), + dict(d=-2, f=-3, b=-4, g=-5, e=-6, i=-7, a=-8, h=-9, c=-10, w=0, y=False, z=True), + dict(), + ] + self.assertEqual(f((False, range(10), True)), alts[0]) + self.assertEqual(f((False, range(1, 11), True)), alts[1]) + self.assertEqual(f((False, range(0, -10, -1), True)), alts[2]) + self.assertEqual(f((False, range(-1, -11, -1), True)), alts[3]) + self.assertEqual(f((False, range(10, 20), True)), alts[4]) + + def test_patma_248(self): + class C(dict): + @staticmethod + def get(key, default=None): + return 'bar' + + x = C({'foo': 'bar'}) + match x: + case {'foo': bar}: + y = bar + + self.assertEqual(y, 'bar') + + def test_patma_249(self): + return # disabled + class C: + __attr = "eggs" # mangled to _C__attr + _Outer__attr = "bacon" + class Outer: + def f(self, x): + match x: + # looks up __attr, not _C__attr or _Outer__attr + case C(__attr=y): + return y + c = C() + setattr(c, "__attr", "spam") # setattr is needed because we're in a class scope + self.assertEqual(Outer().f(c), "spam") + + +class TestSyntaxErrors(unittest.TestCase): + + def assert_syntax_error(self, code: str): + with self.assertRaises(SyntaxError): + compile(inspect.cleandoc(code), "<test>", "exec") + + def test_alternative_patterns_bind_different_names_0(self): + self.assert_syntax_error(""" + match ...: + case "a" | a: + pass + """) + + def test_alternative_patterns_bind_different_names_1(self): + self.assert_syntax_error(""" + match ...: + case [a, [b] | [c] | [d]]: + pass + """) + + + @disable # validation will be added when class patterns are added + def test_attribute_name_repeated_in_class_pattern(self): + self.assert_syntax_error(""" + match ...: + case Class(a=_, a=_): + pass + """) + + def test_imaginary_number_required_in_complex_literal_0(self): + self.assert_syntax_error(""" + match ...: + case 0+0: + pass + """) + + def test_imaginary_number_required_in_complex_literal_1(self): + self.assert_syntax_error(""" + match ...: + case {0+0: _}: + pass + """) + + def test_invalid_syntax_0(self): + self.assert_syntax_error(""" + match ...: + case {**rest, "key": value}: + pass + """) + + def test_invalid_syntax_1(self): + self.assert_syntax_error(""" + match ...: + case {"first": first, **rest, "last": last}: + pass + """) + + def test_invalid_syntax_2(self): + self.assert_syntax_error(""" + match ...: + case {**_}: + pass + """) + + def test_invalid_syntax_3(self): + self.assert_syntax_error(""" + match ...: + case 42 as _: + pass + """) + + def test_mapping_pattern_keys_may_only_match_literals_and_attribute_lookups(self): + self.assert_syntax_error(""" + match ...: + case {f"": _}: + pass + """) + + def test_multiple_assignments_to_name_in_pattern_0(self): + self.assert_syntax_error(""" + match ...: + case a, a: + pass + """) + + def test_multiple_assignments_to_name_in_pattern_1(self): + self.assert_syntax_error(""" + match ...: + case {"k": a, "l": a}: + pass + """) + + def test_multiple_assignments_to_name_in_pattern_2(self): + self.assert_syntax_error(""" + match ...: + case MyClass(x, x): + pass + """) + + def test_multiple_assignments_to_name_in_pattern_3(self): + self.assert_syntax_error(""" + match ...: + case MyClass(x=x, y=x): + pass + """) + + def test_multiple_assignments_to_name_in_pattern_4(self): + self.assert_syntax_error(""" + match ...: + case MyClass(x, y=x): + pass + """) + + def test_multiple_assignments_to_name_in_pattern_5(self): + self.assert_syntax_error(""" + match ...: + case a as a: + pass + """) + + @disable # will be implemented as part of sequence patterns + def test_multiple_starred_names_in_sequence_pattern_0(self): + self.assert_syntax_error(""" + match ...: + case *a, b, *c, d, *e: + pass + """) + + @disable # will be implemented as part of sequence patterns + def test_multiple_starred_names_in_sequence_pattern_1(self): + self.assert_syntax_error(""" + match ...: + case a, *b, c, *d, e: + pass + """) + + def test_name_capture_makes_remaining_patterns_unreachable_0(self): + self.assert_syntax_error(""" + match ...: + case a | "a": + pass + """) + + def test_name_capture_makes_remaining_patterns_unreachable_1(self): + self.assert_syntax_error(""" + match 42: + case x: + pass + case y: + pass + """) + + def test_name_capture_makes_remaining_patterns_unreachable_2(self): + self.assert_syntax_error(""" + match ...: + case x | [_ as x] if x: + pass + """) + + def test_name_capture_makes_remaining_patterns_unreachable_3(self): + self.assert_syntax_error(""" + match ...: + case x: + pass + case [x] if x: + pass + """) + + def test_name_capture_makes_remaining_patterns_unreachable_4(self): + self.assert_syntax_error(""" + match ...: + case x: + pass + case _: + pass + """) + + def test_patterns_may_only_match_literals_and_attribute_lookups_0(self): + self.assert_syntax_error(""" + match ...: + case f"": + pass + """) + + def test_patterns_may_only_match_literals_and_attribute_lookups_1(self): + self.assert_syntax_error(""" + match ...: + case f"{x}": + pass + """) + + def test_real_number_required_in_complex_literal_0(self): + self.assert_syntax_error(""" + match ...: + case 0j+0: + pass + """) + + def test_real_number_required_in_complex_literal_1(self): + self.assert_syntax_error(""" + match ...: + case 0j+0j: + pass + """) + + def test_real_number_required_in_complex_literal_2(self): + self.assert_syntax_error(""" + match ...: + case {0j+0: _}: + pass + """) + + def test_real_number_required_in_complex_literal_3(self): + self.assert_syntax_error(""" + match ...: + case {0j+0j: _}: + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_0(self): + self.assert_syntax_error(""" + match ...: + case _ | _: + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_1(self): + self.assert_syntax_error(""" + match ...: + case (_ as x) | [x]: + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_2(self): + self.assert_syntax_error(""" + match ...: + case _ | _ if condition(): + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_3(self): + self.assert_syntax_error(""" + match ...: + case _: + pass + case None: + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_4(self): + self.assert_syntax_error(""" + match ...: + case (None | _) | _: + pass + """) + + def test_wildcard_makes_remaining_patterns_unreachable_5(self): + self.assert_syntax_error(""" + match ...: + case _ | (True | False): + pass + """) + + @disable # validation will be added when class patterns are added + def test_mapping_pattern_duplicate_key(self): + self.assert_syntax_error(""" + match ...: + case {"a": _, "a": _}: + pass + """) + + @disable # validation will be added when class patterns are added + def test_mapping_pattern_duplicate_key_edge_case0(self): + self.assert_syntax_error(""" + match ...: + case {0: _, False: _}: + pass + """) + + @disable # validation will be added when class patterns are added + def test_mapping_pattern_duplicate_key_edge_case1(self): + self.assert_syntax_error(""" + match ...: + case {0: _, 0.0: _}: + pass + """) + + @disable # validation will be added when class patterns are added + def test_mapping_pattern_duplicate_key_edge_case2(self): + self.assert_syntax_error(""" + match ...: + case {0: _, -0: _}: + pass + """) + + @disable # validation will be added when class patterns are added + def test_mapping_pattern_duplicate_key_edge_case3(self): + self.assert_syntax_error(""" + match ...: + case {0: _, 0j: _}: + pass + """) + +class TestTypeErrors(unittest.TestCase): + + def test_accepts_positional_subpatterns_0(self): + class Class: + __match_args__ = () + x = Class() + y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y): + z = 0 + self.assertIs(y, None) + self.assertIs(z, None) + + def test_accepts_positional_subpatterns_1(self): + x = range(10) + y = None + with self.assertRaises(TypeError): + match x: + case range(10): + y = 0 + self.assertEqual(x, range(10)) + self.assertIs(y, None) + + def test_got_multiple_subpatterns_for_attribute_0(self): + class Class: + __match_args__ = ("a", "a") + a = None + x = Class() + w = y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y, z): + w = 0 + self.assertIs(w, None) + self.assertIs(y, None) + self.assertIs(z, None) + + def test_got_multiple_subpatterns_for_attribute_1(self): + class Class: + __match_args__ = ("a",) + a = None + x = Class() + w = y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y, a=z): + w = 0 + self.assertIs(w, None) + self.assertIs(y, None) + self.assertIs(z, None) + + def test_match_args_elements_must_be_strings(self): + class Class: + __match_args__ = (None,) + x = Class() + y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y): + z = 0 + self.assertIs(y, None) + self.assertIs(z, None) + + def test_match_args_must_be_a_tuple_0(self): + class Class: + __match_args__ = None + x = Class() + y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y): + z = 0 + self.assertIs(y, None) + self.assertIs(z, None) + + def test_match_args_must_be_a_tuple_1(self): + class Class: + __match_args__ = "XYZ" + x = Class() + y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y): + z = 0 + self.assertIs(y, None) + self.assertIs(z, None) + + def test_match_args_must_be_a_tuple_2(self): + class Class: + __match_args__ = ["spam", "eggs"] + spam = 0 + eggs = 1 + x = Class() + w = y = z = None + with self.assertRaises(TypeError): + match x: + case Class(y, z): + w = 0 + self.assertIs(w, None) + self.assertIs(y, None) + self.assertIs(z, None) + + +class TestValueErrors(unittest.TestCase): + + def test_mapping_pattern_checks_duplicate_key_1(self): + class Keys: + KEY = "a" + x = {"a": 0, "b": 1} + w = y = z = None + with self.assertRaises(ValueError): + match x: + case {Keys.KEY: y, "a": z}: + w = 0 + self.assertIs(w, None) + self.assertIs(y, None) + self.assertIs(z, None) + + +if __name__ == "__main__": + """ + # From inside environment using this Python, with pyperf installed: + sudo $(which pyperf) system tune && \ + $(which python) -m test.test_patma --rigorous; \ + sudo $(which pyperf) system reset + """ + import pyperf + + + class PerfPatma(TestPatma): + + def assertEqual(*_, **__): + pass + + def assertIs(*_, **__): + pass + + def assertRaises(*_, **__): + assert False, "this test should be a method of a different class!" + + def run_perf(self, count): + tests = [] + for attr in vars(TestPatma): + if attr.startswith("test_"): + tests.append(getattr(self, attr)) + tests *= count + start = pyperf.perf_counter() + for test in tests: + test() + return pyperf.perf_counter() - start + + + runner = pyperf.Runner() + runner.bench_time_func("patma", PerfPatma().run_perf) |