diff options
-rw-r--r-- | lambda/matching.ml | 390 |
1 files changed, 221 insertions, 169 deletions
diff --git a/lambda/matching.ml b/lambda/matching.ml index 1c0a480554..bcd2012188 100644 --- a/lambda/matching.ml +++ b/lambda/matching.ml @@ -143,18 +143,93 @@ let all_record_args lbls = type 'a clause = 'a * lambda module Non_empty_clause = struct - type 'a t = ('a * pattern list) clause + type 'a t = ('a * Typedtree.pattern list) clause let of_initial = function | [], _ -> assert false | pat :: patl, act -> ((pat, patl), act) + + let map_head f ((p, patl), act) = ((f p, patl), act) end -module General = struct - type nonrec pattern = pattern +type simple_view = [ + | `Any + | `Constant of constant + | `Tuple of pattern list + | `Construct of Longident.t loc * constructor_description * pattern list + | `Variant of label * pattern option * row_desc ref + | `Record of (Longident.t loc * label_description * pattern) list * closed_flag + | `Array of pattern list + | `Lazy of pattern +] + +type half_simple_view = [ + | simple_view + | `Or of pattern * pattern * row_desc option +] +type general_view = [ + | half_simple_view + | `Var of Ident.t * string loc + | `Alias of pattern * Ident.t * string loc +] + +module General : sig + type pattern = general_view pattern_data type clause = pattern Non_empty_clause.t + + val view : Typedtree.pattern -> pattern + val erase : [< general_view ] pattern_data -> Typedtree.pattern +end = struct + type pattern = general_view pattern_data + type clause = pattern Non_empty_clause.t + + let view_desc = function + | Tpat_any -> + `Any + | Tpat_var (id, str) -> + `Var (id, str) + | Tpat_alias (p, id, str) -> + `Alias (p, id, str) + | Tpat_constant cst -> + `Constant cst + | Tpat_tuple ps -> + `Tuple ps + | Tpat_construct (cstr, cstr_descr, args) -> + `Construct (cstr, cstr_descr, args) + | Tpat_variant (cstr, arg, row_desc) -> + `Variant (cstr, arg, row_desc) + | Tpat_record (fields, closed) -> + `Record (fields, closed) + | Tpat_array ps -> `Array ps + | Tpat_or (p, q, row_desc) -> `Or (p, q, row_desc) + | Tpat_lazy p -> `Lazy p + + let view p : pattern = + { p with pat_desc = view_desc p.pat_desc } + + let erase_desc = function + | `Any -> Tpat_any + | `Var (id, str) -> Tpat_var (id, str) + | `Alias (p, id, str) -> Tpat_alias (p, id, str) + | `Constant cst -> Tpat_constant cst + | `Tuple ps -> Tpat_tuple ps + | `Construct (cstr, cst_descr, args) -> + Tpat_construct (cstr, cst_descr, args) + | `Variant (cstr, arg, row_desc) -> + Tpat_variant (cstr, arg, row_desc) + | `Record (fields, closed) -> + Tpat_record (fields, closed) + | `Array ps -> Tpat_array ps + | `Or (p, q, row_desc) -> Tpat_or (p, q, row_desc) + | `Lazy p -> Tpat_lazy p + + let erase p = + { p with pat_desc = erase_desc p.pat_desc } end +let omega_ : [> `Any ] pattern_data = + { Parmatch.omega with pat_desc = `Any } + module Half_simple : sig (** Half-simplified patterns are patterns where: - records are expanded so that they possess all fields @@ -177,28 +252,24 @@ module Half_simple : sig In particular, or-patterns may still occur in the leading column, so this is only a "half-simplification". *) - type pattern - - val to_pattern : pattern -> General.pattern + type pattern = half_simple_view pattern_data type clause = pattern Non_empty_clause.t val of_clause : args:(lambda * 'a) list -> General.clause -> clause end = struct - type nonrec pattern = pattern + type pattern = half_simple_view pattern_data type clause = pattern Non_empty_clause.t - let to_pattern p = p - - let rec simpl_orpat p = + let rec simpl_under_orpat p = match p.pat_desc with | Tpat_any | Tpat_var _ -> p | Tpat_alias (q, id, s) -> - { p with pat_desc = Tpat_alias (simpl_orpat q, id, s) } + { p with pat_desc = Tpat_alias (simpl_under_orpat q, id, s) } | Tpat_or (p1, p2, o) -> - let p1, p2 = (simpl_orpat p1, simpl_orpat p2) in + let p1, p2 = (simpl_under_orpat p1, simpl_under_orpat p2) in if le_pat p1 p2 then p1 else @@ -209,40 +280,37 @@ end = struct | _ -> p let of_clause ~args cl = - let rec aux ((pat, patl), action) = - match pat.pat_desc with - | Tpat_any -> ((pat, patl), action) - | Tpat_var (id, s) -> - let p = { pat with pat_desc = Tpat_alias (omega, id, s) } in - aux ((p, patl), action) - | Tpat_alias (p, id, _) -> + let rec aux (((p, patl), action) : General.clause) : clause = + let continue p (view : general_view) : clause = + aux (({ p with pat_desc = view }, patl), action) in + let stop p (view : half_simple_view) : clause = + (({ p with pat_desc = view }, patl), action) in + match p.pat_desc with + | `Any -> stop p `Any + | `Var (id, s) -> continue p (`Alias (omega, id, s)) + | `Alias (p, id, _) -> let arg = match args with | [] -> assert false | (arg, _) :: _ -> arg in - let k = Typeopt.value_kind pat.pat_env pat.pat_type in - aux ((p, patl), bind_with_value_kind Alias (id, k) arg action) - | Tpat_record ([], _) -> ((omega, patl), action) - | Tpat_record (lbls, closed) -> - let all_lbls = all_record_args lbls in - let full_pat = - { pat with pat_desc = Tpat_record (all_lbls, closed) } - in - ((full_pat, patl), action) - | Tpat_or _ -> ( - let pat_simple = simpl_orpat pat in - match pat_simple.pat_desc with - | Tpat_or _ -> ((pat_simple, patl), action) - | _ -> aux ((pat_simple, patl), action) + let k = Typeopt.value_kind p.pat_env p.pat_type in + aux ((General.view p, patl), + bind_with_value_kind Alias (id, k) arg action) + | `Record ([], _) as view -> stop p view + | `Record (lbls, closed) -> + let full_view = `Record (all_record_args lbls, closed) in + stop p full_view + | `Or _ -> ( + let orpat = + General.view (simpl_under_orpat (General.erase p)) in + match orpat.pat_desc with + | `Or _ as or_view -> stop orpat or_view + | other_view -> continue orpat other_view ) - | Tpat_constant _ - | Tpat_tuple _ - | Tpat_construct _ - | Tpat_variant _ - | Tpat_array _ - | Tpat_lazy _ -> - ((pat, patl), action) + | (`Constant _ | `Tuple _ | `Construct _ | `Variant _ + | `Array _ | `Lazy _) as view -> + stop p view in aux cl end @@ -250,37 +318,41 @@ end exception Cannot_flatten module Simple : sig - type pattern - (** A fully simplified pattern: or-patterns have been exploded, and the - remaining aliases have been removed and replaced by bindings in actions *) - + type pattern = simple_view pattern_data type clause = pattern Non_empty_clause.t - val try_no_or : Half_simple.pattern -> pattern option - - val to_pattern : pattern -> General.pattern - val head : pattern -> Pattern_head.t val explode_or_pat : - Half_simple.pattern * General.pattern list -> + Half_simple.pattern * Typedtree.pattern list -> arg:Ident.t option -> mk_action:(vars:Ident.t list -> lambda) -> vars:Ident.t list -> clause list -> clause list - - val omega : pattern end = struct - type nonrec pattern = pattern - - let omega = omega - + type pattern = simple_view pattern_data type clause = pattern Non_empty_clause.t - let to_pattern p = p - - let head p = fst (Pattern_head.deconstruct p) + let head p = + fst (Pattern_head.deconstruct (General.erase (p :> General.pattern))) + + let alpha env (p : pattern) : pattern = + let alpha_pat env p = Typedtree.alpha_pat env p in + let pat_desc = match p.pat_desc with + | `Any -> `Any + | `Constant cst -> `Constant cst + | `Tuple ps -> `Tuple (List.map (alpha_pat env) ps) + | `Construct (cstr, cst_descr, args) -> + `Construct (cstr, cst_descr, List.map (alpha_pat env) args) + | `Variant (cstr, argo, row_desc) -> + `Variant (cstr, Option.map (alpha_pat env) argo, row_desc) + | `Record (fields, closed) -> + let alpha_field env (lid, l, p) = (lid, l, alpha_pat env p) in + `Record (List.map (alpha_field env) fields, closed) + | `Array ps -> `Array (List.map (alpha_pat env) ps) + | `Lazy p -> `Lazy (alpha_pat env p) + in { p with pat_desc } let mk_alpha_env arg aliases ids = List.map @@ -294,27 +366,25 @@ end = struct Ident.create_local (Ident.name id) )) ids - let explode_or_pat (p, patl) ~arg ~mk_action ~vars rem = + let explode_or_pat ((p : Half_simple.pattern), patl) ~arg ~mk_action ~vars + (rem : clause list) : clause list = let rec explode p aliases rem = + let split_explode p aliases rem = + explode (General.view p) aliases rem in match p.pat_desc with - | Tpat_or (p1, p2, _) -> - explode p1 aliases (explode p2 aliases rem) - | Tpat_alias (p, id, _) -> - explode p (id :: aliases) rem - | Tpat_var (x, _) -> - let env = mk_alpha_env arg (x :: aliases) vars in - ((omega, patl), mk_action ~vars:(List.map snd env)) :: rem - | _ -> + | `Or (p1, p2, _) -> + split_explode p1 aliases (split_explode p2 aliases rem) + | `Alias (p, id, _) -> + split_explode p (id :: aliases) rem + | `Var (id, str) -> + explode + { p with pat_desc = `Alias (Parmatch.omega, id, str) } aliases rem + | #simple_view as view -> let env = mk_alpha_env arg aliases vars in - ((alpha_pat env p, patl), mk_action ~vars:(List.map snd env)) :: rem + ((alpha env { p with pat_desc = view }, patl), + mk_action ~vars:(List.map snd env)) :: rem in - explode (Half_simple.to_pattern p) [] rem - - let try_no_or hsp = - let p = Half_simple.to_pattern hsp in - match p.pat_desc with - | Tpat_or _ -> None - | _ -> Some p + explode (p : Half_simple.pattern :> General.pattern) [] rem end type initial_clause = pattern list clause @@ -838,7 +908,7 @@ let pretty_hc_pm pm = pretty_cases (List.map (fun ((p, ps), act) -> - (Half_simple.to_pattern p :: ps, act)) + (General.erase p :: ps, act)) pm.cases); if not (Default_environment.is_empty pm.default) then Default_environment.pp pm.default @@ -846,7 +916,7 @@ let pretty_hc_pm pm = let pretty_sc_pm pm = pretty_cases (List.map - (fun ((p, ps), act) -> (Simple.to_pattern p :: ps, act)) + (fun ((p, ps), act) -> (General.erase p :: ps, act)) pm.cases); if not (Default_environment.is_empty pm.default) then Default_environment.pp pm.default @@ -951,7 +1021,7 @@ let same_actions = function None ) -let safe_before to_pattern ((p, ps), act_p) l = +let safe_before ((p, ps), act_p) l = (* Test for swapping two clauses *) let same_actions act1 act2 = match (make_key act1, make_key act2) with @@ -963,14 +1033,20 @@ let safe_before to_pattern ((p, ps), act_p) l = List.for_all (fun ((q, qs), act_q) -> same_actions act_p act_q - || not (may_compats (to_pattern p :: ps) (to_pattern q :: qs))) + || not (may_compats (General.erase p :: ps) (General.erase q :: qs))) l -let half_simplify_clause args cls = +let half_simplify_nonempty + args (cls : Typedtree.pattern Non_empty_clause.t) : Half_simple.clause = cls - |> Non_empty_clause.of_initial + |> Non_empty_clause.map_head General.view |> Half_simple.of_clause ~args +let half_simplify_clause args (cls : Typedtree.pattern list clause) = + cls + |> Non_empty_clause.of_initial + |> half_simplify_nonempty args + let half_simplify_cases args cls = List.map (half_simplify_clause args) cls (* Once matchings are *fully* simplified, one can easily find @@ -978,7 +1054,7 @@ let half_simplify_cases args cls = List.map (half_simplify_clause args) cls let rec what_is_cases ~skip_any cases = match cases with - | [] -> Simple.omega + | [] -> omega_ | ((p, _), _) :: rem -> ( match Pattern_head.desc (Simple.head p) with | Any when skip_any -> what_is_cases ~skip_any rem @@ -1117,11 +1193,11 @@ let rec omega_like p = let equiv_pat p q = le_pat p q && le_pat q p -let rec extract_equiv_head to_pattern p l = +let rec extract_equiv_head p l = match l with | (((q, _), _) as cl) :: rem -> - if equiv_pat p (to_pattern q) then - let others, rem = extract_equiv_head to_pattern p rem in + if equiv_pat p (General.erase q) then + let others, rem = extract_equiv_head p rem in (cl :: others, rem) else ([], l) @@ -1150,10 +1226,10 @@ module Or_matrix = struct let safe_below (ps, act) qs = (not (is_guarded act)) && Parmatch.le_pats ps qs - let safe_below_or_matrix to_pattern l (q, qs) = + let safe_below_or_matrix l (q, qs) = List.for_all (fun ((p, ps), act_p) -> - let p = to_pattern p in + let p = General.erase p in match p.pat_desc with | Tpat_or _ -> disjoint p q || safe_below (ps, act_p) qs | _ -> true) @@ -1168,14 +1244,14 @@ module Or_matrix = struct let insert_or_append (head, ps, act) rev_ors rev_no = let safe_to_insert rem (p, ps) seen = let _, not_e = - extract_equiv_head Half_simple.to_pattern p rem + extract_equiv_head p rem in (* check append condition for head of O *) - safe_below_or_matrix Half_simple.to_pattern not_e (p, ps) + safe_below_or_matrix not_e (p, ps) && (* check insert condition for tail of O *) List.for_all (fun ((q, _), _) -> - disjoint p (Half_simple.to_pattern q)) + disjoint p (General.erase q)) seen in let rec attempt seen = function @@ -1183,8 +1259,8 @@ module Or_matrix = struct [seen] (but maybe not [rem] yet) *) | [] -> (((head, ps), act) :: rev_ors, rev_no) | (((q, qs), act_q) as cl) :: rem -> - let p = Half_simple.to_pattern head in - let q = Half_simple.to_pattern q in + let p = General.erase head in + let q = General.erase q in if (not (is_or q)) || disjoint p q then attempt (cl :: seen) rem else if @@ -1208,8 +1284,8 @@ end (* Reconstruct default information from half_compiled pm list *) -let as_matrix pat_of_head cases = - get_mins le_pats (List.map (fun ((p, ps), _) -> pat_of_head p :: ps) cases) +let as_matrix cases = + get_mins le_pats (List.map (fun ((p, ps), _) -> General.erase p :: ps) cases) (* Split a matching along the first column. @@ -1260,12 +1336,14 @@ let rec split_or argo (cls : Half_simple.clause list) args def = let rec do_split (rev_before : Simple.clause list) rev_ors rev_no = function | [] -> cons_next (List.rev rev_before) (List.rev rev_ors) (List.rev rev_no) - | cl :: rem when not (safe_before Half_simple.to_pattern cl rev_no) -> + | cl :: rem when not (safe_before cl rev_no) -> do_split rev_before rev_ors (cl :: rev_no) rem | (((p, ps), act) as cl) :: rem -> ( - match Simple.try_no_or p with - | Some sp when safe_before Half_simple.to_pattern cl rev_ors -> - do_split (((sp, ps), act) :: rev_before) rev_ors rev_no rem + match p.pat_desc with + | #simple_view as view when safe_before cl rev_ors -> + do_split + ((({ p with pat_desc = view }, ps), act) :: rev_before) + rev_ors rev_no rem | _ -> let rev_ors, rev_no = Or_matrix.insert_or_append (p, ps, act) rev_ors rev_no @@ -1307,18 +1385,8 @@ and split_no_or cls args def k = let discr = what_is_first_case cls in collect discr [] [] cls and collect group_discr rev_yes rev_no = function -<<<<<<< HEAD:lambda/matching.ml - | ([], _) :: _ -> assert false - | [ ((p, ps, _) as cl) ] - when rev_yes <> [] && List.for_all omega_like (p:: ps) -> -||||||| parent of 667fb6af7... Add intermediary types to keep track of the compilation state of pms:bytecomp/matching.ml - | ([], _) :: _ -> assert false - | [ ((p, ps, _) as cl) ] - when rev_yes <> [] && List.for_all heads_are_var (p:: ps) -> -======= | [ (((p, ps), _) as cl) ] - when rev_yes <> [] && group_var p && List.for_all heads_are_var ps -> ->>>>>>> 667fb6af7... Add intermediary types to keep track of the compilation state of pms:bytecomp/matching.ml + when rev_yes <> [] && group_var p && List.for_all omega_like ps -> (* This enables an extra division in some frequent cases: last row is made of variables only @@ -1331,7 +1399,7 @@ and split_no_or cls args def k = testsuite/tests/basic/patmatch_split_no_or.ml *) collect group_discr rev_yes (cl :: rev_no) [] | (((p, _), _) as cl) :: rem -> - if can_group group_discr p && safe_before Simple.to_pattern cl rev_no then + if can_group group_discr p && safe_before cl rev_no then collect group_discr (cl :: rev_yes) rev_no rem else if should_split group_discr then ( assert (rev_no = []); @@ -1358,8 +1426,8 @@ and split_no_or cls args def k = (Default_environment.cons matrix idef def) ((idef, next) :: nexts) and should_split group_discr = - match (Simple.to_pattern group_discr).pat_desc with - | Tpat_construct (_, { cstr_tag = Cstr_extension _ }, _) -> + match Pattern_head.desc (Simple.head group_discr) with + | Construct { cstr_tag = Cstr_extension _ } -> (* it is unlikely that we will raise anything, so we split now *) true | _ -> false @@ -1401,7 +1469,7 @@ and precompile_var args cls def k = | _ -> let rec rebuild_matrix pmh = match pmh with - | Pm pm -> as_matrix Simple.to_pattern pm.cases + | Pm pm -> as_matrix pm.cases | PmOr { or_matrix = m } -> m | PmVar x -> add_omega_column (rebuild_matrix x.inside) in @@ -1437,23 +1505,23 @@ and precompile_var args cls def k = and do_not_precompile args cls def k = ( { me = Pm { cases = cls; args; default = def }; - matrix = as_matrix Simple.to_pattern cls; + matrix = as_matrix cls; top_default = def }, k ) -and precompile_or argo cls ors args def k = +and precompile_or argo (cls : Simple.clause list) ors args def k = let rec do_cases = function | [] -> ([], []) | ((p, patl), action) :: rem -> ( - match Simple.try_no_or p with - | Some sp -> + match p.pat_desc with + | #simple_view as view -> let new_ord, new_to_catch = do_cases rem in - (((sp, patl), action) :: new_ord, new_to_catch) - | None -> - let orp = Half_simple.to_pattern p in + ((({ p with pat_desc = view }, patl), action) :: new_ord, new_to_catch) + | `Or _ -> + let orp = General.erase p in let others, rem = - extract_equiv_head Half_simple.to_pattern orp rem + extract_equiv_head orp rem in let orpm = { cases = @@ -1498,15 +1566,8 @@ and precompile_or argo cls ors args def k = let cases, handlers = do_cases ors in let matrix = as_matrix - (fun x -> x) - (List.map - (fun ((p, ps), act) -> ((Simple.to_pattern p, ps), act)) - cls - @ List.map - (fun ((p, ps), act) -> - ((Half_simple.to_pattern p, ps), act)) - ors - ) + ((cls : Simple.clause list :> General.clause list) + @ (ors : Half_simple.clause list :> General.clause list)) and body = { cases = cls @ cases; args; default = def } in ( { me = PmOr { body; handlers; or_matrix = matrix }; matrix; @@ -1516,11 +1577,7 @@ and precompile_or argo cls ors args def k = let split_and_precompile_nonempty argo pm = let pm = - { pm with - cases = - List.map (Half_simple.of_clause ~args:pm.args) pm.cases - } - in + { pm with cases = List.map (half_simplify_nonempty pm.args) pm.cases } in let { me = next }, nexts = split_or argo pm.cases pm.args pm.default in if dbg @@ -1603,7 +1660,7 @@ let add_in_div make_matching_fun eq_key key patl_action division = let divide make eq_key get_key get_args ctx (pm : Simple.clause pattern_matching) = let add ((p, patl), action) division = - let p = Simple.to_pattern p in + let p = General.erase p in add_in_div (make p pm.default ctx) eq_key (get_key p) (get_args p patl, action) division @@ -1617,7 +1674,7 @@ let add_line patl_action pm = let divide_line make_ctx make get_args discr ctx (pm : Simple.clause pattern_matching) = let add ((p, patl), action) submatrix = - let p = Simple.to_pattern p in + let p = General.erase p in add_line (get_args p patl, action) submatrix in let pm = List.fold_right add pm.cases (make pm.default pm.args) in @@ -1837,36 +1894,31 @@ let make_variant_matching_nonconst p lab def ctx = function let divide_variant row ctx { cases = cl; args; default = def } = let row = Btype.row_repr row in let rec divide = function - | [] -> { args; cells = [] } - | ((p, patl), action) :: rem -> ( - let p = Simple.to_pattern p in - match p.pat_desc with - | Tpat_variant (lab, pato, _) -> ( - let variants = divide rem in - if - try - Btype.row_field_repr (List.assoc lab row.row_fields) = Rabsent - with Not_found -> true - then - variants - else - let tag = Btype.hash_variant lab in - match pato with - | None -> - add_in_div - (make_variant_matching_constant p lab def ctx) - ( = ) (Cstr_constant tag) (patl, action) variants - | Some pat -> - add_in_div - (make_variant_matching_nonconst p lab def ctx) - ( = ) (Cstr_block tag) - (pat :: patl, action) - variants - ) - | _ -> - (* I really want to assert false here. *) - { args; cells = [] } + | (({ pat_desc = `Variant (lab, pato, _) } as p, patl), action) :: rem -> ( + let p = General.erase p in + let variants = divide rem in + if + try + Btype.row_field_repr (List.assoc lab row.row_fields) = Rabsent + with Not_found -> true + then + variants + else + let tag = Btype.hash_variant lab in + match pato with + | None -> + add_in_div + (make_variant_matching_constant p lab def ctx) + ( = ) (Cstr_constant tag) (patl, action) variants + | Some pat -> + add_in_div + (make_variant_matching_nonconst p lab def ctx) + ( = ) (Cstr_block tag) + (pat :: patl, action) + variants ) + | _ -> + { args; cells = [] } in divide cl @@ -3279,7 +3331,7 @@ and compile_simplified repr partial ctx | _ -> assert false and compile_half_compiled repr partial ctx - (m : pattern Non_empty_clause.t pattern_matching) = + (m : Typedtree.pattern Non_empty_clause.t pattern_matching) = match m with | { cases = []; args = [] } -> comp_exit ctx m | { args = ((Lvar v as arg), str) :: argl } -> @@ -3334,7 +3386,7 @@ and do_compile_matching repr partial ctx pmh = in let pat = what_is_cases pm.cases in let ph = Simple.head pat in - let pat = Simple.to_pattern pat in + let pat = General.erase pat in match Pattern_head.desc ph with | Any -> compile_no_test divide_var Context.rshift repr partial ctx pm @@ -3731,7 +3783,7 @@ let flatten_cases size cases = List.map (function | (p, []), action -> ( - match flatten_pattern size (Simple.to_pattern p) with + match flatten_pattern size (General.erase p) with | p :: ps -> ((p, ps), action) | [] -> assert false ) |