diff options
author | Simon Peyton Jones <simonpj@microsoft.com> | 2011-04-28 11:44:12 +0100 |
---|---|---|
committer | Simon Peyton Jones <simonpj@microsoft.com> | 2011-04-28 11:44:12 +0100 |
commit | 478e69b303eb2e653a2ebf5c888b5efdfef1fb9d (patch) | |
tree | d23ca1c0b6dc6a0ab58cc65db055fa9109f5081e /compiler/deSugar | |
parent | 66a733f23eebbd69f6e2d00a9f73c4d5541b5c39 (diff) | |
download | haskell-478e69b303eb2e653a2ebf5c888b5efdfef1fb9d.tar.gz |
Preliminary monad-comprehension patch (Trac #4370)
This is the work of Nils Schweinsberg <mail@n-sch.de>
It adds the language extension -XMonadComprehensions, which
generalises list comprehension syntax [ e | x <- xs] to work over
arbitrary monads.
Diffstat (limited to 'compiler/deSugar')
-rw-r--r-- | compiler/deSugar/Coverage.lhs | 60 | ||||
-rw-r--r-- | compiler/deSugar/DsArrows.lhs | 4 | ||||
-rw-r--r-- | compiler/deSugar/DsExpr.lhs | 20 | ||||
-rw-r--r-- | compiler/deSugar/DsGRHSs.lhs | 4 | ||||
-rw-r--r-- | compiler/deSugar/DsListComp.lhs | 366 | ||||
-rw-r--r-- | compiler/deSugar/DsMeta.hs | 8 |
6 files changed, 407 insertions, 55 deletions
diff --git a/compiler/deSugar/Coverage.lhs b/compiler/deSugar/Coverage.lhs index 0daa6befc4..e73c2499e8 100644 --- a/compiler/deSugar/Coverage.lhs +++ b/compiler/deSugar/Coverage.lhs @@ -301,10 +301,11 @@ addTickHsExpr (HsLet binds e) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsExprNeverOrAlways e) -addTickHsExpr (HsDo cxt stmts last_exp srcloc) = do +addTickHsExpr (HsDo cxt stmts last_exp return_exp srcloc) = do (stmts', last_exp') <- addTickLStmts' forQual stmts (addTickLHsExpr last_exp) - return (HsDo cxt stmts' last_exp' srcloc) + return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp + return (HsDo cxt stmts' last_exp' return_exp' srcloc) where forQual = case cxt of ListComp -> Just $ BinBox QualBinBox @@ -438,31 +439,38 @@ addTickStmt _isGuard (BindStmt pat e bind fail) = do (addTickLHsExprAlways e) (addTickSyntaxExpr hpcSrcSpan bind) (addTickSyntaxExpr hpcSrcSpan fail) -addTickStmt isGuard (ExprStmt e bind' ty) = do - liftM3 ExprStmt +addTickStmt isGuard (ExprStmt e bind' guard' ty) = do + liftM4 ExprStmt (addTick isGuard e) (addTickSyntaxExpr hpcSrcSpan bind') + (addTickSyntaxExpr hpcSrcSpan guard') (return ty) addTickStmt _isGuard (LetStmt binds) = do liftM LetStmt (addTickHsLocalBinds binds) -addTickStmt isGuard (ParStmt pairs) = do - liftM ParStmt +addTickStmt isGuard (ParStmt pairs mzipExpr bindExpr returnExpr) = do + liftM4 ParStmt (mapM (addTickStmtAndBinders isGuard) pairs) - -addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr) = do - liftM4 TransformStmt - (addTickLStmts isGuard stmts) - (return ids) - (addTickLHsExprAlways usingExpr) - (addTickMaybeByLHsExpr maybeByExpr) - -addTickStmt isGuard (GroupStmt stmts binderMap by using) = do - liftM4 GroupStmt - (addTickLStmts isGuard stmts) - (return binderMap) - (fmapMaybeM addTickLHsExprAlways by) - (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using) + (addTickSyntaxExpr hpcSrcSpan mzipExpr) + (addTickSyntaxExpr hpcSrcSpan bindExpr) + (addTickSyntaxExpr hpcSrcSpan returnExpr) + +addTickStmt isGuard (TransformStmt stmts ids usingExpr maybeByExpr returnExpr bindExpr) = do + t_s <- (addTickLStmts isGuard stmts) + t_u <- (addTickLHsExprAlways usingExpr) + t_m <- (addTickMaybeByLHsExpr maybeByExpr) + t_r <- (addTickSyntaxExpr hpcSrcSpan returnExpr) + t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr) + return $ TransformStmt t_s ids t_u t_m t_r t_b + +addTickStmt isGuard (GroupStmt stmts binderMap by using returnExpr bindExpr liftMExpr) = do + t_s <- (addTickLStmts isGuard stmts) + t_y <- (fmapMaybeM addTickLHsExprAlways by) + t_u <- (fmapEitherM addTickLHsExprAlways (addTickSyntaxExpr hpcSrcSpan) using) + t_f <- (addTickSyntaxExpr hpcSrcSpan returnExpr) + t_b <- (addTickSyntaxExpr hpcSrcSpan bindExpr) + t_m <- (addTickSyntaxExpr hpcSrcSpan liftMExpr) + return $ GroupStmt t_s binderMap t_y t_u t_b t_f t_m addTickStmt isGuard stmt@(RecStmt {}) = do { stmts' <- addTickLStmts isGuard (recS_stmts stmt) @@ -569,9 +577,10 @@ addTickHsCmd (HsLet binds c) = liftM2 HsLet (addTickHsLocalBinds binds) -- to think about: !patterns. (addTickLHsCmd c) -addTickHsCmd (HsDo cxt stmts last_exp srcloc) = do +addTickHsCmd (HsDo cxt stmts last_exp return_exp srcloc) = do (stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp) - return (HsDo cxt stmts' last_exp' srcloc) + return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp + return (HsDo cxt stmts' last_exp' return_exp' srcloc) addTickHsCmd (HsArrApp e1 e2 ty1 arr_ty lr) = liftM5 HsArrApp @@ -635,10 +644,11 @@ addTickCmdStmt (BindStmt pat c bind fail) = do (addTickLHsCmd c) (return bind) (return fail) -addTickCmdStmt (ExprStmt c bind' ty) = do - liftM3 ExprStmt +addTickCmdStmt (ExprStmt c bind' guard' ty) = do + liftM4 ExprStmt (addTickLHsCmd c) - (return bind') + (addTickSyntaxExpr hpcSrcSpan bind') + (addTickSyntaxExpr hpcSrcSpan guard') (return ty) addTickCmdStmt (LetStmt binds) = do liftM LetStmt diff --git a/compiler/deSugar/DsArrows.lhs b/compiler/deSugar/DsArrows.lhs index 58bf6b88e7..608f25e7f5 100644 --- a/compiler/deSugar/DsArrows.lhs +++ b/compiler/deSugar/DsArrows.lhs @@ -541,7 +541,7 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do core_body, exprFreeVars core_binds `intersectVarSet` local_vars) -dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _) +dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _ _) = dsCmdDo ids local_vars env_ids res_ty stmts body -- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t @@ -674,7 +674,7 @@ dsCmdStmt -- ---> arr (\ (xs) -> ((xs1),(xs'))) >>> first c >>> -- arr snd >>> ss -dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ c_ty) = do +dsCmdStmt ids local_vars env_ids out_ids (ExprStmt cmd _ _ c_ty) = do (core_cmd, fv_cmd, env_ids1) <- dsfixCmd ids local_vars [] c_ty cmd core_mux <- matchEnvStack env_ids [] (mkCorePairExpr (mkBigCoreVarTup env_ids1) (mkBigCoreVarTup out_ids)) diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs index 1781aef5f8..fb3f856c63 100644 --- a/compiler/deSugar/DsExpr.lhs +++ b/compiler/deSugar/DsExpr.lhs @@ -325,22 +325,25 @@ dsExpr (HsLet binds body) = do -- We need the `ListComp' form to use `deListComp' (rather than the "do" form) -- because the interpretation of `stmts' depends on what sort of thing it is. -- -dsExpr (HsDo ListComp stmts body result_ty) +dsExpr (HsDo ListComp stmts body _ result_ty) = -- Special case for list comprehensions dsListComp stmts body elt_ty where [elt_ty] = tcTyConAppArgs result_ty -dsExpr (HsDo DoExpr stmts body result_ty) +dsExpr (HsDo DoExpr stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo GhciStmt stmts body result_ty) +dsExpr (HsDo GhciStmt stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo MDoExpr stmts body result_ty) +dsExpr (HsDo MDoExpr stmts body _ result_ty) = dsDo stmts body result_ty -dsExpr (HsDo PArrComp stmts body result_ty) +dsExpr (HsDo MonadComp stmts body return_op result_ty) + = dsMonadComp stmts return_op body result_ty + +dsExpr (HsDo PArrComp stmts body _ result_ty) = -- Special case for array comprehensions dsPArrComp (map unLoc stmts) body elt_ty where @@ -722,7 +725,7 @@ dsDo stmts body result_ty goL [] = dsLExpr body goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts) - go _ (ExprStmt rhs then_expr _) stmts + go _ (ExprStmt rhs then_expr _ _) stmts = do { rhs2 <- dsLExpr rhs ; case tcSplitAppTy_maybe (exprType rhs2) of Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty @@ -769,7 +772,7 @@ dsDo stmts body result_ty mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body] (mkFunTy tup_ty body_ty)) mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats - body = noLoc $ HsDo DoExpr rec_stmts return_app body_ty + body = noLoc $ HsDo DoExpr rec_stmts return_app noSyntaxExpr body_ty return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets) body_ty = mkAppTy m_ty tup_ty tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case @@ -869,7 +872,7 @@ dsMDo ctxt tbl stmts body result_ty rets = map nlHsVar later_ids' ++ map noLoc rec_rets mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats - body = noLoc $ HsDo ctxt rec_stmts return_app body_ty + body = noLoc $ HsDo ctxt rec_stmts return_app noSyntaxExpr body_ty body_ty = mkAppTy m_ty tup_ty tup_ty = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids)) -- Deals with singleton case @@ -888,7 +891,6 @@ dsMDo ctxt tbl stmts body result_ty -} \end{code} - %************************************************************************ %* * Warning about identities diff --git a/compiler/deSugar/DsGRHSs.lhs b/compiler/deSugar/DsGRHSs.lhs index a7260e2af8..d3fcf76d1c 100644 --- a/compiler/deSugar/DsGRHSs.lhs +++ b/compiler/deSugar/DsGRHSs.lhs @@ -106,11 +106,11 @@ matchGuards [] _ rhs _ -- NB: The success of this clause depends on the typechecker not -- wrapping the 'otherwise' in empty HsTyApp or HsWrap constructors -- If it does, you'll get bogus overlap warnings -matchGuards (ExprStmt e _ _ : stmts) ctx rhs rhs_ty +matchGuards (ExprStmt e _ _ _ : stmts) ctx rhs rhs_ty | Just addTicks <- isTrueLHsExpr e = do match_result <- matchGuards stmts ctx rhs rhs_ty return (adjustMatchResultDs addTicks match_result) -matchGuards (ExprStmt expr _ _ : stmts) ctx rhs rhs_ty = do +matchGuards (ExprStmt expr _ _ _ : stmts) ctx rhs rhs_ty = do match_result <- matchGuards stmts ctx rhs rhs_ty pred_expr <- dsLExpr expr return (mkGuardedMatchResult pred_expr match_result) diff --git a/compiler/deSugar/DsListComp.lhs b/compiler/deSugar/DsListComp.lhs index cd22b8ff8c..7fa78487e9 100644 --- a/compiler/deSugar/DsListComp.lhs +++ b/compiler/deSugar/DsListComp.lhs @@ -3,9 +3,10 @@ % (c) The GRASP/AQUA Project, Glasgow University, 1992-1998 % -Desugaring list comprehensions and array comprehensions +Desugaring list comprehensions, monad comprehensions and array comprehensions \begin{code} +{-# LANGUAGE NamedFieldPuns #-} {-# OPTIONS -fno-warn-incomplete-patterns #-} -- The above warning supression flag is a temporary kludge. -- While working on this module you are encouraged to remove it and fix @@ -13,11 +14,11 @@ Desugaring list comprehensions and array comprehensions -- http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings -- for details -module DsListComp ( dsListComp, dsPArrComp ) where +module DsListComp ( dsListComp, dsPArrComp, dsMonadComp ) where #include "HsVersions.h" -import {-# SOURCE #-} DsExpr ( dsLExpr, dsLocalBinds ) +import {-# SOURCE #-} DsExpr ( dsExpr, dsLExpr, dsLocalBinds ) import HsSyn import TcHsSyn @@ -37,6 +38,7 @@ import PrelNames import SrcLoc import Outputable import FastString +import TcType \end{code} List comprehensions may be desugared in one of two ways: ``ordinary'' @@ -72,8 +74,8 @@ dsListComp lquals body elt_ty = do -- mix of possibly a single element in length, so we do this to leave the possibility open isParallelComp = any isParallelStmt - isParallelStmt (ParStmt _) = True - isParallelStmt _ = False + isParallelStmt (ParStmt _ _ _ _) = True + isParallelStmt _ = False -- This function lets you desugar a inner list comprehension and a list of the binders @@ -92,7 +94,7 @@ dsInnerListComp (stmts, bndrs) = do -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsTransformStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr) +dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr _ _) = do { (expr, binders_tuple_type) <- dsInnerListComp (stmts, binders) ; usingExpr' <- dsLExpr usingExpr @@ -116,7 +118,7 @@ dsTransformStmt (TransformStmt stmts binders usingExpr maybeByExpr) -- Given such a statement it gives you back an expression representing how to compute the transformed -- list and the tuple that you need to bind from that list in order to proceed with your desugaring dsGroupStmt :: Stmt Id -> DsM (CoreExpr, LPat Id) -dsGroupStmt (GroupStmt stmts binderMap by using) = do +dsGroupStmt (GroupStmt stmts binderMap by using _ _ _) = do let (fromBinders, toBinders) = unzip binderMap fromBindersTypes = map idType fromBinders @@ -228,7 +230,7 @@ with the Unboxed variety. deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr -deListComp (ParStmt stmtss_w_bndrs : quals) body list +deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list = do exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs let (exps, qual_tys) = unzip exps_and_qual_tys @@ -252,7 +254,7 @@ deListComp [] body list = do -- Figure 7.4, SLPJ, p 135, rule C above return (mkConsExpr (exprType core_body) core_body list) -- Non-last: must be a guard -deListComp (ExprStmt guard _ _ : quals) body list = do -- rule B above +deListComp (ExprStmt guard _ _ _ : quals) body list = do -- rule B above core_guard <- dsLExpr guard core_rest <- deListComp quals body list return (mkIfThenElse core_guard core_rest list) @@ -344,7 +346,7 @@ dfListComp c_id n_id [] body = do return (mkApps (Var c_id) [core_body, Var n_id]) -- Non-last: must be a guard -dfListComp c_id n_id (ExprStmt guard _ _ : quals) body = do +dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) body = do core_guard <- dsLExpr guard core_rest <- dfListComp c_id n_id quals body return (mkIfThenElse core_guard core_rest (Var n_id)) @@ -501,7 +503,7 @@ dsPArrComp :: [Stmt Id] -> LHsExpr Id -> Type -- Don't use; called with `undefined' below -> DsM CoreExpr -dsPArrComp [ParStmt qss] body _ = -- parallel comprehension +dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension dePArrParComp qss body -- Special case for simple generators: @@ -550,7 +552,7 @@ dePArrComp [] e' pa cea = do -- -- <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea) -- -dePArrComp (ExprStmt b _ _ : qs) body pa cea = do +dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do filterP <- dsLookupDPHId filterPName let ty = parrElemType cea (clam,_) <- deLambda ty pa b @@ -616,7 +618,7 @@ dePArrComp (LetStmt ds : qs) body pa cea = do -- singeltons qualifier lists, which we already special case in the caller. -- So, encountering one here is a bug. -- -dePArrComp (ParStmt _ : _) _ _ _ = +dePArrComp (ParStmt _ _ _ _ : _) _ _ _ = panic "DsListComp.dePArrComp: malformed comprehension AST" -- <<[:e' | qs | qss:]>> pa ea = @@ -682,3 +684,341 @@ parrElemType e = _ -> panic "DsListComp.parrElemType: not a parallel array type" \end{code} + +Translation for monad comprehensions + +\begin{code} + +-- | Keep the "context" of a monad comprehension in a small data type to avoid +-- some boilerplate... +data DsMonadComp = DsMonadComp + { mc_return :: Either (SyntaxExpr Id) (Expr CoreBndr) + , mc_body :: LHsExpr Id + , mc_m_ty :: Type + } + +-- +-- Entry point for monad comprehension desugaring +-- +dsMonadComp :: [LStmt Id] -- the statements + -> SyntaxExpr Id -- the "return" function + -> LHsExpr Id -- the body + -> Type -- the final type + -> DsM CoreExpr +dsMonadComp stmts return_op body res_ty + = dsMcStmts stmts (DsMonadComp (Left return_op) body m_ty) + where + (m_ty, _) = tcSplitAppTy res_ty + + +dsMcStmts :: [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr + +-- No statements left for desugaring. Desugar the body after calling "return" +-- on it. +dsMcStmts [] DsMonadComp { mc_return, mc_body } + = case mc_return of + Left ret -> dsLExpr $ noLoc ret `nlHsApp` mc_body + Right ret' -> do + { body' <- dsLExpr mc_body + ; return $ mkApps ret' [body'] } + +-- Otherwise desugar each statement step by step +dsMcStmts ((L loc stmt) : lstmts) mc + = putSrcSpanDs loc (dsMcStmt stmt lstmts mc) + + +dsMcStmt :: Stmt Id + -> [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr + +-- [ .. | let binds, stmts ] +dsMcStmt (LetStmt binds) stmts mc + = do { rest <- dsMcStmts stmts mc + ; dsLocalBinds binds rest } + +-- [ .. | a <- m, stmts ] +dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts mc + = do { rhs' <- dsLExpr rhs + ; dsMcBindStmt pat rhs' bind_op fail_op stmts mc } + +-- Apply `guard` to the `exp` expression +-- +-- [ .. | exp, stmts ] +-- +dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc + = do { exp' <- dsLExpr exp + ; guard_exp' <- dsExpr guard_exp + ; then_exp' <- dsExpr then_exp + ; rest <- dsMcStmts stmts mc + ; return $ mkApps then_exp' [ mkApps guard_exp' [exp'] + , rest ] } + +-- Transform statements desugar like this: +-- +-- [ .. | qs, then f by e ] -> f (\q_v -> e) [| qs |] +-- +-- where [| qs |] is the desugared inner monad comprehenion generated by the +-- statements `qs`. +dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest mc + = do { (expr, _) <- dsInnerMonadComp (stmts, binders) (mc { mc_return = Left return_op }) + ; let binders_tuple_type = mkBigCoreTupTy $ map idType binders + ; usingExpr' <- dsLExpr usingExpr + ; using_args <- case maybeByExpr of + Nothing -> return [expr] + Just byExpr -> do + byExpr' <- dsLExpr byExpr + us <- newUniqueSupply + tuple_binder <- newSysLocalDs binders_tuple_type + let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder) + return [Lam tuple_binder byExprWrapper, expr] + + ; let pat = mkBigLHsVarPatTup binders + rhs = mkApps usingExpr' ((Type binders_tuple_type) : using_args) + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +-- Group statements desugar like this: +-- +-- [| q, then group by e using f |] -> (f (\q_v -> e) [| q |]) >>= (return . (unzip q_v)) +-- +-- which is equal to +-- +-- [| q, then group by e using f |] -> liftM (unzip q_v) (f (\q_v -> e) [| q |]) +-- +-- where unzip is of the form +-- +-- unzip :: m (a,b,c,..) -> (m a,m b,m c,..) +-- unzip m_tuple = ( liftM selN1 m_tuple +-- , liftM selN2 m_tuple +-- , .. ) +-- where selN1 (a,b,c,..) = a +-- selN2 (a,b,c,..) = b +-- .. +-- +dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_rest mc + = do { let (fromBinders, toBinders) = unzip binderMap + fromBindersTypes = map idType fromBinders + fromBindersTupleTy = mkBigCoreTupTy fromBindersTypes + toBindersTypes = map idType toBinders + toBindersTupleTy = mkBigCoreTupTy toBindersTypes + m_ty = mc_m_ty mc + + -- Desugar an inner comprehension which outputs a list of tuples of the "from" binders + ; (expr, _) <- dsInnerMonadComp (stmts, fromBinders) (mc { mc_return = Left return_op }) + + -- Work out what arguments should be supplied to that expression: i.e. is an extraction + -- function required? If so, create that desugared function and add to arguments + ; usingExpr' <- dsLExpr (either id noLoc using) + ; usingArgs <- case by of + Nothing -> return [expr] + Just by_e -> do { by_e' <- dsLExpr by_e + ; us <- newUniqueSupply + ; from_tup_id <- newSysLocalDs fromBindersTupleTy + ; let by_wrap = mkTupleCase us fromBinders by_e' + from_tup_id (Var from_tup_id) + ; return [Lam from_tup_id by_wrap, expr] } + + -- Create an unzip function for the appropriate arity and element types + ; liftM_op' <- dsExpr liftM_op + ; (unzip_fn, unzip_rhs) <- mkMcUnzipM liftM_op' m_ty fromBindersTypes + + -- Generate the expressions to build the grouped list + + ; let -- First we apply the grouping function to the inner monad + inner_monad_expr = mkApps usingExpr' ((Type fromBindersTupleTy) : usingArgs) + -- Then we map our "unzip" across it to turn the "monad of tuples" into "tuples of monads" + -- We make sure we instantiate the type variable "a" to be a "monad of 'from' tuples" and + -- the "b" to be a "tuple of 'to' monads"! + unzipped_inner_monad_expr = mkApps liftM_op' -- ! + -- Types: + [ Type (m_ty `mkAppTy` fromBindersTupleTy), Type toBindersTupleTy + -- And arguments: + , Var unzip_fn, inner_monad_expr ] + -- Then finally we bind the unzip function around that expression + bound_unzipped_inner_monad_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_monad_expr + + -- Build a pattern that ensures the consumer binds into the NEW binders, which hold monads + -- rather than single values + ; let pat = mkBigLHsVarPatTup toBinders + rhs = bound_unzipped_inner_monad_expr + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel +-- statements, for example: +-- +-- [ body | qs1 | qs2 | qs3 ] +-- -> [ body | (bndrs1, (bndrs2, bndrs3)) <- mzip qs1 (mzip qs2 qs3) ] +-- +-- where `mzip` is of the form +-- +-- mzip :: m a -> m b -> m (a,b) +-- +dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc + = do { -- Get types for `return` + return_op' <- dsExpr return_op + ; let pairs_with_return = map (\tp@(_,b) -> (mkReturn b,tp)) pairs + mkReturn bndrs = mkApps return_op' [Type (mkBigCoreTupTy (map idType bndrs))] + + ; pairs' <- mapM (\(r,tp) -> dsInnerMonadComp tp mc{mc_return = Right r}) + pairs_with_return + + ; let (exps, _qual_tys) = unzip pairs' + -- Types of our `Id`s are getting messed up by `dsInnerMonadComp` + -- so we construct them by hand: + qual_tys = map (mkBigCoreTupTy . map idType . snd) pairs + + ; mzip_op' <- dsExpr mzip_op + ; (zip_fn, zip_rhs) <- mkMcZipM mzip_op' (mc_m_ty mc) qual_tys + + ; let -- The pattern variables + vars = map (mkBigLHsVarPatTup . snd) pairs + -- Pattern with tuples of variables + -- [v1,v2,v3] => (v1, (v2, v3)) + pat = foldr (\tn tm -> mkBigLHsPatTup [tn, tm]) (last vars) (init vars) + rhs = Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps) + + ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc } + +dsMcStmt stmt _ _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt) + + +-- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a +-- desugared `CoreExpr` +dsMcBindStmt :: LPat Id + -> CoreExpr -- ^ the desugared rhs of the bind statement + -> SyntaxExpr Id + -> SyntaxExpr Id + -> [LStmt Id] + -> DsMonadComp + -> DsM CoreExpr +dsMcBindStmt pat rhs' bind_op fail_op stmts mc + = do { body <- dsMcStmts stmts mc + ; bind_op' <- dsExpr bind_op + ; var <- selectSimpleMatchVarL pat + ; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2 + res1_ty = funResultTy (funArgTy (funResultTy bind_ty)) + ; match <- matchSinglePat (Var var) (StmtCtxt DoExpr) pat + res1_ty (cantFailMatchResult body) + ; match_code <- handle_failure pat match fail_op + ; return (mkApps bind_op' [rhs', Lam var match_code]) } + + where + -- In a monad comprehension expression, pattern-match failure just calls + -- the monadic `fail` rather than throwing an exception + handle_failure pat match fail_op + | matchCanFail match + = do { fail_op' <- dsExpr fail_op + ; fail_msg <- mkStringExpr (mk_fail_msg pat) + ; extractMatchResult match (App fail_op' fail_msg) } + | otherwise + = extractMatchResult match (error "It can't fail") + + mk_fail_msg :: Located e -> String + mk_fail_msg pat = "Pattern match failure in monad comprehension at " ++ + showSDoc (ppr (getLoc pat)) + +-- Desugar nested monad comprehensions, for example in `then..` constructs +dsInnerMonadComp :: ([LStmt Id], [Id]) + -> DsMonadComp + -> DsM (CoreExpr, Type) +dsInnerMonadComp (stmts, bndrs) DsMonadComp{ mc_return, mc_m_ty } + = do { expr <- dsMcStmts stmts mc' + ; return (expr, bndrs_tuple_type) } + where + bndrs_types = map idType bndrs + bndrs_tuple_type = mkAppTy mc_m_ty $ mkBigCoreTupTy bndrs_types + mc' = DsMonadComp mc_return (mkBigLHsVarTup bndrs) mc_m_ty + +-- The `unzip` function for `GroupStmt` in a monad comprehensions +-- +-- unzip :: m (a,b,..) -> (m a,m b,..) +-- unzip m_tuple = ( liftM selN1 m_tuple +-- , liftM selN2 m_tuple +-- , .. ) +-- +-- mkMcUnzipM m [t1, t2] +-- = (unzip_fn, \ys :: m (t1, t2) -> +-- ( liftM (selN1 :: (t1, t2) -> t1) ys +-- , liftM (selN2 :: (t1, t2) -> t2) ys +-- )) +-- +mkMcUnzipM :: CoreExpr + -> Type -- m + -> [Type] -- [a,b,c,..] + -> DsM (Id, CoreExpr) +mkMcUnzipM liftM_op m_ty elt_tys + = do { ys <- newSysLocalDs monad_tuple_ty + ; xs <- mapM newSysLocalDs elt_tys + ; scrut <- newSysLocalDs tuple_tys + + ; unzip_fn <- newSysLocalDs unzip_fn_ty + + ; let -- Select one Id from our tuple + selectExpr n = mkLams [scrut] $ mkTupleSelector xs (xs !! n) scrut (Var scrut) + -- Apply 'selectVar' and 'ys' to 'liftM' + tupleElem n = mkApps liftM_op + -- Types (m is figured out by the type checker): + -- liftM :: forall a b. (a -> b) -> m a -> m b + [ Type tuple_tys, Type (elt_tys !! n) + -- Arguments: + , selectExpr n, Var ys ] + -- The final expression with the big tuple + unzip_body = mkBigCoreTup [ tupleElem n | n <- [0..length elt_tys - 1] ] + + ; return (unzip_fn, mkLams [ys] unzip_body) } + where monad_tys = map (m_ty `mkAppTy`) elt_tys -- [m a,m b,m c,..] + tuple_monad_tys = mkBigCoreTupTy monad_tys -- (m a,m b,m c,..) + tuple_tys = mkBigCoreTupTy elt_tys -- (a,b,c,..) + monad_tuple_ty = m_ty `mkAppTy` tuple_tys -- m (a,b,c,..) + unzip_fn_ty = monad_tuple_ty `mkFunTy` tuple_monad_tys -- m (a,b,c,..) -> (m a,m b,m c,..) + +-- Generate the `mzip` function for `ParStmt` in monad comprehensions, for +-- example: +-- +-- mzip :: m t1 +-- -> (m t2 -> m t3 -> m (t2, t3)) +-- -> m (t1, (t2, t3)) +-- +-- mkMcZipM m [t1, t2, t3] +-- = (zip_fn, \(q1::t1) (q2::t2) (q3::t3) -> +-- mzip q1 (mzip q2 q3)) +-- +mkMcZipM :: CoreExpr + -> Type + -> [Type] + -> DsM (Id, CoreExpr) + +mkMcZipM mzip_op m_ty tys@(_:_:_) -- min. 2 types + = do { (ids, t1, tuple_ty, zip_body) <- loop tys + ; zip_fn <- newSysLocalDs $ + (m_ty `mkAppTy` t1) + `mkFunTy` + (m_ty `mkAppTy` tuple_ty) + `mkFunTy` + (m_ty `mkAppTy` mkBigCoreTupTy [t1, tuple_ty]) + ; return (zip_fn, mkLams ids zip_body) } + + where + -- loop :: [Type] -> DsM ([Id], Type, [Type], CoreExpr) + loop [t1, t2] = do -- last run of the `loop` + { ids@[a,b] <- newSysLocalsDs (map (m_ty `mkAppTy`) [t1,t2]) + ; let zip_body = mkApps mzip_op [ Type t1, Type t2 , Var a, Var b ] + ; return (ids, t1, t2, zip_body) } + + loop (t1:tr) = do + { -- Get ty, ids etc from the "inner" zip + (ids', t1', t2', zip_body') <- loop tr + + ; a <- newSysLocalDs $ m_ty `mkAppTy` t1 + ; let tuple_ty' = mkBigCoreTupTy [t1', t2'] + zip_body = mkApps mzip_op [ Type t1, Type tuple_ty', Var a, zip_body' ] + ; return ((a:ids'), t1, tuple_ty', zip_body) } + +-- This case should never happen: +mkMcZipM _ _ tys = pprPanic "mkMcZipM: unexpected argument" (ppr tys) + +\end{code} diff --git a/compiler/deSugar/DsMeta.hs b/compiler/deSugar/DsMeta.hs index e34c6960d7..2c1939ff73 100644 --- a/compiler/deSugar/DsMeta.hs +++ b/compiler/deSugar/DsMeta.hs @@ -721,7 +721,7 @@ repE (HsLet bs e) = do { (ss,ds) <- repBinds bs ; wrapGenSyms ss z } -- FIXME: I haven't got the types here right yet -repE e@(HsDo ctxt sts body _) +repE e@(HsDo ctxt sts body _ _) | case ctxt of { DoExpr -> True; GhciStmt -> True; _ -> False } = do { (ss,zs) <- repLSts sts; body' <- addBinds ss $ repLE body; @@ -737,7 +737,7 @@ repE e@(HsDo ctxt sts body _) wrapGenSyms ss e' } | otherwise - = notHandled "mdo and [: :]" (ppr e) + = notHandled "mdo, monad comprehension and [: :]" (ppr e) repE (ExplicitList _ es) = do { xs <- repLEs es; repListExp xs } repE e@(ExplicitPArr _ _) = notHandled "Parallel arrays" (ppr e) @@ -817,7 +817,7 @@ repGuards other wrapGenSyms (concat xs) gd } where process :: LGRHS Name -> DsM ([GenSymBind], (Core (TH.Q (TH.Guard, TH.Exp)))) - process (L _ (GRHS [L _ (ExprStmt e1 _ _)] e2)) + process (L _ (GRHS [L _ (ExprStmt e1 _ _ _)] e2)) = do { x <- repLNormalGE e1 e2; return ([], x) } process (L _ (GRHS ss rhs)) @@ -876,7 +876,7 @@ repSts (LetStmt bs : ss) = ; z <- repLetSt ds ; (ss2,zs) <- addBinds ss1 (repSts ss) ; return (ss1++ss2, z : zs) } -repSts (ExprStmt e _ _ : ss) = +repSts (ExprStmt e _ _ _ : ss) = do { e2 <- repLE e ; z <- repNoBindSt e2 ; (ss2,zs) <- repSts ss |