summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2011-04-29 18:06:03 +0100
committerSimon Peyton Jones <simonpj@microsoft.com>2011-04-29 18:06:03 +0100
commit4ac2bb39dffb4b825ece73b349ff0d56d79092d7 (patch)
treefeed26ef7e157d3fa025cb5d2df97e277940b00b
parent5ccf658872ea2304f34eda6b1fb840fc1bfc0ba0 (diff)
downloadhaskell-4ac2bb39dffb4b825ece73b349ff0d56d79092d7.tar.gz
Simon's hacking on monad-comp; incomplete
-rw-r--r--compiler/deSugar/Coverage.lhs29
-rw-r--r--compiler/deSugar/DsArrows.lhs13
-rw-r--r--compiler/deSugar/DsExpr.lhs210
-rw-r--r--compiler/deSugar/DsListComp.lhs372
-rw-r--r--compiler/hsSyn/HsExpr.lhs74
-rw-r--r--compiler/hsSyn/HsUtils.lhs32
-rw-r--r--compiler/parser/Parser.y.pp21
-rw-r--r--compiler/parser/RdrHsSyn.lhs30
-rw-r--r--compiler/prelude/PrelNames.lhs3
-rw-r--r--compiler/rename/RnExpr.lhs206
-rw-r--r--compiler/typecheck/TcExpr.lhs4
-rw-r--r--compiler/typecheck/TcGenDeriv.lhs15
-rw-r--r--compiler/typecheck/TcHsSyn.lhs18
-rw-r--r--compiler/typecheck/TcMatches.lhs285
14 files changed, 670 insertions, 642 deletions
diff --git a/compiler/deSugar/Coverage.lhs b/compiler/deSugar/Coverage.lhs
index e73c2499e8..711f66e9ab 100644
--- a/compiler/deSugar/Coverage.lhs
+++ b/compiler/deSugar/Coverage.lhs
@@ -301,11 +301,9 @@ addTickHsExpr (HsLet binds e) =
liftM2 HsLet
(addTickHsLocalBinds binds) -- to think about: !patterns.
(addTickLHsExprNeverOrAlways e)
-addTickHsExpr (HsDo cxt stmts last_exp return_exp srcloc) = do
- (stmts', last_exp') <- addTickLStmts' forQual stmts
- (addTickLHsExpr last_exp)
- return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp
- return (HsDo cxt stmts' last_exp' return_exp' srcloc)
+addTickHsExpr (HsDo cxt stmts srcloc)
+ = do { (stmts', _) <- addTickLStmts' forQual stmts (return ())
+ ; return (HsDo cxt stmts' srcloc) }
where
forQual = case cxt of
ListComp -> Just $ BinBox QualBinBox
@@ -425,14 +423,16 @@ addTickLStmts isGuard stmts = do
addTickLStmts' :: (Maybe (Bool -> BoxLabel)) -> [LStmt Id] -> TM a
-> TM ([LStmt Id], a)
addTickLStmts' isGuard lstmts res
- = bindLocals binders $ do
- lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts
- a <- res
- return (lstmts', a)
- where
- binders = collectLStmtsBinders lstmts
+ = bindLocals (collectLStmtsBinders lstmts) $
+ do { lstmts' <- mapM (liftL (addTickStmt isGuard)) lstmts
+ ; a <- res
+ ; return (lstmts', a) }
addTickStmt :: (Maybe (Bool -> BoxLabel)) -> Stmt Id -> TM (Stmt Id)
+addTickStmt _isGuard (LastStmt e ret) = do
+ liftM2 LastStmt
+ (addTickLHsExprAlways e)
+ (addTickSyntaxExpr hpcSrcSpan ret)
addTickStmt _isGuard (BindStmt pat e bind fail) = do
liftM4 BindStmt
(addTickLPat pat)
@@ -577,10 +577,9 @@ addTickHsCmd (HsLet binds c) =
liftM2 HsLet
(addTickHsLocalBinds binds) -- to think about: !patterns.
(addTickLHsCmd c)
-addTickHsCmd (HsDo cxt stmts last_exp return_exp srcloc) = do
- (stmts', last_exp') <- addTickLCmdStmts' stmts (addTickLHsCmd last_exp)
- return_exp' <- addTickSyntaxExpr hpcSrcSpan return_exp
- return (HsDo cxt stmts' last_exp' return_exp' srcloc)
+addTickHsCmd (HsDo cxt stmts srcloc)
+ = do { (stmts', _) <- addTickLCmdStmts' stmts (return ())
+ ; return (HsDo cxt stmts' srcloc) }
addTickHsCmd (HsArrApp e1 e2 ty1 arr_ty lr) =
liftM5 HsArrApp
diff --git a/compiler/deSugar/DsArrows.lhs b/compiler/deSugar/DsArrows.lhs
index 608f25e7f5..a5bf2b69d6 100644
--- a/compiler/deSugar/DsArrows.lhs
+++ b/compiler/deSugar/DsArrows.lhs
@@ -541,8 +541,8 @@ dsCmd ids local_vars env_ids stack res_ty (HsLet binds body) = do
core_body,
exprFreeVars core_binds `intersectVarSet` local_vars)
-dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts body _ _)
- = dsCmdDo ids local_vars env_ids res_ty stmts body
+dsCmd ids local_vars env_ids [] res_ty (HsDo _ctxt stmts _)
+ = dsCmdDo ids local_vars env_ids res_ty stmts
-- A |- e :: forall e. a1 (e*ts1) t1 -> ... an (e*tsn) tn -> a (e*ts) t
-- A | xs |- ci :: [tsi] ti
@@ -618,7 +618,6 @@ dsCmdDo :: DsCmdEnv -- arrow combinators
-- so don't pull on it too early
-> Type -- return type of the statement
-> [LStmt Id] -- statements to desugar
- -> LHsExpr Id -- body
-> DsM (CoreExpr, -- desugared expression
IdSet) -- set of local vars that occur free
@@ -626,15 +625,17 @@ dsCmdDo :: DsCmdEnv -- arrow combinators
-- --------------------------
-- A | xs |- do { c } :: [] t
-dsCmdDo ids local_vars env_ids res_ty [] body
+dsCmdDo _ _ _ _ [] = panic "dsCmdDo"
+
+dsCmdDo ids local_vars env_ids res_ty [L _ (LastStmt body _)]
= dsLCmd ids local_vars env_ids [] res_ty body
-dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) body = do
+dsCmdDo ids local_vars env_ids res_ty (stmt:stmts) = do
let
bound_vars = mkVarSet (collectLStmtBinders stmt)
local_vars' = local_vars `unionVarSet` bound_vars
(core_stmts, _, env_ids') <- fixDs (\ ~(_,_,env_ids') -> do
- (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts body
+ (core_stmts, fv_stmts) <- dsCmdDo ids local_vars' env_ids' res_ty stmts
return (core_stmts, fv_stmts, varSetElems fv_stmts))
(core_stmt, fv_stmt) <- dsCmdLStmt ids local_vars env_ids env_ids' stmt
return (do_compose ids
diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs
index fb3f856c63..c55c2d4c74 100644
--- a/compiler/deSugar/DsExpr.lhs
+++ b/compiler/deSugar/DsExpr.lhs
@@ -325,29 +325,12 @@ dsExpr (HsLet binds body) = do
-- We need the `ListComp' form to use `deListComp' (rather than the "do" form)
-- because the interpretation of `stmts' depends on what sort of thing it is.
--
-dsExpr (HsDo ListComp stmts body _ result_ty)
- = -- Special case for list comprehensions
- dsListComp stmts body elt_ty
- where
- [elt_ty] = tcTyConAppArgs result_ty
-
-dsExpr (HsDo DoExpr stmts body _ result_ty)
- = dsDo stmts body result_ty
-
-dsExpr (HsDo GhciStmt stmts body _ result_ty)
- = dsDo stmts body result_ty
-
-dsExpr (HsDo MDoExpr stmts body _ result_ty)
- = dsDo stmts body result_ty
-
-dsExpr (HsDo MonadComp stmts body return_op result_ty)
- = dsMonadComp stmts return_op body result_ty
-
-dsExpr (HsDo PArrComp stmts body _ result_ty)
- = -- Special case for array comprehensions
- dsPArrComp (map unLoc stmts) body elt_ty
- where
- [elt_ty] = tcTyConAppArgs result_ty
+dsExpr (HsDo ListComp stmts res_ty) = dsListComp stmts res_ty
+dsExpr (HsDo PArrComp stmts _) = dsPArrComp (map unLoc stmts)
+dsExpr (HsDo DoExpr stmts res_ty) = dsDo stmts res_ty
+dsExpr (HsDo GhciStmt stmts res_ty) = dsDo stmts res_ty
+dsExpr (HsDo MDoExpr stmts res_ty) = dsDo stmts res_ty
+dsExpr (HsDo MonadComp stmts res_ty) = dsMonadComp stmts res_ty
dsExpr (HsIf mb_fun guard_expr then_expr else_expr)
= do { pred <- dsLExpr guard_expr
@@ -712,24 +695,24 @@ Haskell 98 report:
\begin{code}
dsDo :: [LStmt Id]
- -> LHsExpr Id
-> Type -- Type of the whole expression
-> DsM CoreExpr
-dsDo stmts body result_ty
+dsDo stmts result_ty
= goL stmts
where
- -- result_ty must be of the form (m b)
- (m_ty, _b_ty) = tcSplitAppTy result_ty
-
- goL [] = dsLExpr body
- goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
+ goL [] = panic "dsDo"
+ goL (L loc stmt:lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
+ go _ (LastStmt body ret_op) stmts
+ = ASSERT( null stmts )
+ do { body' <- dsLExpr body
+ ; ret_op' <- dsExpr ret_op
+ ; return (App ret_op' body') }
+
go _ (ExprStmt rhs then_expr _ _) stmts
= do { rhs2 <- dsLExpr rhs
- ; case tcSplitAppTy_maybe (exprType rhs2) of
- Just (container_ty, returning_ty) -> warnDiscardedDoBindings rhs container_ty returning_ty
- _ -> return ()
+ ; warnDiscardedDoBindings rhs (exprType rhs2)
; then_expr2 <- dsExpr then_expr
; rest <- goL stmts
; return (mkApps then_expr2 [rhs2, rest]) }
@@ -753,29 +736,25 @@ dsDo stmts body result_ty
go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = return_op
, recS_mfix_fn = mfix_op, recS_bind_fn = bind_op
- , recS_rec_rets = rec_rets }) stmts
+ , recS_rec_rets = rec_rets, recS_ret_ty = body_ty }) stmts
= ASSERT( length rec_ids > 0 )
goL (new_bind_stmt : stmts)
where
- -- returnE <- dsExpr return_id
- -- mfixE <- dsExpr mfix_id
- new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats) mfix_app
- bind_op
+ new_bind_stmt = L loc $ BindStmt (mkLHsPatTup later_pats)
+ mfix_app bind_op
noSyntaxExpr -- Tuple cannot fail
tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids
rec_tup_pats = map nlVarPat tup_ids
later_pats = rec_tup_pats
rets = map noLoc rec_rets
-
- mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
- mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
- (mkFunTy tup_ty body_ty))
- mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
- body = noLoc $ HsDo DoExpr rec_stmts return_app noSyntaxExpr body_ty
- return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
- body_ty = mkAppTy m_ty tup_ty
- tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
+ mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
+ mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
+ (mkFunTy tup_ty body_ty))
+ mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
+ body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
+ ret_stmt = noLoc $ LastStmt return_op (mkLHsTupleExpr rets)
+ tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr
-- In a do expression, pattern-match failure just calls
@@ -793,103 +772,6 @@ mk_fail_msg pat = "Pattern match failure in do expression at " ++
showSDoc (ppr (getLoc pat))
\end{code}
-Translation for RecStmt's:
------------------------------
-We turn (RecStmt [v1,..vn] stmts) into:
-
- (v1,..,vn) <- mfix (\~(v1,..vn). do stmts
- return (v1,..vn))
-
-\begin{code}
-{-
-dsMDo :: HsStmtContext Name
- -> [(Name,Id)]
- -> [LStmt Id]
- -> LHsExpr Id
- -> Type -- Type of the whole expression
- -> DsM CoreExpr
-
-dsMDo ctxt tbl stmts body result_ty
- = goL stmts
- where
- goL [] = dsLExpr body
- goL ((L loc stmt):lstmts) = putSrcSpanDs loc (go loc stmt lstmts)
-
- (m_ty, b_ty) = tcSplitAppTy result_ty -- result_ty must be of the form (m b)
- return_id = lookupEvidence tbl returnMName
- bind_id = lookupEvidence tbl bindMName
- then_id = lookupEvidence tbl thenMName
- fail_id = lookupEvidence tbl failMName
-
- go _ (LetStmt binds) stmts
- = do { rest <- goL stmts
- ; dsLocalBinds binds rest }
-
- go _ (ExprStmt rhs then_expr rhs_ty) stmts
- = do { rhs2 <- dsLExpr rhs
- ; warnDiscardedDoBindings rhs m_ty rhs_ty
- ; then_expr2 <- dsExpr then_expr
- ; rest <- goL stmts
- ; return (mkApps then_expr2 [rhs2, rest]) }
-
- go _ (BindStmt pat rhs bind_op _) stmts
- = do { body <- goL stmts
- ; rhs' <- dsLExpr rhs
- ; bind_op' <- dsExpr bind_op
- ; var <- selectSimpleMatchVarL pat
- ; match <- matchSinglePat (Var var) (StmtCtxt ctxt) pat
- result_ty (cantFailMatchResult body)
- ; match_code <- handle_failure pat match fail_op
- ; return (mkApps bind_op [rhs', Lam var match_code]) }
-
- go loc (RecStmt { recS_stmts = rec_stmts, recS_later_ids = later_ids
- , recS_rec_ids = rec_ids, recS_rec_rets = rec_rets
- , recS_mfix_fn = mfix_op, recS_bind_fn = bind_op }) stmts
- = ASSERT( length rec_ids > 0 )
- ASSERT( length rec_ids == length rec_rets )
- ASSERT( isEmptyTcEvBinds _ev_binds )
- pprTrace "dsMDo" (ppr later_ids) $
- goL (new_bind_stmt : stmts)
- where
- new_bind_stmt = L loc $ BindStmt (mk_tup_pat later_pats) mfix_app
- bind_op noSyntaxExpr
-
- -- Remove the later_ids that appear (without fancy coercions)
- -- in rec_rets, because there's no need to knot-tie them separately
- -- See Note [RecStmt] in HsExpr
- later_ids' = filter (`notElem` mono_rec_ids) later_ids
- mono_rec_ids = [ id | HsVar id <- rec_rets ]
-
- mfix_app = nlHsApp (noLoc mfix_op) mfix_arg
- mfix_arg = noLoc $ HsLam (MatchGroup [mkSimpleMatch [mfix_pat] body]
- (mkFunTy tup_ty body_ty))
-
- -- The rec_tup_pat must bind the rec_ids only; remember that the
- -- trimmed_laters may share the same Names
- -- Meanwhile, the later_pats must bind the later_vars
- rec_tup_pats = map mk_wild_pat later_ids' ++ map nlVarPat rec_ids
- later_pats = map nlVarPat later_ids' ++ map mk_later_pat rec_ids
- rets = map nlHsVar later_ids' ++ map noLoc rec_rets
-
- mfix_pat = noLoc $ LazyPat $ mk_tup_pat rec_tup_pats
- body = noLoc $ HsDo ctxt rec_stmts return_app noSyntaxExpr body_ty
- body_ty = mkAppTy m_ty tup_ty
- tup_ty = mkBoxedTupleTy (map idType (later_ids' ++ rec_ids)) -- Deals with singleton case
-
- return_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
-
- mk_wild_pat :: Id -> LPat Id
- mk_wild_pat v = noLoc $ WildPat $ idType v
-
- mk_later_pat :: Id -> LPat Id
- mk_later_pat v | v `elem` later_ids' = mk_wild_pat v
- | otherwise = nlVarPat v
-
- mk_tup_pat :: [LPat Id] -> LPat Id
- mk_tup_pat [p] = p
- mk_tup_pat ps = noLoc $ mkVanillaTuplePat ps Boxed
--}
-\end{code}
%************************************************************************
%* *
@@ -929,30 +811,34 @@ conversionNames
\begin{code}
-- Warn about certain types of values discarded in monadic bindings (#3263)
-warnDiscardedDoBindings :: LHsExpr Id -> Type -> Type -> DsM ()
-warnDiscardedDoBindings rhs container_ty returning_ty = do {
- -- Warn about discarding non-() things in 'monadic' binding
- ; warn_unused <- doptDs Opt_WarnUnusedDoBind
- ; if warn_unused && not (returning_ty `tcEqType` unitTy)
- then warnDs (unusedMonadBind rhs returning_ty)
- else do {
- -- Warn about discarding m a things in 'monadic' binding of the same type,
- -- but only if we didn't already warn due to Opt_WarnUnusedDoBind
- ; warn_wrong <- doptDs Opt_WarnWrongDoBind
- ; case tcSplitAppTy_maybe returning_ty of
- Just (returning_container_ty, _) -> when (warn_wrong && container_ty `tcEqType` returning_container_ty) $
- warnDs (wrongMonadBind rhs returning_ty)
- _ -> return () } }
+warnDiscardedDoBindings :: LHsExpr Id -> Type -> DsM ()
+warnDiscardedDoBindings rhs rhs_ty
+ | Just (m_ty, elt_ty) <- tcSplitAppTy_maybe rhs_ty
+ = do { -- Warn about discarding non-() things in 'monadic' binding
+ ; warn_unused <- doptDs Opt_WarnUnusedDoBind
+ ; if warn_unused && not (isUnitTy elt_ty)
+ then warnDs (unusedMonadBind rhs elt_ty)
+ else
+ -- Warn about discarding m a things in 'monadic' binding of the same type,
+ -- but only if we didn't already warn due to Opt_WarnUnusedDoBind
+ do { warn_wrong <- doptDs Opt_WarnWrongDoBind
+ ; case tcSplitAppTy_maybe elt_ty of
+ Just (elt_m_ty, _) | warn_wrong, m_ty `tcEqType` elt_m_ty
+ -> warnDs (wrongMonadBind rhs elt_ty)
+ _ -> return () } }
+
+ | otherwise -- RHS does have type of form (m ty), which is wierd
+ = return () -- but at lesat this warning is irrelevant
unusedMonadBind :: LHsExpr Id -> Type -> SDoc
-unusedMonadBind rhs returning_ty
- = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$
+unusedMonadBind rhs elt_ty
+ = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$
ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$
ptext (sLit "or by using the flag -fno-warn-unused-do-bind")
wrongMonadBind :: LHsExpr Id -> Type -> SDoc
-wrongMonadBind rhs returning_ty
- = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr returning_ty <> dot $$
+wrongMonadBind rhs elt_ty
+ = ptext (sLit "A do-notation statement discarded a result of type") <+> ppr elt_ty <> dot $$
ptext (sLit "Suppress this warning by saying \"_ <- ") <> ppr rhs <> ptext (sLit "\",") $$
ptext (sLit "or by using the flag -fno-warn-wrong-do-bind")
\end{code}
diff --git a/compiler/deSugar/DsListComp.lhs b/compiler/deSugar/DsListComp.lhs
index 7fa78487e9..1ecab67e10 100644
--- a/compiler/deSugar/DsListComp.lhs
+++ b/compiler/deSugar/DsListComp.lhs
@@ -49,12 +49,12 @@ There will be at least one ``qualifier'' in the input.
\begin{code}
dsListComp :: [LStmt Id]
- -> LHsExpr Id
- -> Type -- Type of list elements
+ -> Type -- Type of entire list
-> DsM CoreExpr
-dsListComp lquals body elt_ty = do
+dsListComp lquals res_ty = do
dflags <- getDOptsDs
let quals = map unLoc lquals
+ [elt_ty] = tcTyConAppArgs res_ty
if not (dopt Opt_EnableRewriteRules dflags) || dopt Opt_IgnoreInterfacePragmas dflags
-- Either rules are switched off, or we are ignoring what there are;
@@ -62,8 +62,8 @@ dsListComp lquals body elt_ty = do
-- Wadler-style desugaring
|| isParallelComp quals
-- Foldr-style desugaring can't handle parallel list comprehensions
- then deListComp quals body (mkNilExpr elt_ty)
- else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals body)
+ then deListComp quals (mkNilExpr elt_ty)
+ else mkBuildExpr elt_ty (\(c, _) (n, _) -> dfListComp c n quals)
-- Foldr/build should be enabled, so desugar
-- into foldrs and builds
@@ -83,12 +83,11 @@ dsListComp lquals body elt_ty = do
-- and the type of the elements that it outputs (tuples of binders)
dsInnerListComp :: ([LStmt Id], [Id]) -> DsM (CoreExpr, Type)
dsInnerListComp (stmts, bndrs) = do
- expr <- dsListComp stmts (mkBigLHsVarTup bndrs) bndrs_tuple_type
- return (expr, bndrs_tuple_type)
- where
- bndrs_types = map idType bndrs
- bndrs_tuple_type = mkBigCoreTupTy bndrs_types
-
+ = do { expr <- dsListComp (stmts ++ [noLoc $ mkLastStmt (mkBigLHsVarTup bndrs)])
+ bndrs_tuple_type
+ ; return (expr, bndrs_tuple_type) }
+ where
+ bndrs_tuple_type = mkBigCoreVarTupTy bndrs
-- This function factors out commonality between the desugaring strategies for TransformStmt.
-- Given such a statement it gives you back an expression representing how to compute the transformed
@@ -228,9 +227,40 @@ with the Unboxed variety.
\begin{code}
-deListComp :: [Stmt Id] -> LHsExpr Id -> CoreExpr -> DsM CoreExpr
+deListComp :: [Stmt Id] -> CoreExpr -> DsM CoreExpr
+
+deListComp [] _ = panic "deListComp"
-deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list
+deListComp (LastStmt body _ : quals) list
+ = -- Figure 7.4, SLPJ, p 135, rule C above
+ ASSERT( null quals )
+ do { core_body <- dsLExpr body
+ ; return (mkConsExpr (exprType core_body) core_body list) }
+
+ -- Non-last: must be a guard
+deListComp (ExprStmt guard _ _ _ : quals) list = do -- rule B above
+ core_guard <- dsLExpr guard
+ core_rest <- deListComp quals list
+ return (mkIfThenElse core_guard core_rest list)
+
+-- [e | let B, qs] = let B in [e | qs]
+deListComp (LetStmt binds : quals) list = do
+ core_rest <- deListComp quals list
+ dsLocalBinds binds core_rest
+
+deListComp (stmt@(TransformStmt {}) : quals) list = do
+ (inner_list_expr, pat) <- dsTransformStmt stmt
+ deBindComp pat inner_list_expr quals list
+
+deListComp (stmt@(GroupStmt {}) : quals) list = do
+ (inner_list_expr, pat) <- dsGroupStmt stmt
+ deBindComp pat inner_list_expr quals list
+
+deListComp (BindStmt pat list1 _ _ : quals) core_list2 = do -- rule A' above
+ core_list1 <- dsLExpr list1
+ deBindComp pat core_list1 quals core_list2
+
+deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) list
= do
exps_and_qual_tys <- mapM dsInnerListComp stmtss_w_bndrs
let (exps, qual_tys) = unzip exps_and_qual_tys
@@ -239,7 +269,7 @@ deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list
-- Deal with [e | pat <- zip l1 .. ln] in example above
deBindComp pat (Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps))
- quals body list
+ quals list
where
bndrs_s = map snd stmtss_w_bndrs
@@ -247,34 +277,6 @@ deListComp (ParStmt stmtss_w_bndrs _ _ _ : quals) body list
-- pat is the pattern ((x1,..,xn), (y1,..,ym)) in the example above
pat = mkBigLHsPatTup pats
pats = map mkBigLHsVarPatTup bndrs_s
-
- -- Last: the one to return
-deListComp [] body list = do -- Figure 7.4, SLPJ, p 135, rule C above
- core_body <- dsLExpr body
- return (mkConsExpr (exprType core_body) core_body list)
-
- -- Non-last: must be a guard
-deListComp (ExprStmt guard _ _ _ : quals) body list = do -- rule B above
- core_guard <- dsLExpr guard
- core_rest <- deListComp quals body list
- return (mkIfThenElse core_guard core_rest list)
-
--- [e | let B, qs] = let B in [e | qs]
-deListComp (LetStmt binds : quals) body list = do
- core_rest <- deListComp quals body list
- dsLocalBinds binds core_rest
-
-deListComp (stmt@(TransformStmt {}) : quals) body list = do
- (inner_list_expr, pat) <- dsTransformStmt stmt
- deBindComp pat inner_list_expr quals body list
-
-deListComp (stmt@(GroupStmt {}) : quals) body list = do
- (inner_list_expr, pat) <- dsGroupStmt stmt
- deBindComp pat inner_list_expr quals body list
-
-deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' above
- core_list1 <- dsLExpr list1
- deBindComp pat core_list1 quals body core_list2
\end{code}
@@ -282,10 +284,9 @@ deListComp (BindStmt pat list1 _ _ : quals) body core_list2 = do -- rule A' abov
deBindComp :: OutPat Id
-> CoreExpr
-> [Stmt Id]
- -> LHsExpr Id
-> CoreExpr
-> DsM (Expr Id)
-deBindComp pat core_list1 quals body core_list2 = do
+deBindComp pat core_list1 quals core_list2 = do
let
u3_ty@u1_ty = exprType core_list1 -- two names, same thing
@@ -302,7 +303,7 @@ deBindComp pat core_list1 quals body core_list2 = do
core_fail = App (Var h) (Var u3)
letrec_body = App (Var h) core_list1
- rest_expr <- deListComp quals body core_fail
+ rest_expr <- deListComp quals core_fail
core_match <- matchSimply (Var u2) (StmtCtxt ListComp) pat rest_expr core_fail
let
@@ -337,48 +338,48 @@ TE[ e | p <- l , q ] c n = let
\begin{code}
dfListComp :: Id -> Id -- 'c' and 'n'
-> [Stmt Id] -- the rest of the qual's
- -> LHsExpr Id
-> DsM CoreExpr
- -- Last: the one to return
-dfListComp c_id n_id [] body = do
- core_body <- dsLExpr body
- return (mkApps (Var c_id) [core_body, Var n_id])
+dfListComp _ _ [] = panic "dfListComp"
+
+dfListComp c_id n_id (LastStmt body _ : quals)
+ = ASSERT( null quals )
+ do { core_body <- dsLExpr body
+ ; return (mkApps (Var c_id) [core_body, Var n_id]) }
-- Non-last: must be a guard
-dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) body = do
+dfListComp c_id n_id (ExprStmt guard _ _ _ : quals) = do
core_guard <- dsLExpr guard
- core_rest <- dfListComp c_id n_id quals body
+ core_rest <- dfListComp c_id n_id quals
return (mkIfThenElse core_guard core_rest (Var n_id))
-dfListComp c_id n_id (LetStmt binds : quals) body = do
+dfListComp c_id n_id (LetStmt binds : quals) = do
-- new in 1.3, local bindings
- core_rest <- dfListComp c_id n_id quals body
+ core_rest <- dfListComp c_id n_id quals
dsLocalBinds binds core_rest
-dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) body = do
+dfListComp c_id n_id (stmt@(TransformStmt {}) : quals) = do
(inner_list_expr, pat) <- dsTransformStmt stmt
-- Anyway, we bind the newly transformed list via the generic binding function
- dfBindComp c_id n_id (pat, inner_list_expr) quals body
+ dfBindComp c_id n_id (pat, inner_list_expr) quals
-dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) body = do
+dfListComp c_id n_id (stmt@(GroupStmt {}) : quals) = do
(inner_list_expr, pat) <- dsGroupStmt stmt
-- Anyway, we bind the newly grouped list via the generic binding function
- dfBindComp c_id n_id (pat, inner_list_expr) quals body
+ dfBindComp c_id n_id (pat, inner_list_expr) quals
-dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) body = do
+dfListComp c_id n_id (BindStmt pat list1 _ _ : quals) = do
-- evaluate the two lists
core_list1 <- dsLExpr list1
-- Do the rest of the work in the generic binding builder
- dfBindComp c_id n_id (pat, core_list1) quals body
+ dfBindComp c_id n_id (pat, core_list1) quals
dfBindComp :: Id -> Id -- 'c' and 'n'
-> (LPat Id, CoreExpr)
-> [Stmt Id] -- the rest of the qual's
- -> LHsExpr Id
-> DsM CoreExpr
-dfBindComp c_id n_id (pat, core_list1) quals body = do
+dfBindComp c_id n_id (pat, core_list1) quals = do
-- find the required type
let x_ty = hsLPatType pat
b_ty = idType n_id
@@ -387,7 +388,7 @@ dfBindComp c_id n_id (pat, core_list1) quals body = do
[b, x] <- newSysLocalsDs [b_ty, x_ty]
-- build rest of the comprehesion
- core_rest <- dfListComp c_id b quals body
+ core_rest <- dfListComp c_id b quals
-- build the pattern match
core_expr <- matchSimply (Var x) (StmtCtxt ListComp)
@@ -482,9 +483,6 @@ mkUnzipBind elt_tys = do
unzip_fn_ty = elt_tuple_list_ty `mkFunTy` elt_list_tuple_ty
mkConcatExpression (list_element_ty, head, tail) = mkConsExpr list_element_ty head tail
-
-
-
\end{code}
%************************************************************************
@@ -500,11 +498,10 @@ mkUnzipBind elt_tys = do
-- [:e | qss:] = <<[:e | qss:]>> () [:():]
--
dsPArrComp :: [Stmt Id]
- -> LHsExpr Id
- -> Type -- Don't use; called with `undefined' below
-> DsM CoreExpr
-dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension
- dePArrParComp qss body
+
+-- Special case for parallel comprehension
+dsPArrComp (ParStmt qss _ _ _ : quals) = dePArrParComp qss quals
-- Special case for simple generators:
--
@@ -515,7 +512,7 @@ dsPArrComp [ParStmt qss _ _ _] body _ = -- parallel comprehension
-- <<[:e' | p <- e, qs:]>> =
-- <<[:e' | qs:]>> p (filterP (\x -> case x of {p -> True; _ -> False}) e)
--
-dsPArrComp (BindStmt p e _ _ : qs) body _ = do
+dsPArrComp (BindStmt p e _ _ : qs) = do
filterP <- dsLookupDPHId filterPName
ce <- dsLExpr e
let ety'ce = parrElemType ce
@@ -525,38 +522,41 @@ dsPArrComp (BindStmt p e _ _ : qs) body _ = do
pred <- matchSimply (Var v) (StmtCtxt PArrComp) p true false
let gen | isIrrefutableHsPat p = ce
| otherwise = mkApps (Var filterP) [Type ety'ce, mkLams [v] pred, ce]
- dePArrComp qs body p gen
+ dePArrComp qs p gen
-dsPArrComp qs body _ = do -- no ParStmt in `qs'
+dsPArrComp qs = do -- no ParStmt in `qs'
sglP <- dsLookupDPHId singletonPName
let unitArray = mkApps (Var sglP) [Type unitTy, mkCoreTup []]
- dePArrComp qs body (noLoc $ WildPat unitTy) unitArray
+ dePArrComp qs (noLoc $ WildPat unitTy) unitArray
-- the work horse
--
dePArrComp :: [Stmt Id]
- -> LHsExpr Id
-> LPat Id -- the current generator pattern
-> CoreExpr -- the current generator expression
-> DsM CoreExpr
+
+dePArrComp [] _ _ = panic "dePArrComp"
+
--
-- <<[:e' | :]>> pa ea = mapP (\pa -> e') ea
--
-dePArrComp [] e' pa cea = do
- mapP <- dsLookupDPHId mapPName
- let ty = parrElemType cea
- (clam, ty'e') <- deLambda ty pa e'
- return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea]
+dePArrComp (LastStmt e' _ : quals) pa cea
+ = ASSERT( null quals )
+ do { mapP <- dsLookupDPHId mapPName
+ ; let ty = parrElemType cea
+ ; (clam, ty'e') <- deLambda ty pa e'
+ ; return $ mkApps (Var mapP) [Type ty, Type ty'e', clam, cea] }
--
-- <<[:e' | b, qs:]>> pa ea = <<[:e' | qs:]>> pa (filterP (\pa -> b) ea)
--
-dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do
+dePArrComp (ExprStmt b _ _ _ : qs) pa cea = do
filterP <- dsLookupDPHId filterPName
let ty = parrElemType cea
(clam,_) <- deLambda ty pa b
- dePArrComp qs body pa (mkApps (Var filterP) [Type ty, clam, cea])
+ dePArrComp qs pa (mkApps (Var filterP) [Type ty, clam, cea])
--
-- <<[:e' | p <- e, qs:]>> pa ea =
@@ -571,7 +571,7 @@ dePArrComp (ExprStmt b _ _ _ : qs) body pa cea = do
-- in
-- <<[:e' | qs:]>> (pa, p) (crossMapP ea ef)
--
-dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
+dePArrComp (BindStmt p e _ _ : qs) pa cea = do
filterP <- dsLookupDPHId filterPName
crossMapP <- dsLookupDPHId crossMapPName
ce <- dsLExpr e
@@ -587,7 +587,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
let ety'cef = ety'ce -- filter doesn't change the element type
pa' = mkLHsPatTup [pa, p]
- dePArrComp qs body pa' (mkApps (Var crossMapP)
+ dePArrComp qs pa' (mkApps (Var crossMapP)
[Type ety'cea, Type ety'cef, cea, clam])
--
-- <<[:e' | let ds, qs:]>> pa ea =
@@ -596,7 +596,7 @@ dePArrComp (BindStmt p e _ _ : qs) body pa cea = do
-- where
-- {x_1, ..., x_n} = DV (ds) -- Defined Variables
--
-dePArrComp (LetStmt ds : qs) body pa cea = do
+dePArrComp (LetStmt ds : qs) pa cea = do
mapP <- dsLookupDPHId mapPName
let xs = collectLocalBinders ds
ty'cea = parrElemType cea
@@ -611,14 +611,14 @@ dePArrComp (LetStmt ds : qs) body pa cea = do
ccase <- matchSimply (Var v) (StmtCtxt PArrComp) pa projBody cerr
let pa' = mkLHsPatTup [pa, mkLHsPatTup (map nlVarPat xs)]
proj = mkLams [v] ccase
- dePArrComp qs body pa' (mkApps (Var mapP)
+ dePArrComp qs pa' (mkApps (Var mapP)
[Type ty'cea, Type errTy, proj, cea])
--
-- The parser guarantees that parallel comprehensions can only appear as
-- singeltons qualifier lists, which we already special case in the caller.
-- So, encountering one here is a bug.
--
-dePArrComp (ParStmt _ _ _ _ : _) _ _ _ =
+dePArrComp (ParStmt _ _ _ _ : _) _ _ =
panic "DsListComp.dePArrComp: malformed comprehension AST"
-- <<[:e' | qs | qss:]>> pa ea =
@@ -627,17 +627,17 @@ dePArrComp (ParStmt _ _ _ _ : _) _ _ _ =
-- where
-- {x_1, ..., x_n} = DV (qs)
--
-dePArrParComp :: [([LStmt Id], [Id])] -> LHsExpr Id -> DsM CoreExpr
-dePArrParComp qss body = do
+dePArrParComp :: [([LStmt Id], [Id])] -> [Stmt Id] -> DsM CoreExpr
+dePArrParComp qss quals = do
(pQss, ceQss) <- deParStmt qss
- dePArrComp [] body pQss ceQss
+ dePArrComp quals pQss ceQss
where
deParStmt [] =
-- empty parallel statement lists have no source representation
panic "DsListComp.dePArrComp: Empty parallel list comprehension"
deParStmt ((qs, xs):qss) = do -- first statement
let res_expr = mkLHsVarTuple xs
- cqs <- dsPArrComp (map unLoc qs) res_expr undefined
+ cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
parStmts qss (mkLHsVarPatTup xs) cqs
---
parStmts [] pa cea = return (pa, cea)
@@ -646,7 +646,7 @@ dePArrParComp qss body = do
let pa' = mkLHsPatTup [pa, mkLHsVarPatTup xs]
ty'cea = parrElemType cea
res_expr = mkLHsVarTuple xs
- cqs <- dsPArrComp (map unLoc qs) res_expr undefined
+ cqs <- dsPArrComp (map unLoc qs ++ [mkLastStmt res_expr])
let ty'cqs = parrElemType cqs
cea' = mkApps (Var zipP) [Type ty'cea, Type ty'cqs, cea, cqs]
parStmts qss pa' cea'
@@ -701,11 +701,9 @@ data DsMonadComp = DsMonadComp
-- Entry point for monad comprehension desugaring
--
dsMonadComp :: [LStmt Id] -- the statements
- -> SyntaxExpr Id -- the "return" function
- -> LHsExpr Id -- the body
-> Type -- the final type
-> DsM CoreExpr
-dsMonadComp stmts return_op body res_ty
+dsMonadComp stmts res_ty
= dsMcStmts stmts (DsMonadComp (Left return_op) body m_ty)
where
(m_ty, _) = tcSplitAppTy res_ty
@@ -729,30 +727,33 @@ dsMcStmts ((L loc stmt) : lstmts) mc
= putSrcSpanDs loc (dsMcStmt stmt lstmts mc)
-dsMcStmt :: Stmt Id
- -> [LStmt Id]
- -> DsMonadComp
- -> DsM CoreExpr
+dsMcStmt :: Stmt Id -> [LStmt Id] -> DsM CoreExpr
+
+dsMcStmt (LastStmt body ret_op) stmts
+ = ASSERT( null stmts )
+ do { body' <- dsLExpr body
+ ; ret_op' <- dsExpr ret_op
+ ; return (App ret_op' body') }
-- [ .. | let binds, stmts ]
-dsMcStmt (LetStmt binds) stmts mc
- = do { rest <- dsMcStmts stmts mc
+dsMcStmt (LetStmt binds) stmts
+ = do { rest <- dsMcStmts stmts
; dsLocalBinds binds rest }
-- [ .. | a <- m, stmts ]
-dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts mc
- = do { rhs' <- dsLExpr rhs
- ; dsMcBindStmt pat rhs' bind_op fail_op stmts mc }
+dsMcStmt (BindStmt pat rhs bind_op fail_op) stmts
+ = do { rhs' <- dsLExpr rhs
+ ; dsMcBindStmt pat rhs' bind_op fail_op stmts }
-- Apply `guard` to the `exp` expression
--
-- [ .. | exp, stmts ]
--
-dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc
+dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts
= do { exp' <- dsLExpr exp
; guard_exp' <- dsExpr guard_exp
; then_exp' <- dsExpr then_exp
- ; rest <- dsMcStmts stmts mc
+ ; rest <- dsMcStmts stmts
; return $ mkApps then_exp' [ mkApps guard_exp' [exp']
, rest ] }
@@ -762,26 +763,38 @@ dsMcStmt (ExprStmt exp then_exp guard_exp _) stmts mc
--
-- where [| qs |] is the desugared inner monad comprehenion generated by the
-- statements `qs`.
-dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest mc
- = do { (expr, _) <- dsInnerMonadComp (stmts, binders) (mc { mc_return = Left return_op })
- ; let binders_tuple_type = mkBigCoreTupTy $ map idType binders
+dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) stmts_rest
+ = do { expr <- dsInnerMonadComp stmts binders return_op
+ ; let binders_tup_type = mkBigCoreTupTy $ map idType binders
; usingExpr' <- dsLExpr usingExpr
; using_args <- case maybeByExpr of
Nothing -> return [expr]
Just byExpr -> do
byExpr' <- dsLExpr byExpr
us <- newUniqueSupply
- tuple_binder <- newSysLocalDs binders_tuple_type
- let byExprWrapper = mkTupleCase us binders byExpr' tuple_binder (Var tuple_binder)
- return [Lam tuple_binder byExprWrapper, expr]
+ tup_binder <- newSysLocalDs binders_tup_type
+ let byExprWrapper = mkTupleCase us binders byExpr' tup_binder (Var tup_binder)
+ return [Lam tup_binder byExprWrapper, expr]
; let pat = mkBigLHsVarPatTup binders
- rhs = mkApps usingExpr' ((Type binders_tuple_type) : using_args)
+ rhs = mkApps usingExpr' ((Type binders_tup_type) : using_args)
- ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
+ ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
-- Group statements desugar like this:
--
+-- [| (q, then group by e using f); rest |]
+-- ---> f {qt} (\qv -> e) [| q; return qv |] >>= \ n_tup ->
+-- case unzip n_tup of qv -> [| rest |]
+--
+-- where variables (v1:t1, ..., vk:tk) are bound by q
+-- qv = (v1, ..., vk)
+-- qt = (t1, ..., tk)
+-- (>>=) :: m2 a -> (a -> m3 b) -> m3 b
+-- f :: forall a. (a -> t) -> m1 a -> m2 (n a)
+-- n_tup :: n qt
+-- unzip :: n qt -> (n t1, ..., n tk) (needs Functor n)
+--
-- [| q, then group by e using f |] -> (f (\q_v -> e) [| q |]) >>= (return . (unzip q_v))
--
-- which is equal to
@@ -790,24 +803,23 @@ dsMcStmt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) s
--
-- where unzip is of the form
--
--- unzip :: m (a,b,c,..) -> (m a,m b,m c,..)
--- unzip m_tuple = ( liftM selN1 m_tuple
--- , liftM selN2 m_tuple
+-- unzip :: n (a,b,c,..) -> (n a,n b,n c,..)
+-- unzip m_tuple = ( fmap selN1 m_tuple
+-- , fmap selN2 m_tuple
-- , .. )
-- where selN1 (a,b,c,..) = a
-- selN2 (a,b,c,..) = b
-- ..
--
-dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_rest mc
+dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op fmap_op) stmts_rest
= do { let (fromBinders, toBinders) = unzip binderMap
- fromBindersTypes = map idType fromBinders
+ fromBindersTypes = map idType fromBinders -- Types ty
fromBindersTupleTy = mkBigCoreTupTy fromBindersTypes
- toBindersTypes = map idType toBinders
+ toBindersTypes = map idType toBinders -- Types (n ty)
toBindersTupleTy = mkBigCoreTupTy toBindersTypes
- m_ty = mc_m_ty mc
-- Desugar an inner comprehension which outputs a list of tuples of the "from" binders
- ; (expr, _) <- dsInnerMonadComp (stmts, fromBinders) (mc { mc_return = Left return_op })
+ ; expr <- dsInnerMonadComp stmts fromBinders return_op
-- Work out what arguments should be supplied to that expression: i.e. is an extraction
-- function required? If so, create that desugared function and add to arguments
@@ -815,62 +827,45 @@ dsMcStmt (GroupStmt stmts binderMap by using return_op bind_op liftM_op) stmts_r
; usingArgs <- case by of
Nothing -> return [expr]
Just by_e -> do { by_e' <- dsLExpr by_e
- ; us <- newUniqueSupply
- ; from_tup_id <- newSysLocalDs fromBindersTupleTy
- ; let by_wrap = mkTupleCase us fromBinders by_e'
- from_tup_id (Var from_tup_id)
- ; return [Lam from_tup_id by_wrap, expr] }
+ ; lam <- matchTuple fromBinders by_e'
+ ; return [lam, expr] }
-- Create an unzip function for the appropriate arity and element types
- ; liftM_op' <- dsExpr liftM_op
- ; (unzip_fn, unzip_rhs) <- mkMcUnzipM liftM_op' m_ty fromBindersTypes
+ ; fmap_op' <- dsExpr fmap_op
+ ; (unzip_fn, unzip_rhs) <- mkMcUnzipM fmap_op' m_ty fromBindersTypes
-- Generate the expressions to build the grouped list
-
- ; let -- First we apply the grouping function to the inner monad
- inner_monad_expr = mkApps usingExpr' ((Type fromBindersTupleTy) : usingArgs)
- -- Then we map our "unzip" across it to turn the "monad of tuples" into "tuples of monads"
- -- We make sure we instantiate the type variable "a" to be a "monad of 'from' tuples" and
- -- the "b" to be a "tuple of 'to' monads"!
- unzipped_inner_monad_expr = mkApps liftM_op' -- !
- -- Types:
- [ Type (m_ty `mkAppTy` fromBindersTupleTy), Type toBindersTupleTy
- -- And arguments:
- , Var unzip_fn, inner_monad_expr ]
- -- Then finally we bind the unzip function around that expression
- bound_unzipped_inner_monad_expr = Let (Rec [(unzip_fn, unzip_rhs)]) unzipped_inner_monad_expr
-
- -- Build a pattern that ensures the consumer binds into the NEW binders, which hold monads
- -- rather than single values
- ; let pat = mkBigLHsVarPatTup toBinders
- rhs = bound_unzipped_inner_monad_expr
-
- ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
+ -- Build a pattern that ensures the consumer binds into the NEW binders,
+ -- which hold monads rather than single values
+ ; bind_op' <- dsExpr bind_op
+ ; let bind_ty = exprType bind_op' -- m2 (n (a,b,c)) -> (n (a,b,c) -> r1) -> r2
+ n_tup_ty = funArgTy $ funArgTy $ funResultTy bind_ty
+
+ ; body <- dsMcStmts stmts_rest
+ ; n_tup_var <- newSysLocalDs n_tup_ty
+ ; tup_n_var <- newSysLocalDs (mkBigCoreVarTupTy toBinders)
+ ; us <- newUniqueSupply
+ ; let unzip_n_tup = Let (Rec [(unzip_fn, unzip_rhs)]) $
+ App (Var unzip_fn) (Var n_tup_var)
+ -- unzip_n_tup :: (n a, n b, n c)
+ body' = mkTupleCase us toBinders body unzip_n_tup (Var tup_n_var)
+
+ ; return (mkApps bind_op' [rhs', Lam n_tup_var body']) }
-- Parallel statements. Use `Control.Monad.Zip.mzip` to zip parallel
-- statements, for example:
--
-- [ body | qs1 | qs2 | qs3 ]
--- -> [ body | (bndrs1, (bndrs2, bndrs3)) <- mzip qs1 (mzip qs2 qs3) ]
---
--- where `mzip` is of the form
+-- -> [ body | (bndrs1, (bndrs2, bndrs3))
+-- <- [bndrs1 | qs1] `mzip` ([bndrs2 | qs2] `mzip` [bndrs3 | qs3]) ]
--
--- mzip :: m a -> m b -> m (a,b)
---
-dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc
- = do { -- Get types for `return`
- return_op' <- dsExpr return_op
- ; let pairs_with_return = map (\tp@(_,b) -> (mkReturn b,tp)) pairs
- mkReturn bndrs = mkApps return_op' [Type (mkBigCoreTupTy (map idType bndrs))]
-
- ; pairs' <- mapM (\(r,tp) -> dsInnerMonadComp tp mc{mc_return = Right r})
- pairs_with_return
-
- ; let (exps, _qual_tys) = unzip pairs'
- -- Types of our `Id`s are getting messed up by `dsInnerMonadComp`
- -- so we construct them by hand:
- qual_tys = map (mkBigCoreTupTy . map idType . snd) pairs
+-- where `mzip` has type
+-- mzip :: forall a b. m a -> m b -> m (a,b)
+-- NB: we need a polymorphic mzip because we call it several times
+dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest
+ = do { exps <- mapM ds_inner pairs
+ ; let qual_tys = map (mkBigCoreVarTupTy . snd) pairs
; mzip_op' <- dsExpr mzip_op
; (zip_fn, zip_rhs) <- mkMcZipM mzip_op' (mc_m_ty mc) qual_tys
@@ -881,9 +876,23 @@ dsMcStmt (ParStmt pairs mzip_op bind_op return_op) stmts_rest mc
pat = foldr (\tn tm -> mkBigLHsPatTup [tn, tm]) (last vars) (init vars)
rhs = Let (Rec [(zip_fn, zip_rhs)]) (mkApps (Var zip_fn) exps)
- ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest mc }
+ ; dsMcBindStmt pat rhs bind_op noSyntaxExpr stmts_rest }
+ where
+ ds_inner (stmts, bndrs) = dsInnerMonadComp stmts bndrs mono_ret_op
+ where
+ mono_ret_op = HsWrap (WpTyApp (mkBigCoreVarTupTy bndrs)) return_op
-dsMcStmt stmt _ _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
+dsMcStmt stmt _ = pprPanic "dsMcStmt: unexpected stmt" (ppr stmt)
+
+
+matchTuple :: [Id] -> CoreExpr -> DsM CoreExpr
+-- (matchTuple [a,b,c] body)
+-- returns the Core term
+-- \x. case x of (a,b,c) -> body
+matchTuple ids body
+ = do { us <- newUniqueSupply
+ ; tup_id <- newSysLocalDs (mkBigLHsVarPatTup ids)
+ ; return (Lam tup_id $ mkTupleCase us ids body tup_id (Var tup_id)) }
-- general `rhs' >>= \pat -> stmts` desugaring where `rhs'` is already a
@@ -893,10 +902,9 @@ dsMcBindStmt :: LPat Id
-> SyntaxExpr Id
-> SyntaxExpr Id
-> [LStmt Id]
- -> DsMonadComp
-> DsM CoreExpr
-dsMcBindStmt pat rhs' bind_op fail_op stmts mc
- = do { body <- dsMcStmts stmts mc
+dsMcBindStmt pat rhs' bind_op fail_op stmts
+ = do { body <- dsMcStmts stmts
; bind_op' <- dsExpr bind_op
; var <- selectSimpleMatchVarL pat
; let bind_ty = exprType bind_op' -- rhs -> (pat -> res1) -> res2
@@ -922,16 +930,16 @@ dsMcBindStmt pat rhs' bind_op fail_op stmts mc
showSDoc (ppr (getLoc pat))
-- Desugar nested monad comprehensions, for example in `then..` constructs
-dsInnerMonadComp :: ([LStmt Id], [Id])
- -> DsMonadComp
- -> DsM (CoreExpr, Type)
-dsInnerMonadComp (stmts, bndrs) DsMonadComp{ mc_return, mc_m_ty }
- = do { expr <- dsMcStmts stmts mc'
- ; return (expr, bndrs_tuple_type) }
- where
- bndrs_types = map idType bndrs
- bndrs_tuple_type = mkAppTy mc_m_ty $ mkBigCoreTupTy bndrs_types
- mc' = DsMonadComp mc_return (mkBigLHsVarTup bndrs) mc_m_ty
+-- dsInnerMonadComp quals [a,b,c] ret_op
+-- returns the desugaring of
+-- [ (a,b,c) | quals ]
+
+dsInnerMonadComp :: [LStmt Id]
+ -> [Id] -- Return a tuple of these variables
+ -> LHsExpr Id -- The monomorphic "return" operator
+ -> DsM CoreExpr
+dsInnerMonadComp stmts bndrs ret_op
+ = dsMcStmts (stmts ++ [noLoc (ReturnStmt (mkBigLHsVarTup bndrs) ret_op)])
-- The `unzip` function for `GroupStmt` in a monad comprehensions
--
diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs
index e367af50a3..f7b693f157 100644
--- a/compiler/hsSyn/HsExpr.lhs
+++ b/compiler/hsSyn/HsExpr.lhs
@@ -23,6 +23,7 @@ import Name
import BasicTypes
import DataCon
import SrcLoc
+import Util( dropTail )
import Outputable
import FastString
@@ -146,10 +147,6 @@ data HsExpr id
-- because in this context we never use
-- the PatGuard or ParStmt variant
[LStmt id] -- "do":one or more stmts
- (LHsExpr id) -- The body; the last expression in the
- -- 'do' of [ body | ... ] in a list comp
- (SyntaxExpr id) -- The 'return' function, see Note
- -- [Monad Comprehensions]
PostTcType -- Type of the whole expression
| ExplicitList -- syntactic list
@@ -441,7 +438,7 @@ ppr_expr (HsLet binds expr)
= sep [hang (ptext (sLit "let")) 2 (pprBinds binds),
hang (ptext (sLit "in")) 2 (ppr expr)]
-ppr_expr (HsDo do_or_list_comp stmts body _ _) = pprDo do_or_list_comp stmts body
+ppr_expr (HsDo do_or_list_comp stmts _) = pprDo do_or_list_comp stmts
ppr_expr (ExplicitList _ exprs)
= brackets (pprDeeperList fsep (punctuate comma (map ppr_lexpr exprs)))
@@ -577,7 +574,7 @@ pprParendExpr expr
HsPar {} -> pp_as_was
HsBracket {} -> pp_as_was
HsBracketOut _ [] -> pp_as_was
- HsDo sc _ _ _ _
+ HsDo sc _ _
| isListCompExpr sc -> pp_as_was
_ -> parens pp_as_was
@@ -835,7 +832,12 @@ type Stmt id = StmtLR id id
-- The SyntaxExprs in here are used *only* for do-notation and monad
-- comprehensions, which have rebindable syntax. Otherwise they are unused.
data StmtLR idL idR
- = BindStmt (LPat idL)
+ = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, DoExpr, MDoExpr
+ -- Not used for GhciStmt, PatGuard, which scope over other stuff
+ (LHsExpr idR)
+ (SyntaxExpr idR) -- The return operator, used only for MonadComp
+ -- See Note [Monad Comprehensions]
+ | BindStmt (LPat idL)
(LHsExpr idR)
(SyntaxExpr idR) -- The (>>=) operator
(SyntaxExpr idR) -- The fail operator
@@ -852,9 +854,10 @@ data StmtLR idL idR
-- ParStmts only occur in a list/monad comprehension
| ParStmt [([LStmt idL], [idR])]
- (SyntaxExpr idR) -- polymorphic `mzip` for monad comprehensions
+ (SyntaxExpr idR) -- Polymorphic `mzip` for monad comprehensions
(SyntaxExpr idR) -- The `>>=` operator
- (SyntaxExpr idR) -- polymorphic `return` operator
+ (SyntaxExpr idR) -- Polymorphic `return` operator
+ -- with type (forall a. a -> m a)
-- See notes [Monad Comprehensions]
-- After renaming, the ids are the binders bound by the stmts and used
@@ -926,6 +929,10 @@ data StmtLR idL idR
-- because the Id may be *polymorphic*, but
-- the returned thing has to be *monomorphic*,
-- so they may be type applications
+
+ , recS_ret_ty :: PostTcType -- The type of of do { stmts; return (a,b,c) }
+ -- With rebindable syntax the type might not
+ -- be quite as simple as (m (tya, tyb, tyc)).
}
deriving (Data, Typeable)
\end{code}
@@ -1022,10 +1029,10 @@ where v1..vn are the later_ids
Note [Monad Comprehensions]
~~~~~~~~~~~~~~~~~~~~~~~~~~~
-Monad comprehensions require seperate functions like 'return' and '>>=' for
-desugaring. These functions are stored in the 'HsDo' expression and the
-statements used in monad comprehensions. For example, the 'return' of the
-'HsDo' expression is used to lift the body of the monad comprehension:
+Monad comprehensions require separate functions like 'return' and
+'>>=' for desugaring. These functions are stored in the statements
+used in monad comprehensions. For example, the 'return' of the 'LastStmt'
+expression is used to lift the body of the monad comprehension:
[ body | stmts ]
=>
@@ -1065,6 +1072,7 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR)
ppr stmt = pprStmt stmt
pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc
+pprStmt (LastStmt expr _) = ppr expr
pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds]
pprStmt (ExprStmt expr _ _ _) = ppr expr
@@ -1103,28 +1111,32 @@ pprBy :: OutputableBndr id => Maybe (LHsExpr id) -> SDoc
pprBy Nothing = empty
pprBy (Just e) = ptext (sLit "by") <+> ppr e
-pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> LHsExpr id -> SDoc
-pprDo DoExpr stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body
-pprDo GhciStmt stmts body = ptext (sLit "do") <+> ppr_do_stmts stmts body
-pprDo MDoExpr stmts body = ptext (sLit "mdo") <+> ppr_do_stmts stmts body
-pprDo ListComp stmts body = brackets $ pprComp stmts body
-pprDo PArrComp stmts body = pa_brackets $ pprComp stmts body
-pprDo MonadComp stmts body = brackets $ pprComp stmts body
-pprDo _ _ _ = panic "pprDo" -- PatGuard, ParStmtCxt
+pprDo :: OutputableBndr id => HsStmtContext any -> [LStmt id] -> SDoc
+pprDo DoExpr stmts = ptext (sLit "do") <+> ppr_do_stmts stmts
+pprDo GhciStmt stmts = ptext (sLit "do") <+> ppr_do_stmts stmts
+pprDo MDoExpr stmts = ptext (sLit "mdo") <+> ppr_do_stmts stmts
+pprDo ListComp stmts = brackets $ pprComp stmts
+pprDo PArrComp stmts = pa_brackets $ pprComp stmts
+pprDo MonadComp stmts = brackets $ pprComp stmts
+pprDo _ _ = panic "pprDo" -- PatGuard, ParStmtCxt
-ppr_do_stmts :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc
+ppr_do_stmts :: OutputableBndr id => [LStmt id] -> SDoc
-- Print a bunch of do stmts, with explicit braces and semicolons,
-- so that we are not vulnerable to layout bugs
-ppr_do_stmts stmts body
- = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts] ++ [ppr body])
+ppr_do_stmts stmts
+ = lbrace <+> pprDeeperList vcat ([ppr s <> semi | s <- stmts])
<+> rbrace
ppr_lc_stmts :: OutputableBndr id => [LStmt id] -> [SDoc]
ppr_lc_stmts stmts = [ppr s <> comma | s <- stmts]
-pprComp :: OutputableBndr id => [LStmt id] -> LHsExpr id -> SDoc
-pprComp quals body -- Prints: body | qual1, ..., qualn
- = hang (ppr body <+> char '|') 2 (interpp'SP quals)
+pprComp :: OutputableBndr id => [LStmt id] -> SDoc
+pprComp quals -- Prints: body | qual1, ..., qualn
+ | not (null quals)
+ , L _ (LastStmt body _) <- last quals
+ = hang (ppr body <+> char '|') 2 (interpp'SP (dropTail 1 quals))
+ | otherwise
+ = pprPanic "pprComp" (interpp'SP quals)
\end{code}
%************************************************************************
@@ -1242,11 +1254,13 @@ data HsMatchContext id -- Context of a Match
data HsStmtContext id
= ListComp
- | DoExpr
- | GhciStmt -- A command-line Stmt in GHCi pat <- rhs
- | MDoExpr -- Recursive do-expression
| MonadComp
| PArrComp -- Parallel array comprehension
+
+ | DoExpr -- do { ... }
+ | MDoExpr -- mdo { ... } ie recursive do-expression
+
+ | GhciStmt -- A command-line Stmt in GHCi pat <- rhs
| PatGuard (HsMatchContext id) -- Pattern guard for specified thing
| ParStmtCtxt (HsStmtContext id) -- A branch of a parallel stmt
| TransformStmtCtxt (HsStmtContext id) -- A branch of a transform stmt
diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs
index 44e3a324af..0d91e9f796 100644
--- a/compiler/hsSyn/HsUtils.lhs
+++ b/compiler/hsSyn/HsUtils.lhs
@@ -21,7 +21,7 @@ module HsUtils(
mkMatchGroup, mkMatch, mkHsLam, mkHsIf,
mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI,
coiToHsWrapper, mkHsLams, mkHsDictLet,
- mkHsOpApp, mkHsDo, mkHsWrapPat, mkHsWrapPatCoI,
+ mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, mkDoStmts,
nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps,
nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList,
@@ -42,7 +42,7 @@ module HsUtils(
nlHsAppTy, nlHsTyVar, nlHsFunTy, nlHsTyConApp,
-- Stmts
- mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt,
+ mkTransformStmt, mkTransformByStmt, mkExprStmt, mkBindStmt, mkLastStmt,
mkGroupUsingStmt, mkGroupByStmt, mkGroupByUsingStmt,
emptyRecStmt, mkRecStmt,
@@ -190,7 +190,9 @@ mkSimpleHsAlt pat expr
mkHsIntegral :: Integer -> PostTcType -> HsOverLit id
mkHsFractional :: Rational -> PostTcType -> HsOverLit id
mkHsIsString :: FastString -> PostTcType -> HsOverLit id
-mkHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
+mkHsDo :: HsStmtContext Name -> [LStmt id] -> HsExpr id
+mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
+mkDoStmts :: [LStmt id] -> [LStmt id]
mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id
mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
@@ -198,6 +200,7 @@ mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
mkTransformStmt :: [LStmt idL] -> LHsExpr idR -> StmtLR idL idR
mkTransformByStmt :: [LStmt idL] -> LHsExpr idR -> LHsExpr idR -> StmtLR idL idR
+mkLastStmt :: LHsExpr idR -> StmtLR idL idR
mkExprStmt :: LHsExpr idR -> StmtLR idL idR
mkBindStmt :: LPat idL -> LHsExpr idR -> StmtLR idL idR
@@ -212,7 +215,15 @@ mkHsIsString s = OverLit (HsIsString s) noRebindableInfo noSyntaxExpr
noRebindableInfo :: Bool
noRebindableInfo = error "noRebindableInfo" -- Just another placeholder;
-mkHsDo ctxt stmts body = HsDo ctxt stmts body noSyntaxExpr placeHolderType
+-- mkDoStmts turns a trailing ExprStmt into a LastStmt
+mkDoStmts [L loc (ExprStmt e _ _ _)] = [L loc (mkLastStmt e)]
+mkDoStmts (s:ss) = s : mkDoStmts ss
+mkDoStmts [] = []
+
+mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType
+mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt])
+ where
+ last_stmt = L (getLoc expr) $ mkLastStmt expr
mkHsIf :: LHsExpr id -> LHsExpr id -> LHsExpr id -> HsExpr id
mkHsIf c a b = HsIf (Just noSyntaxExpr) c a b
@@ -231,13 +242,14 @@ mkGroupUsingStmt stmts usingExpr = GroupStmt stmts [] Nothing (Le
mkGroupByStmt stmts byExpr = GroupStmt stmts [] (Just byExpr) (Right noSyntaxExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr
mkGroupByUsingStmt stmts byExpr usingExpr = GroupStmt stmts [] (Just byExpr) (Left usingExpr) noSyntaxExpr noSyntaxExpr noSyntaxExpr
+mkLastStmt expr = LastStmt expr noSyntaxExpr
mkExprStmt expr = ExprStmt expr noSyntaxExpr noSyntaxExpr placeHolderType
mkBindStmt pat expr = BindStmt pat expr noSyntaxExpr noSyntaxExpr
emptyRecStmt = RecStmt { recS_stmts = [], recS_later_ids = [], recS_rec_ids = []
, recS_ret_fn = noSyntaxExpr, recS_mfix_fn = noSyntaxExpr
, recS_bind_fn = noSyntaxExpr
- , recS_rec_rets = [] }
+ , recS_rec_rets = [], recS_ret_ty = placeHolderType }
mkRecStmt stmts = emptyRecStmt { recS_stmts = stmts }
@@ -327,8 +339,8 @@ nlWildConPat con = noLoc (ConPatIn (noLoc (getRdrName con))
nlWildPat :: LPat id
nlWildPat = noLoc (WildPat placeHolderType) -- Pre-typechecking
-nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> LHsExpr id
-nlHsDo ctxt stmts body = noLoc (mkHsDo ctxt stmts body)
+nlHsDo :: HsStmtContext Name -> [LStmt id] -> LHsExpr id
+nlHsDo ctxt stmts = noLoc (mkHsDo ctxt stmts)
nlHsOpApp :: LHsExpr id -> id -> LHsExpr id -> LHsExpr id
nlHsOpApp e1 op e2 = noLoc (mkHsOpApp e1 op e2)
@@ -496,7 +508,8 @@ collectStmtBinders :: StmtLR idL idR -> [idL]
-- Id Binders for a Stmt... [but what about pattern-sig type vars]?
collectStmtBinders (BindStmt pat _ _ _) = collectPatBinders pat
collectStmtBinders (LetStmt binds) = collectLocalBinders binds
-collectStmtBinders (ExprStmt _ _ _ _) = []
+collectStmtBinders (ExprStmt {}) = []
+collectStmtBinders (LastStmt {}) = []
collectStmtBinders (ParStmt xs _ _ _) = collectLStmtsBinders
$ concatMap fst xs
collectStmtBinders (TransformStmt stmts _ _ _ _ _) = collectLStmtsBinders stmts
@@ -642,7 +655,8 @@ lStmtsImplicits = hs_lstmts
hs_stmt (BindStmt pat _ _ _) = lPatImplicits pat
hs_stmt (LetStmt binds) = hs_local_binds binds
- hs_stmt (ExprStmt _ _ _ _) = emptyNameSet
+ hs_stmt (ExprStmt {}) = emptyNameSet
+ hs_stmt (LastStmt {}) = emptyNameSet
hs_stmt (ParStmt xs _ _ _) = hs_lstmts $ concatMap fst xs
hs_stmt (TransformStmt stmts _ _ _ _ _) = hs_lstmts stmts
diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp
index ec8d3fffb3..c42ea0c864 100644
--- a/compiler/parser/Parser.y.pp
+++ b/compiler/parser/Parser.y.pp
@@ -1283,14 +1283,9 @@ exp10 :: { LHsExpr RdrName }
| 'case' exp 'of' altslist { LL $ HsCase $2 (mkMatchGroup (unLoc $4)) }
| '-' fexp { LL $ NegApp $2 noSyntaxExpr }
- | 'do' stmtlist {% let loc = comb2 $1 $2 in
- checkDo loc (unLoc $2) >>= \ (stmts,body) ->
- return (L loc (mkHsDo DoExpr stmts body)) }
- | 'mdo' stmtlist {% let loc = comb2 $1 $2 in
- checkDo loc (unLoc $2) >>= \ (stmts,body) ->
- return (L loc (mkHsDo MDoExpr
- [L loc (mkRecStmt stmts)]
- body)) }
+ | 'do' stmtlist { L (comb2 $1 $2) (mkHsDo DoExpr (unLoc $2)) }
+ | 'mdo' stmtlist { L (comb2 $1 $2) (mkHsDo MDoExpr (unLoc $2)) }
+
| scc_annot exp { LL $ if opt_SccProfilingOn
then HsSCC (unLoc $1) $2
else HsPar $2 }
@@ -1465,8 +1460,10 @@ list :: { LHsExpr RdrName }
| texp ',' exp '..' { LL $ ArithSeq noPostTcExpr (FromThen $1 $3) }
| texp '..' exp { LL $ ArithSeq noPostTcExpr (FromTo $1 $3) }
| texp ',' exp '..' exp { LL $ ArithSeq noPostTcExpr (FromThenTo $1 $3 $5) }
- | texp '|' flattenedpquals {% checkMonadComp >>= \ ctxt ->
- return (sL (comb2 $1 $>) $ mkHsDo ctxt (unLoc $3) $1) }
+ | texp '|' flattenedpquals
+ {% checkMonadComp >>= \ ctxt ->
+ return (sL (comb2 $1 $>) $
+ mkHsComp ctxt (unLoc $3) $1) }
lexps :: { Located [LHsExpr RdrName] }
: lexps ',' texp { LL (((:) $! $3) $! unLoc $1) }
@@ -1538,7 +1535,7 @@ parr :: { LHsExpr RdrName }
(reverse (unLoc $1)) }
| texp '..' exp { LL $ PArrSeq noPostTcExpr (FromTo $1 $3) }
| texp ',' exp '..' exp { LL $ PArrSeq noPostTcExpr (FromThenTo $1 $3 $5) }
- | texp '|' flattenedpquals { LL $ mkHsDo PArrComp (unLoc $3) $1 }
+ | texp '|' flattenedpquals { LL $ mkHsComp PArrComp (unLoc $3) $1 }
-- We are reusing `lexps' and `flattenedpquals' from the list case.
@@ -1605,7 +1602,7 @@ apats :: { [LPat RdrName] }
-- Statement sequences
stmtlist :: { Located [LStmt RdrName] }
- : '{' stmts '}' { LL (unLoc $2) }
+ : '{' stmts '}' { LL (mkDoStmts (unLoc $2)) }
| vocurly stmts close { $2 }
-- do { ;; s ; s ; ; s ;; }
diff --git a/compiler/parser/RdrHsSyn.lhs b/compiler/parser/RdrHsSyn.lhs
index 0e22c6955e..3b14990ec0 100644
--- a/compiler/parser/RdrHsSyn.lhs
+++ b/compiler/parser/RdrHsSyn.lhs
@@ -40,8 +40,6 @@ module RdrHsSyn (
checkPattern, -- HsExp -> P HsPat
bang_RDR,
checkPatterns, -- SrcLoc -> [HsExp] -> P [HsPat]
- checkDo, -- [Stmt] -> P [Stmt]
- checkMDo, -- [Stmt] -> P [Stmt]
checkMonadComp, -- P (HsStmtContext RdrName)
checkValDef, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
checkValSig, -- (SrcLoc, HsExp, HsRhs, [HsDecl]) -> P HsDecl
@@ -613,34 +611,6 @@ checkPred (L spn ty)
check loc _ _ = parseErrorSDoc loc
(text "malformed class assertion:" <+> ppr ty)
----------------------------------------------------------------------------
--- Checking statements in a do-expression
--- We parse do { e1 ; e2 ; }
--- as [ExprStmt e1, ExprStmt e2]
--- checkDo (a) checks that the last thing is an ExprStmt
--- (b) returns it separately
--- same comments apply for mdo as well
-
-checkDo, checkMDo :: SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName)
-
-checkDo = checkDoMDo "a " "'do'"
-checkMDo = checkDoMDo "an " "'mdo'"
-
-checkDoMDo :: String -> String -> SrcSpan -> [LStmt RdrName] -> P ([LStmt RdrName], LHsExpr RdrName)
-checkDoMDo _ nm loc [] = parseErrorSDoc loc (text ("Empty " ++ nm ++ " construct"))
-checkDoMDo pre nm _ ss = do
- check ss
- where
- check [] = panic "RdrHsSyn:checkDoMDo"
- check [L _ (ExprStmt e _ _ _)] = return ([], e)
- check [L l e] = parseErrorSDoc l
- (text ("The last statement in " ++ pre ++ nm ++
- " construct must be an expression:")
- $$ ppr e)
- check (s:ss) = do
- (ss',e') <- check ss
- return ((s:ss'),e')
-
-- -------------------------------------------------------------------------
-- Checking Patterns.
diff --git a/compiler/prelude/PrelNames.lhs b/compiler/prelude/PrelNames.lhs
index 9b59f5d9a0..421ec45536 100644
--- a/compiler/prelude/PrelNames.lhs
+++ b/compiler/prelude/PrelNames.lhs
@@ -160,6 +160,7 @@ basicKnownKeyNames
-- Monad stuff
thenIOName, bindIOName, returnIOName, failIOName,
failMName, bindMName, thenMName, returnMName,
+ fmapName,
-- MonadRec stuff
mfixName,
@@ -612,6 +613,7 @@ eqName = methName gHC_CLASSES (fsLit "==") eqClassOpKey
ordClassName = clsQual gHC_CLASSES (fsLit "Ord") ordClassKey
geName = methName gHC_CLASSES (fsLit ">=") geClassOpKey
functorClassName = clsQual gHC_BASE (fsLit "Functor") functorClassKey
+fmapName = methName gHC_BASE (fsLit "fmap") fmapClassOpKey
-- Class Monad
monadClassName, thenMName, bindMName, returnMName, failMName :: Name
@@ -1312,6 +1314,7 @@ negateClassOpKey = mkPreludeMiscIdUnique 111
failMClassOpKey = mkPreludeMiscIdUnique 112
bindMClassOpKey = mkPreludeMiscIdUnique 113 -- (>>=)
thenMClassOpKey = mkPreludeMiscIdUnique 114 -- (>>)
+fmapClassOpKey = mkPreludeMiscIdUnique 115
returnMClassOpKey = mkPreludeMiscIdUnique 117
-- Recursive do notation
diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs
index 425cb40f59..e3e92bcfd0 100644
--- a/compiler/rename/RnExpr.lhs
+++ b/compiler/rename/RnExpr.lhs
@@ -224,16 +224,9 @@ rnExpr (HsLet binds expr)
rnLExpr expr `thenM` \ (expr',fvExpr) ->
return (HsLet binds' expr', fvExpr)
-rnExpr (HsDo do_or_lc stmts body _ _)
- = do { ((stmts', body'), fvs1) <- rnStmts do_or_lc stmts $ \ _ ->
- rnLExpr body
- ; (return_op, fvs2) <-
- if isMonadCompExpr do_or_lc
- then lookupSyntaxName returnMName
- else return (noSyntaxExpr, emptyFVs)
-
- ; return ( HsDo do_or_lc stmts' body' return_op placeHolderType
- , fvs1 `plusFV` fvs2 ) }
+rnExpr (HsDo do_or_lc stmts _)
+ = do { ((stmts', _), fvs) <- rnStmts do_or_lc stmts (\ _ -> return ())
+ ; return ( HsDo do_or_lc stmts' placeHolderType, fvs ) }
rnExpr (ExplicitList _ exps)
= rnExprs exps `thenM` \ (exps', fvs) ->
@@ -653,25 +646,53 @@ rnStmts :: HsStmtContext Name -> [LStmt RdrName]
--
-- Renaming a single RecStmt can give a sequence of smaller Stmts
-rnStmts _ [] thing_inside
- = do { (res, fvs) <- thing_inside []
- ; return (([], res), fvs) }
+rnStmts ctxt [] thing_inside
+ = do { addErr (ptext (sLit "Empty") <+> pprStmtContext ctxt)
+ ; (thing, fvs) <- thing_inside []
+ ; return (([], thing), fvs) }
+
+rnStmts MDoExpr stmts thing_inside -- Deal with mdo
+ = -- Behave like do { rec { ...all but last... }; last }
+ do { ((stmts1, (stmts2, thing)), fvs)
+ <- rnStmt MDoExpr (mkRecStmt all_but_last) $ \ bndrs ->
+ do { checkStmt MDoExpr True last_stmt
+ ; rnStmt MDoExpr last_stmt thing_inside }
+ ; return (((stmts1 ++ stmts2), thing), fvs) }
+ where
+ Just (all_but_last, last_stmt) = snocView stmts
rnStmts ctxt (stmt@(L loc _) : stmts) thing_inside
+ | null stmts
+ = setSrcSpan loc $
+ do { let last_stmt = case stmt of
+ ExprStmt e _ _ _ -> LastStmt e noSyntaxExpr
+ ; checkStmt ctxt True {- last stmt -} stmt
+ ; rnStmt ctxt stmt thing_inside }
+
+ | otherwise
= do { ((stmts1, (stmts2, thing)), fvs)
- <- setSrcSpan loc $
- rnStmt ctxt stmt $ \ bndrs1 ->
- rnStmts ctxt stmts $ \ bndrs2 ->
- thing_inside (bndrs1 ++ bndrs2)
+ <- setSrcSpan loc $
+ do { checkStmt ctxt False {- Not last -} stmt
+ ; rnStmt ctxt stmt $ \ bndrs1 ->
+ rnStmts ctxt stmts $ \ bndrs2 ->
+ thing_inside (bndrs1 ++ bndrs2) }
; return (((stmts1 ++ stmts2), thing), fvs) }
-
-rnStmt :: HsStmtContext Name -> LStmt RdrName
+----------------------
+rnStmt :: HsStmtContext Name
+ -> LStmt RdrName
-> ([Name] -> RnM (thing, FreeVars))
-> RnM (([LStmt Name], thing), FreeVars)
-- Variables bound by the Stmt, and mentioned in thing_inside,
-- do not appear in the result FreeVars
+rnStmt ctxt (L loc (LastStmt expr _)) thing_inside
+ = do { (expr', fv_expr) <- rnLExpr expr
+ ; (ret_op, fvs1) <- lookupSyntaxName returnMName
+ ; (thing, fvs3) <- thing_inside []
+ ; return (([L loc (LastStmt expr' ret_op)], thing),
+ fv_expr `plusFV` fvs1 `plusFV` fvs3) }
+
rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
= do { (expr', fv_expr) <- rnLExpr expr
; (then_op, fvs1) <- lookupSyntaxName thenMName
@@ -683,7 +704,8 @@ rnStmt ctxt (L loc (ExprStmt expr _ _ _)) thing_inside
fv_expr `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) }
rnStmt ctxt (L loc (BindStmt pat expr _ _)) thing_inside
- = do { (expr', fv_expr) <- rnLExpr expr
+ = do { checkBindStmt ctxt is_last
+ ; (expr', fv_expr) <- rnLExpr expr
-- The binders do not scope over the expression
; (bind_op, fvs1) <- lookupSyntaxName bindMName
; (fail_op, fvs2) <- lookupSyntaxName failMName
@@ -701,8 +723,7 @@ rnStmt ctxt (L loc (LetStmt binds)) thing_inside
; return (([L loc (LetStmt binds')], thing), fvs) } }
rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
- = do { checkRecStmt ctxt
-
+ = do {
-- Step1: Bring all the binders of the mdo into scope
-- (Remember that this also removes the binders from the
-- finally-returned free-vars.)
@@ -745,8 +766,7 @@ rnStmt ctxt (L _ (RecStmt { recS_stmts = rec_stmts })) thing_inside
; return ((rec_stmts', thing), fvs `plusFV` fvs1 `plusFV` fvs2 `plusFV` fvs3) } }
rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
- = do { checkParStmt ctxt
- ; ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
+ = do { ((mzip_op, fvs1), (bind_op, fvs2), (return_op, fvs3)) <- if isMonadCompExpr ctxt
then (,,) <$> lookupSyntaxName mzipName
<*> lookupSyntaxName bindMName
<*> lookupSyntaxName returnMName
@@ -758,9 +778,7 @@ rnStmt ctxt (L loc (ParStmt segs _ _ _)) thing_inside
, fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
- = do { checkTransformStmt ctxt
-
- ; (using', fvs1) <- rnLExpr using
+ = do { (using', fvs1) <- rnLExpr using
; ((stmts', (by', used_bndrs, thing)), fvs2)
<- rnStmts (TransformStmtCtxt ctxt) stmts $ \ bndrs ->
@@ -786,9 +804,7 @@ rnStmt ctxt (L loc (TransformStmt stmts _ using by _ _)) thing_inside
fvs1 `plusFV` fvs2 `plusFV` fvs3 `plusFV` fvs4) }
rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
- = do { checkTransformStmt ctxt
-
- -- Rename the 'using' expression in the context before the transform is begun
+ = do { -- Rename the 'using' expression in the context before the transform is begun
; (using', fvs1) <- case using of
Left e -> do { (e', fvs) <- rnLExpr e; return (Left e', fvs) }
Right _
@@ -810,11 +826,11 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
; return ((by', used_bndrs, thing), fvs) }
-- Lookup `return`, `(>>=)` and `liftM` for monad comprehensions
- ; ((return_op, fvs3), (bind_op, fvs4), (liftM_op, fvs5)) <-
+ ; ((return_op, fvs3), (bind_op, fvs4), (fmap_op, fvs5)) <-
if isMonadCompExpr ctxt
then (,,) <$> lookupSyntaxName returnMName
<*> lookupSyntaxName bindMName
- <*> lookupSyntaxName liftMName
+ <*> lookupSyntaxName fmapName
else return ( (noSyntaxExpr, emptyFVs)
, (noSyntaxExpr, emptyFVs)
, (noSyntaxExpr, emptyFVs) )
@@ -825,7 +841,7 @@ rnStmt ctxt (L loc (GroupStmt stmts _ by using _ _ _)) thing_inside
-- See Note [GroupStmt binder map] in HsExpr
; traceRn (text "rnStmt: implicitly rebound these used binders:" <+> ppr bndr_map)
- ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op liftM_op)], thing), all_fvs) }
+ ; return (([L loc (GroupStmt stmts' bndr_map by' using' return_op bind_op fmap_op)], thing), all_fvs) }
type ParSeg id = ([LStmt id], [id]) -- The Names are bound by the Stmts
@@ -1182,22 +1198,124 @@ program.
%************************************************************************
\begin{code}
-
----------------------
-- Checking when a particular Stmt is ok
-checkLetStmt :: HsStmtContext Name -> HsLocalBinds RdrName -> RnM ()
-checkLetStmt (ParStmtCtxt _) (HsIPBinds binds) = addErr (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds)
-checkLetStmt _ctxt _binds = return ()
+checkStmt :: HsStmtContext Name
+ -> Bool -- True <=> this is the last Stmt in the sequence
+ -> LStmt RdrName
+ -> RnM ()
+checkStmt ctxt is_last (L _ stmt)
+ = do { dflags <- getDOpts
+ ; case okStmt dflags ctxt is_last stmt of
+ Nothing -> return ()
+ Just extr -> addErr (msg $$ extra) }
+ where
+ msg = ptext (sLit "Unexpected") <+> pprStmtCat stmt
+ <+> ptext (sLit "statement in") <+> pprStmtContext ctxt
+
+pprStmtCat :: Stmt a -> SDoc
+pprStmtCat (TransformStmt {}) = ptext (sLit "transform")
+pprStmtCat (GroupStmt {}) = ptext (sLit "group")
+pprStmtCat (LastStmt {}) = ptext (sLit "return expression")
+pprStmtCat (ExprStmt {}) = ptext (sLit "exprssion")
+pprStmtCat (BindStmt {}) = ptext (sLit "binding")
+pprStmtCat (LetStmt {}) = ptext (sLit "let")
+pprStmtCat (RecStmt {}) = ptext (sLit "rec")
+pprStmtCat (ParStmt {}) = ptext (sLit "parallel")
+
+------------
+isOK, notOK :: Maybe SDoc
+isOK = Nothing
+notOK = Just empty
+
+okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool
+ -> Stmt RdrName -> Maybe SDoc
+-- Return Nothing if OK, (Just extra) if not ok
+-- The "extra" is an SDoc that is appended to an generic error message
+okStmt dflags GhciStmt is_last stmt
+ = case stmt of
+ ExprStmt {} -> isOK
+ BindStmt {} -> isOK
+ LetStmt {} -> isOK
+ _ -> notOK
+
+okStmt dflags (PatGuard {}) is_last stmt
+ = case stmt of
+ ExprStmt {} -> isOK
+ BindStmt {} -> isOK
+ LetStmt {} -> isOK
+ _ -> notOK
+
+okStmt dflags (ParStmtCtxt ctxt) is_last stmt
+ = case stmt of
+ LetStmt (HsIPBinds {}) -> notOK
+ _ -> okStmt dflags ctxt is_last stmt
+
+okStmt dflags (TransformStmtCtxt ctxt) is_last stmt
+ = okStmt dflags ctxt is_last stmt
+
+okStmt ctxt is_last stmt
+ | isDoExpr ctxt = okDoStmt ctxt is_last stmt
+ | isCompExpr ctxt = okCompStmt ctxt is_last stmt
+ | otherwise = pprPanic "okStmt" (pprStmtContext ctxt)
+
+----------------
+okDoStmt dflags ctxt is_last stmt
+ | is_last
+ = case stmt of
+ LastStmt {} -> isOK
+ _ -> Just (ptext (sLit "The last statement in") <+> what <+>
+ ptext (sLIt "construct must be an expression"))
+ where
+ what = case ctxt of
+ DoExpr -> ptext (sLit "a 'do'")
+ MDoExpr -> ptext (sLit "an 'mdo'")
+ _ -> panic "checkStmt"
+
+ | otherwise
+ = case stmt of
+ RecStmt {} -> isOK -- Shouldn't we test a flag?
+ BindStmt {} -> isOK
+ LetStmt {} -> isOK
+ ExprStmt {} -> isOK
+ _ -> notOK
+
+
+----------------
+okCompStmt dflags ctxt is_last stmt
+ | is_last
+ = case stmt of
+ LastStmt {} -> Nothing
+ -> pprPanic "Unexpected stmt" (ppr stmt) -- Not a user error
+
+ | otherwise
+ = case stmt of
+ BindStmt {} -> isOK
+ LetStmt {} -> isOK
+ ExprStmt {} -> isOK
+ RecStmt {} -> notOK
+ ParStmt {}
+ | dopt dflags Opt_ParallelListComp -> isOK
+ | otherwise -> Just (ptext (sLit "Use -XParallelListComp"))
+ TransformStmt {}
+ | dopt dflags Opt_transformListComp -> isOK
+ | otherwise -> Just (ptext (sLit "Use -XTransformListComp"))
+ GroupStmt {}
+ | dopt dflags Opt_transformListComp -> isOK
+ | otherwise -> Just (ptext (sLit "Use -XTransformListComp"))
+
+
+checkStmt :: HsStmtContext Name -> Stmt RdrName -> Maybe SDoc
+-- Non-last stmt
+
+checkStmt (ParStmtCtxt _) (HsIPBinds binds)
+ = Just (badIpBinds (ptext (sLit "a parallel list comprehension:")) binds)
-- We do not allow implicit-parameter bindings in a parallel
-- list comprehension. I'm not sure what it might mean.
----------
-checkRecStmt :: HsStmtContext Name -> RnM ()
-checkRecStmt MDoExpr = return () -- Recursive stmt ok in 'mdo'
-checkRecStmt DoExpr = return () -- and in 'do'
-checkRecStmt ctxt = addErr msg
- where
- msg = ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt
+checkStmt ctxt (RecStmt {})
+ | not (isDoExpr ctxt)
+ = addErr (ptext (sLit "Illegal 'rec' stmt in") <+> pprStmtContext ctxt)
---------
checkParStmt :: HsStmtContext Name -> RnM ()
diff --git a/compiler/typecheck/TcExpr.lhs b/compiler/typecheck/TcExpr.lhs
index a821f2548c..d24ebbe85f 100644
--- a/compiler/typecheck/TcExpr.lhs
+++ b/compiler/typecheck/TcExpr.lhs
@@ -415,8 +415,8 @@ tcExpr (HsIf (Just fun) pred b1 b2) res_ty -- Note [Rebindable syntax for if]
-- and it maintains uniformity with other rebindable syntax
; return (HsIf (Just fun') pred' b1' b2') }
-tcExpr (HsDo do_or_lc stmts body return_op _) res_ty
- = tcDoStmts do_or_lc stmts body return_op res_ty
+tcExpr (HsDo do_or_lc stmts _) res_ty
+ = tcDoStmts do_or_lc stmts res_ty
tcExpr (HsProc pat cmd) res_ty
= do { (pat', cmd', coi) <- tcProc pat cmd res_ty
diff --git a/compiler/typecheck/TcGenDeriv.lhs b/compiler/typecheck/TcGenDeriv.lhs
index efacac2c37..f7e5d39c94 100644
--- a/compiler/typecheck/TcGenDeriv.lhs
+++ b/compiler/typecheck/TcGenDeriv.lhs
@@ -779,7 +779,7 @@ gen_Ix_binds loc tycon
single_con_range
= mk_easy_FunBind loc range_RDR
[nlTuplePat [con_pat as_needed, con_pat bs_needed] Boxed] $
- nlHsDo ListComp stmts con_expr
+ noLoc (mkHsComp ListComp stmts con_expr)
where
stmts = zipWith3Equal "single_con_range" mk_qual as_needed bs_needed cs_needed
@@ -893,7 +893,7 @@ gen_Read_binds get_fixity loc tycon
read_nullary_cons
= case nullary_cons of
[] -> []
- [con] -> [nlHsDo DoExpr (match_con con) (result_expr con [])]
+ [con] -> [nlHsDo DoExpr (match_con con ++ [mkExprStmt (result_expr con [])])]
_ -> [nlHsApp (nlHsVar choose_RDR)
(nlList (map mk_pair nullary_cons))]
-- NB For operators the parens around (:=:) are matched by the
@@ -965,11 +965,12 @@ gen_Read_binds get_fixity loc tycon
------------------------------------------------------------------------
-- Helpers
------------------------------------------------------------------------
- mk_alt e1 e2 = genOpApp e1 alt_RDR e2 -- e1 +++ e2
- mk_parser p ss b = nlHsApps prec_RDR [nlHsIntLit p, nlHsDo DoExpr ss b] -- prec p (do { ss ; b })
- bindLex pat = noLoc (mkBindStmt pat (nlHsVar lexP_RDR)) -- pat <- lexP
- con_app con as = nlHsVarApps (getRdrName con) as -- con as
- result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as) -- return (con as)
+ mk_alt e1 e2 = genOpApp e1 alt_RDR e2 -- e1 +++ e2
+ mk_parser p ss b = nlHsApps prec_RDR [nlHsIntLit p -- prec p (do { ss ; b })
+ , nlHsDo DoExpr (ss ++ [mkExprStmt b])]
+ bindLex pat = noLoc (mkBindStmt pat (nlHsVar lexP_RDR)) -- pat <- lexP
+ con_app con as = nlHsVarApps (getRdrName con) as -- con as
+ result_expr con as = nlHsApp (nlHsVar returnM_RDR) (con_app con as) -- return (con as)
punc_pat s = nlConPat punc_RDR [nlLitPat (mkHsString s)] -- Punc 'c'
diff --git a/compiler/typecheck/TcHsSyn.lhs b/compiler/typecheck/TcHsSyn.lhs
index 357db734cd..518582fa6a 100644
--- a/compiler/typecheck/TcHsSyn.lhs
+++ b/compiler/typecheck/TcHsSyn.lhs
@@ -578,12 +578,10 @@ zonkExpr env (HsLet binds expr)
zonkLExpr new_env expr `thenM` \ new_expr ->
returnM (HsLet new_binds new_expr)
-zonkExpr env (HsDo do_or_lc stmts body return_op ty)
- = zonkStmts env stmts `thenM` \ (new_env, new_stmts) ->
- zonkLExpr new_env body `thenM` \ new_body ->
- zonkExpr new_env return_op `thenM` \ new_return ->
+zonkExpr env (HsDo do_or_lc stmts ty)
+ = zonkStmts env stmts `thenM` \ (_, new_stmts) ->
zonkTcTypeToType env ty `thenM` \ new_ty ->
- returnM (HsDo do_or_lc new_stmts new_body new_return new_ty)
+ returnM (HsDo do_or_lc new_stmts new_ty)
zonkExpr env (ExplicitList ty exprs)
= zonkTcTypeToType env ty `thenM` \ new_ty ->
@@ -745,9 +743,10 @@ zonkStmt env (ParStmt stmts_w_bndrs mzip_op bind_op return_op)
zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_ids = rvs
, recS_ret_fn = ret_id, recS_mfix_fn = mfix_id, recS_bind_fn = bind_id
- , recS_rec_rets = rets })
+ , recS_rec_rets = rets, redS_ret_ty = ret_ty })
= do { new_rvs <- zonkIdBndrs env rvs
; new_lvs <- zonkIdBndrs env lvs
+ ; new_ret_ty <- zonkTcTypeToType env ret_ty
; new_ret_id <- zonkExpr env ret_id
; new_mfix_id <- zonkExpr env mfix_id
; new_bind_id <- zonkExpr env bind_id
@@ -760,7 +759,7 @@ zonkStmt env (RecStmt { recS_stmts = segStmts, recS_later_ids = lvs, recS_rec_id
RecStmt { recS_stmts = new_segStmts, recS_later_ids = new_lvs
, recS_rec_ids = new_rvs, recS_ret_fn = new_ret_id
, recS_mfix_fn = new_mfix_id, recS_bind_fn = new_bind_id
- , recS_rec_rets = new_rets }) }
+ , recS_rec_rets = new_rets, recS_ret_ty = new_ret_ty }) }
zonkStmt env (ExprStmt expr then_op guard_op ty)
= zonkLExpr env expr `thenM` \ new_expr ->
@@ -769,6 +768,11 @@ zonkStmt env (ExprStmt expr then_op guard_op ty)
zonkTcTypeToType env ty `thenM` \ new_ty ->
returnM (env, ExprStmt new_expr new_then new_guard new_ty)
+zonkStmt env (LastStmt expr ret_op)
+ = zonkLExpr env expr `thenM` \ new_expr ->
+ zonkExpr env ret_op `thenM` \ new_ret ->
+ returnM (env, LastStmt new_expr new_ret)
+
zonkStmt env (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op)
= do { (env', stmts') <- zonkStmts env stmts
; let binders' = zonkIdOccs env' binders
diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs
index 31aa555b72..60bf7e2c3e 100644
--- a/compiler/typecheck/TcMatches.lhs
+++ b/compiler/typecheck/TcMatches.lhs
@@ -241,41 +241,31 @@ tcGRHS ctxt res_ty (GRHS guards rhs)
\begin{code}
tcDoStmts :: HsStmtContext Name
-> [LStmt Name]
- -> LHsExpr Name
- -> SyntaxExpr Name -- 'return' function for monad
- -- comprehensions
-> TcRhoType
-> TcM (HsExpr TcId) -- Returns a HsDo
-tcDoStmts ListComp stmts body _ res_ty
+tcDoStmts ListComp stmts res_ty
= do { (coi, elt_ty) <- matchExpectedListTy res_ty
- ; (stmts', body') <- tcStmts ListComp (tcLcStmt listTyCon) stmts
- elt_ty $
- tcBody body
+ ; stmts' <- tcStmts ListComp (tcLcStmt listTyCon) stmts res_ty
; return $ mkHsWrapCoI coi
- (HsDo ListComp stmts' body' noSyntaxExpr (mkListTy elt_ty)) }
+ (HsDo ListComp stmts' (mkListTy elt_ty)) }
-tcDoStmts PArrComp stmts body _ res_ty
+tcDoStmts PArrComp stmts res_ty
= do { (coi, elt_ty) <- matchExpectedPArrTy res_ty
- ; (stmts', body') <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts
- elt_ty $
- tcBody body
+ ; stmts' <- tcStmts PArrComp (tcLcStmt parrTyCon) stmts elt_ty
; return $ mkHsWrapCoI coi
- (HsDo PArrComp stmts' body' noSyntaxExpr (mkPArrTy elt_ty)) }
+ (HsDo PArrComp stmts' (mkPArrTy elt_ty)) }
-tcDoStmts DoExpr stmts body _ res_ty
- = do { (stmts', body') <- tcStmts DoExpr tcDoStmt stmts res_ty $
- tcBody body
- ; return (HsDo DoExpr stmts' body' noSyntaxExpr res_ty) }
+tcDoStmts DoExpr stmts res_ty
+ = do { stmts' <- tcStmts DoExpr tcDoStmt stmts res_ty
+ ; return (HsDo DoExpr stmts' res_ty) }
-tcDoStmts MDoExpr stmts body _ res_ty
- = do { (stmts', body') <- tcStmts MDoExpr tcDoStmt stmts res_ty $
- tcBody body
- ; return (HsDo MDoExpr stmts' body' noSyntaxExpr res_ty) }
+tcDoStmts MDoExpr stmts res_ty
+ = do { stmts' <- tcStmts MDoExpr tcDoStmt stmts res_ty
+ ; return (HsDo MDoExpr stmts' res_ty) }
-tcDoStmts MonadComp stmts body return_op res_ty
- = do { (stmts', (body', return_op')) <- tcStmts MonadComp tcMcStmt stmts res_ty $
- tcMcBody body return_op
- ; return $ HsDo MonadComp stmts' body' return_op' res_ty }
+tcDoStmts MonadComp stmts res_ty
+ = do { stmts' <- tcStmts MonadComp tcMcStmt stmts res_ty
+ ; return (HsDo MonadComp stmts' res_ty) }
tcDoStmts ctxt _ _ _ _ = pprPanic "tcDoStmts" (pprStmtContext ctxt)
@@ -306,30 +296,40 @@ tcStmts :: HsStmtContext Name
-> TcStmtChecker -- NB: higher-rank type
-> [LStmt Name]
-> TcRhoType
- -> (TcRhoType -> TcM thing)
- -> TcM ([LStmt TcId], thing)
+ -> TcM [LStmt TcId]
+tcStmts ctxt stmt_chk stmts res_ty
+ = do { (stmts', _) <- tcStmtsAndThen ctxt stmt_check stmts res_ty $
+ const (return ())
+ ; return stmts' }
+
+tcStmtsAndThen :: HsStmtContext Name
+ -> TcStmtChecker -- NB: higher-rank type
+ -> [LStmt Name]
+ -> TcRhoType
+ -> (TcRhoType -> TcM thing)
+ -> TcM ([LStmt TcId], thing)
-- Note the higher-rank type. stmt_chk is applied at different
-- types in the equations for tcStmts
-tcStmts _ _ [] res_ty thing_inside
+tcStmtsAndThen _ _ [] res_ty thing_inside
= do { thing <- thing_inside res_ty
; return ([], thing) }
-- LetStmts are handled uniformly, regardless of context
-tcStmts ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside
+tcStmtsAndThen ctxt stmt_chk (L loc (LetStmt binds) : stmts) res_ty thing_inside
= do { (binds', (stmts',thing)) <- tcLocalBinds binds $
- tcStmts ctxt stmt_chk stmts res_ty thing_inside
+ tcStmtsAndThen ctxt stmt_chk stmts res_ty thing_inside
; return (L loc (LetStmt binds') : stmts', thing) }
-- For the vanilla case, handle the location-setting part
-tcStmts ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
+tcStmtsAndThen ctxt stmt_chk (L loc stmt : stmts) res_ty thing_inside
= do { (stmt', (stmts', thing)) <-
- setSrcSpan loc $
- addErrCtxt (pprStmtInCtxt ctxt stmt) $
- stmt_chk ctxt stmt res_ty $ \ res_ty' ->
- popErrCtxt $
- tcStmts ctxt stmt_chk stmts res_ty' $
+ setSrcSpan loc $
+ addErrCtxt (pprStmtInCtxt ctxt stmt) $
+ stmt_chk ctxt stmt res_ty $ \ res_ty' ->
+ popErrCtxt $
+ tcStmtsAndThen ctxt stmt_chk stmts res_ty' $
thing_inside
; return (L loc stmt' : stmts', thing) }
@@ -357,18 +357,23 @@ tcGuardStmt _ stmt _ _
tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray)
-> TcStmtChecker
+tcLcStmt m_tc ctxt (LastStmt body _) elt_ty thing_inside
+ = do { body' <- tcMonoExpr body elt_ty
+ ; thing <- thing_inside elt_ty
+ ; return (LastStmt body' noSyntaxExpr, thing) }
+
-- A generator, pat <- rhs
-tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) res_ty thing_inside
+tcLcStmt m_tc ctxt (BindStmt pat rhs _ _) elt_ty thing_inside
= do { pat_ty <- newFlexiTyVarTy liftedTypeKind
; rhs' <- tcMonoExpr rhs (mkTyConApp m_tc [pat_ty])
; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $
- thing_inside res_ty
+ thing_inside elt_ty
; return (BindStmt pat' rhs' noSyntaxExpr noSyntaxExpr, thing) }
-- A boolean guard
-tcLcStmt _ _ (ExprStmt rhs _ _ _) res_ty thing_inside
+tcLcStmt _ _ (ExprStmt rhs _ _ _) elt_ty thing_inside
= do { rhs' <- tcMonoExpr rhs boolTy
- ; thing <- thing_inside res_ty
+ ; thing <- thing_inside elt_ty
; return (ExprStmt rhs' noSyntaxExpr noSyntaxExpr boolTy, thing) }
-- A parallel set of comprehensions
@@ -491,20 +496,29 @@ tcLcStmt _ _ stmt _ _
tcMcStmt :: TcStmtChecker
+tcMcStmt ctxt (LastStmt body return_op) res_ty thing_inside
+ = do { a_ty <- newFlexiTyVarTy liftedTypeKind
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (a_ty `mkFunTy` res_ty)
+ ; body' <- tcMonoExpr body a_ty
+ ; return (body', return_op') }
+
-- Generators for monad comprehensions ( pat <- rhs )
--
-- [ body | q <- gen ] -> gen :: m a
-- q :: a
--
+
tcMcStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
= do { rhs_ty <- newFlexiTyVarTy liftedTypeKind
; pat_ty <- newFlexiTyVarTy liftedTypeKind
; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+
+ -- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
; bind_op' <- tcSyntaxOp MCompOrigin bind_op
(mkFunTys [rhs_ty, mkFunTy pat_ty new_res_ty] res_ty)
- -- If (but only if) the pattern can fail,
- -- typecheck the 'fail' operator
+ -- If (but only if) the pattern can fail, typecheck the 'fail' operator
; fail_op' <- if isIrrefutableHsPat pat
then return noSyntaxExpr
else tcSyntaxOp MCompOrigin fail_op (mkFunTy stringTy new_res_ty)
@@ -540,15 +554,15 @@ tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
-- [ body | stmts, then f ] -> f :: forall a. m a -> m a
-- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a
--
-tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) elt_ty thing_inside
+tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside
= do {
-- We don't know the types of binders yet, so we use this dummy and
-- later unify this type with the `m_bndr_ty`
ty_dummy <- newFlexiTyVarTy liftedTypeKind
; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <-
- tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \elt_ty' -> do
- { (_, (m_ty, _)) <- matchExpectedAppTy elt_ty'
+ tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do
+ { (_, (m_ty, _)) <- matchExpectedAppTy res_ty'
; (usingExpr', maybeByExpr') <-
case maybeByExpr of
Nothing -> do
@@ -582,22 +596,22 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_
-- -> ( (a,b,c,..) -> m (a,b,c,..) )
-- -> m (a,b,c,..)
--
- ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids
+ ; let bndr_ty = mkBigCoreVarTupTy bndr_ids
m_bndr_ty = m_ty `mkAppTy` bndr_ty
; return_op' <- tcSyntaxOp MCompOrigin return_op
(bndr_ty `mkFunTy` m_bndr_ty)
; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
- m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` elt_ty)
- `mkFunTy` elt_ty
+ m_bndr_ty `mkFunTy` (bndr_ty `mkFunTy` res_ty)
+ `mkFunTy` res_ty
-- Unify types of the inner comprehension and the binders type
- ; _ <- unifyType elt_ty' m_bndr_ty
+ ; _ <- unifyType res_ty' m_bndr_ty
-- Typecheck the `thing` with out old type (which is the type
-- of the final result of our comprehension)
- ; thing <- thing_inside elt_ty
+ ; thing <- thing_inside res_ty
; return (bndr_ids, usingExpr', maybeByExpr', return_op', bind_op', thing) }
@@ -613,32 +627,21 @@ tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_
-- [ body | stmts, then group using f ]
-- -> f :: forall a. m a -> m (m a)
--
-tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) elt_ty thing_inside
- = do { let (bndr_names, m_bndr_names) = unzip bindersMap
-
- ; (_,(m_ty,_)) <- matchExpectedAppTy elt_ty
- ; let alphaMTy = m_ty `mkAppTy` alphaTy
- alphaMMTy = m_ty `mkAppTy` alphaMTy
-
- -- We don't know the type of the bindings yet. It's not elt_ty!
- ; bndr_ty_dummy <- newFlexiTyVarTy liftedTypeKind
-
- ; (stmts', (bndr_ids, by', using_ty, return_op', bind_op')) <-
- tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts bndr_ty_dummy $ \elt_ty' -> do
- { (by', using_ty) <-
- case by of
- Nothing -> -- check that using :: forall a. m a -> m (m a)
- return (Nothing, mkForAllTy alphaTyVar $
- alphaMTy `mkFunTy` alphaMMTy)
-
- Just by_e -> -- check that using :: forall a. (a -> t) -> m a -> m (m a)
- -- where by :: t
- do { (by_e', t_ty) <- tcInferRhoNC by_e
- ; return (Just by_e', mkForAllTy alphaTyVar $
- (alphaTy `mkFunTy` t_ty)
- `mkFunTy` alphaMTy
- `mkFunTy` alphaMMTy) }
-
+tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op fmap_op) res_ty thing_inside
+ = do { m1_ty <- newFlexiTyVarTy liftedTypeKind
+ ; m2_ty <- newFlexiTyVarTy liftedTypeKind
+ ; n_ty <- newFlexiTyVarTy liftedTypeKind
+ ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; let (bndr_names, n_bndr_names) = unzip bindersMap
+ m1_tup_ty = m1_ty `mkAppTy` tup_ty_var
+
+ -- 'stmts' returns a result of type (m1_ty tuple_ty),
+ -- typically something like [(Int,Bool,Int)]
+ -- We don't know what tuple_ty is yet, so we use a variable
+ ; (stmts', (bndr_ids, by_e_ty, return_op')) <-
+ tcStmts (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+ { by_e_ty <- mapM tcInferRhoNC by_e
-- Find the Ids (and hence types) of all old binders
; bndr_ids <- tcLookupLocalIds bndr_names
@@ -646,48 +649,52 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e
-- 'return' is only used for the binders, so we know its type.
--
-- return :: (a,b,c,..) -> m (a,b,c,..)
- --
- ; let bndr_ty = mkChunkified mkBoxedTupleTy $ map idType bndr_ids
- m_bndr_ty = m_ty `mkAppTy` bndr_ty
- ; return_op' <- tcSyntaxOp MCompOrigin return_op $ bndr_ty `mkFunTy` m_bndr_ty
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op $
+ (mkBigCoreVarTupTy bndr_ids) `mkFunTy` res_ty'
- -- '>>=' is used to pass the grouped binders to the rest of the
- -- comprehension.
- --
- -- (>>=) :: m (m a, m b, m c, ..)
- -- -> ( (m a, m b, m c, ..) -> new_elt_ty )
- -- -> elt_ty
- --
- ; let bndr_m_ty = mkChunkified mkBoxedTupleTy $ map (mkAppTy m_ty . idType) bndr_ids
- m_bndr_m_ty = m_ty `mkAppTy` bndr_m_ty
- ; new_elt_ty <- newFlexiTyVarTy liftedTypeKind
- ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
- m_bndr_m_ty `mkFunTy` (bndr_m_ty `mkFunTy` new_elt_ty)
- `mkFunTy` elt_ty
+ ; return (bndr_ids, by_e_ty, return_op') }
- -- Finally make sure the type of the inner comprehension
- -- represents the types of our binders
- ; _ <- unifyType elt_ty' m_bndr_ty
- ; return (bndr_ids, by', using_ty, return_op', bind_op') }
- ; let mk_m_bndr :: Name -> TcId -> TcId
- mk_m_bndr m_bndr_name bndr_id =
- mkLocalId m_bndr_name (m_ty `mkAppTy` idType bndr_id)
+ ; let tup_ty = mkBigCoreVarTupTy bndr_ids -- (a,b,c)
+ using_arg_ty = m1_ty `mkAppTy` tup_ty -- m1 (a,b,c)
+ n_tup_ty = n_ty `mkAppTy` tup_ty -- n (a,b,c)
+ using_res_ty = m2_ty `mkAppTy` n_tup_ty -- m2 (n (a,b,c))
+ using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty
+
+ -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
+ -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
- -- Ensure that every old binder of type `b` is linked up with its
- -- new binder which should have type `m b`
- m_bndr_ids = zipWith mk_m_bndr m_bndr_names bndr_ids
- bindersMap' = bndr_ids `zip` m_bndr_ids
+ --------------- Typecheck the 'bind' function -------------
+ ; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
+ using_res_ty `mkFunTy` (n_tup_ty `mkFunTy` new_res_ty)
+ `mkFunTy` res_ty
- -- See Note [GroupStmt binder map] in HsExpr
+ --------------- Typecheck the 'using' function -------------
+ ; let using_fun_ty = (m1_ty `mkAppTy` alphaTy) `mkFunTy`
+ (m2_ty `mkAppTy` (n_ty `mkAppTy` alphaTy))
+ using_poly_ty = case by_e_ty of
+ Nothing -> mkForAllTy alphaTyVar using_fun_ty
+ -- using :: forall a. m1 a -> m2 (n a)
- ; using' <- case using of
- Left e -> do { e' <- tcPolyExpr e using_ty; return (Left e') }
- Right e -> do { e' <- tcPolyExpr (noLoc e) using_ty; return (Right (unLoc e')) }
+ Just (_,t_ty) -> mkForAllTy alphaTyVar $
+ (alphaTy `mkFunTy` t_ty) `mkFunTy` using_fun_ty
+ -- using :: forall a. (a->t) -> m1 a -> m2 (n a)
+ -- where by :: t
- -- Type check 'liftM' with 'forall a b. (a -> b) -> m_ty a -> m_ty b'
- ; liftM_op' <- fmap unLoc . tcPolyExpr (noLoc liftM_op) $
+ ; using' <- case using of
+ Left e -> do { e' <- tcPolyExpr e using_poly_ty
+ ; return (Left e') }
+ Right e -> do { e' <- tcPolyExpr (noLoc e) using_poly_ty
+ ; return (Right (unLoc e')) }
+ ; coi <- unifyType (applyTy using_poly_ty tup_ty)
+ (case by_e_ty of
+ Nothing -> using_fun_ty
+ Just (_,t_ty) -> (tup_ty `mkFunTy` t_ty) `mkFunTy` using_fun_ty)
+ ; let final_using = mkHsWrapCoI coi (HsWrap (WpTyApp tup_ty) using')
+
+ --------------- Typecheck the 'fmap' function -------------
+ ; fmap_op' <- fmap unLoc . tcPolyExpr (noLoc fmap_op) $
mkForAllTy alphaTyVar $ mkForAllTy betaTyVar $
(alphaTy `mkFunTy` betaTy)
`mkFunTy`
@@ -695,11 +702,23 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e
`mkFunTy`
(m_ty `mkAppTy` betaTy)
+ ; let mk_n_bndr :: Name -> TcId -> TcId
+ mk_n_bndr n_bndr_name bndr_id
+ = mkLocalId bndr_name (n_ty `mkAppTy` idType bndr_id)
+
+ -- Ensure that every old binder of type `b` is linked up with its
+ -- new binder which should have type `n b`
+ -- See Note [GroupStmt binder map] in HsExpr
+ n_bndr_ids = zipWith mk_n_bndr n_bndr_names bndr_ids
+ bindersMap' = bndr_ids `zip` n_bndr_ids
+
-- Type check the thing in the environment with these new binders and
-- return the result
- ; thing <- tcExtendIdEnv m_bndr_ids (thing_inside elt_ty)
+ ; thing <- tcExtendIdEnv n_bndr_ids (thing_inside res_ty)
- ; return (GroupStmt stmts' bindersMap' by' using' return_op' bind_op' liftM_op', thing) }
+ ; return (GroupStmt stmts' bindersMap'
+ (fmap fst by_e_ty) final_using
+ return_op' bind_op' fmap_op', thing) }
-- Typecheck `ParStmt`. See `tcLcStmt` for more informations about typechecking
-- of `ParStmt`s.
@@ -712,8 +731,8 @@ tcMcStmt ctxt (GroupStmt stmts bindersMap by using return_op bind_op liftM_op) e
-- -> (m st2 -> m st3 -> m (st2, st3)) -- recursive call
-- -> m (st1, (st2, st3))
--
-tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_inside
- = do { (_,(m_ty,_)) <- matchExpectedAppTy elt_ty
+tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) res_ty thing_inside
+ = do { (_,(m_ty,_)) <- matchExpectedAppTy res_ty
; (pairs', thing) <- loop m_ty bndr_stmts_s
; let mzip_ty = mkForAllTys [alphaTyVar, betaTyVar] $
@@ -725,19 +744,22 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi
; mzip_op' <- unLoc `fmap` tcPolyExpr (noLoc mzip_op) mzip_ty
-- Typecheck bind:
- ; let tys = map (mkChunkified mkBoxedTupleTy . map idType . snd) pairs'
+ ; let tys = map (mkBigCoreVarTupTy . snd) pairs'
tuple_ty = mk_tuple_ty tys
; bind_op' <- tcSyntaxOp MCompOrigin bind_op $
(m_ty `mkAppTy` tuple_ty)
`mkFunTy`
- (tuple_ty `mkFunTy` elt_ty)
+ (tuple_ty `mkFunTy` res_ty)
`mkFunTy`
- elt_ty
+ res_ty
; return_op' <- fmap unLoc . tcPolyExpr (noLoc return_op) $
mkForAllTy alphaTyVar $
alphaTy `mkFunTy` (m_ty `mkAppTy` alphaTy)
+ ; return_op' <- tcSyntaxOp MCompOrigin return_op
+ (bndr_ty `mkFunTy` m_bndr_ty)
+
; return (ParStmt pairs' mzip_op' bind_op' return_op', thing) }
where mk_tuple_ty tys = foldr (\tn tm -> mkBoxedTupleTy [tn, tm]) (last tys) (init tys)
@@ -745,16 +767,16 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi
-- loop :: Type -- m_ty
-- -> [([LStmt Name], [Name])]
-- -> TcM ([([LStmt TcId], [TcId])], thing)
- loop _ [] = do { thing <- thing_inside elt_ty
+ loop _ [] = do { thing <- thing_inside res_ty
; return ([], thing) } -- matching in the branches
loop m_ty ((stmts, names) : pairs)
= do { -- type dummy since we don't know all binder types yet
ty_dummy <- newFlexiTyVarTy liftedTypeKind
; (stmts', (ids, pairs', thing))
- <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \elt_ty' ->
+ <- tcStmts ctxt tcMcStmt stmts ty_dummy $ \res_ty' ->
do { ids <- tcLookupLocalIds names
- ; _ <- unifyType elt_ty' (m_ty `mkAppTy` (mkChunkified mkBoxedTupleTy) (map idType ids))
+ ; _ <- unifyType res_ty' (m_ty `mkAppTy` mkBigCoreVarTupTy ids)
; (pairs', thing) <- loop m_ty pairs
; return (ids, pairs', thing) }
; return ( (stmts', ids) : pairs', thing ) }
@@ -762,27 +784,17 @@ tcMcStmt ctxt (ParStmt bndr_stmts_s mzip_op bind_op return_op) elt_ty thing_insi
tcMcStmt _ stmt _ _
= pprPanic "tcMcStmt: unexpected Stmt" (ppr stmt)
--- Typecheck 'body' with type 'a' instead of 'm a' like the rest of the
--- statements, ignore the second type argument coming from the tcStmts loop
-tcMcBody :: LHsExpr Name
- -> SyntaxExpr Name
- -> TcRhoType
- -> TcM (LHsExpr TcId, SyntaxExpr TcId)
-tcMcBody body return_op res_ty
- = do { (_, (_, a_ty)) <- matchExpectedAppTy res_ty
- ; body' <- tcMonoExpr body a_ty
- ; return_op' <- tcSyntaxOp MCompOrigin return_op
- (a_ty `mkFunTy` res_ty)
- ; return (body', return_op')
- }
-
-
--------------------------------
-- Do-notation
-- The main excitement here is dealing with rebindable syntax
tcDoStmt :: TcStmtChecker
+tcDoStmt ctxt (LastStmt body _) res_ty thing_inside
+ = do { body' <- tcMonoExpr body res_ty
+ ; thing <- thing_inside body_ty
+ ; return (LastStmt body' noSyntaxExpr, thing) }
+
tcDoStmt ctxt (BindStmt pat rhs bind_op fail_op) res_ty thing_inside
= do { -- Deal with rebindable syntax:
-- (>>=) :: rhs_ty -> (pat_ty -> new_res_ty) -> res_ty
@@ -862,7 +874,7 @@ tcDoStmt ctxt (RecStmt { recS_stmts = stmts, recS_later_ids = later_names
; return (RecStmt { recS_stmts = stmts', recS_later_ids = later_ids
, recS_rec_ids = rec_ids, recS_ret_fn = ret_op'
, recS_mfix_fn = mfix_op', recS_bind_fn = bind_op'
- , recS_rec_rets = tup_rets }, thing)
+ , recS_rec_rets = tup_rets, recS_ret_ty = stmts_ty }, thing)
}}
tcDoStmt _ stmt _ _
@@ -888,6 +900,7 @@ the expected/inferred stuff is back to front (see Trac #3613).
tcMDoStmt :: (LHsExpr Name -> TcM (LHsExpr TcId, TcType)) -- RHS inference
-> TcStmtChecker
+-- Used only by TcArrows... should be gotten rid of
tcMDoStmt tc_rhs ctxt (BindStmt pat rhs _ _) res_ty thing_inside
= do { (rhs', pat_ty) <- tc_rhs rhs
; (pat', thing) <- tcPat (StmtCtxt ctxt) pat pat_ty $