summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2011-05-02 13:56:37 +0100
committerSimon Peyton Jones <simonpj@microsoft.com>2011-05-02 13:56:37 +0100
commitf6d254cccd3dc25fff9ff50c2e1bea52b10345e4 (patch)
tree0319f4e1e423455d6ef2e76dca39d8b970b38f82
parentd76d9636aeebe933d160157331b8c8c0087e73ac (diff)
downloadhaskell-f6d254cccd3dc25fff9ff50c2e1bea52b10345e4.tar.gz
More on monad-comp; an intermediate state, so don't pull
-rw-r--r--compiler/deSugar/DsExpr.lhs8
-rw-r--r--compiler/hsSyn/HsExpr.lhs18
-rw-r--r--compiler/hsSyn/HsUtils.lhs8
-rw-r--r--compiler/parser/Parser.y.pp2
-rw-r--r--compiler/rename/RnExpr.lhs96
-rw-r--r--compiler/typecheck/TcMatches.lhs27
6 files changed, 85 insertions, 74 deletions
diff --git a/compiler/deSugar/DsExpr.lhs b/compiler/deSugar/DsExpr.lhs
index 418bda5275..4088e44b1b 100644
--- a/compiler/deSugar/DsExpr.lhs
+++ b/compiler/deSugar/DsExpr.lhs
@@ -740,6 +740,7 @@ dsDo stmts
noSyntaxExpr -- Tuple cannot fail
tup_ids = rec_ids ++ filterOut (`elem` rec_ids) later_ids
+ tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
rec_tup_pats = map nlVarPat tup_ids
later_pats = rec_tup_pats
rets = map noLoc rec_rets
@@ -748,8 +749,11 @@ dsDo stmts
(mkFunTy tup_ty body_ty))
mfix_pat = noLoc $ LazyPat $ mkLHsPatTup rec_tup_pats
body = noLoc $ HsDo DoExpr (rec_stmts ++ [ret_stmt]) body_ty
- ret_stmt = noLoc $ LastStmt (mkLHsTupleExpr rets) return_op
- tup_ty = mkBoxedTupleTy (map idType tup_ids) -- Deals with singleton case
+ ret_app = nlHsApp (noLoc return_op) (mkLHsTupleExpr rets)
+ ret_stmt = noLoc $ mkLastStmt ret_app
+ -- This LastStmt will be desugared with dsDo,
+ -- which ignores the return_op in the LastStmt,
+ -- so we must apply the return_op explicitly
handle_failure :: LPat Id -> MatchResult -> SyntaxExpr Id -> DsM CoreExpr
-- In a do expression, pattern-match failure just calls
diff --git a/compiler/hsSyn/HsExpr.lhs b/compiler/hsSyn/HsExpr.lhs
index cf9c0d7402..fba270ce23 100644
--- a/compiler/hsSyn/HsExpr.lhs
+++ b/compiler/hsSyn/HsExpr.lhs
@@ -833,7 +833,8 @@ type Stmt id = StmtLR id id
-- The SyntaxExprs in here are used *only* for do-notation and monad
-- comprehensions, which have rebindable syntax. Otherwise they are unused.
data StmtLR idL idR
- = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp, DoExpr, MDoExpr
+ = LastStmt -- Always the last Stmt in ListComp, MonadComp, PArrComp,
+ -- and (after the renamer) DoExpr, MDoExpr
-- Not used for GhciStmt, PatGuard, which scope over other stuff
(LHsExpr idR)
(SyntaxExpr idR) -- The return operator, used only for MonadComp
@@ -1090,7 +1091,7 @@ instance (OutputableBndr idL, OutputableBndr idR) => Outputable (StmtLR idL idR)
ppr stmt = pprStmt stmt
pprStmt :: (OutputableBndr idL, OutputableBndr idR) => (StmtLR idL idR) -> SDoc
-pprStmt (LastStmt expr _) = ppr expr
+pprStmt (LastStmt expr _) = ifPprDebug (ptext (sLit "[last]")) <+> ppr expr
pprStmt (BindStmt pat expr _ _) = hsep [ppr pat, ptext (sLit "<-"), ppr expr]
pprStmt (LetStmt binds) = hsep [ptext (sLit "let"), pprBinds binds]
pprStmt (ExprStmt expr _ _ _) = ppr expr
@@ -1354,8 +1355,8 @@ pprAStmtContext ctxt = article <+> pprStmtContext ctxt
-----------------
pprStmtContext GhciStmt = ptext (sLit "interactive GHCi command")
-pprStmtContext DoExpr = ptext (sLit "'do' expression")
-pprStmtContext MDoExpr = ptext (sLit "'mdo' expression")
+pprStmtContext DoExpr = ptext (sLit "'do' block")
+pprStmtContext MDoExpr = ptext (sLit "'mdo' block")
pprStmtContext ListComp = ptext (sLit "list comprehension")
pprStmtContext MonadComp = ptext (sLit "monad comprehension")
pprStmtContext PArrComp = ptext (sLit "array comprehension")
@@ -1402,8 +1403,13 @@ pprMatchInCtxt ctxt match = hang (ptext (sLit "In") <+> pprMatchContext ctxt <>
pprStmtInCtxt :: (OutputableBndr idL, OutputableBndr idR)
=> HsStmtContext idL -> StmtLR idL idR -> SDoc
-pprStmtInCtxt ctxt stmt = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon)
- 4 (ppr_stmt stmt)
+pprStmtInCtxt ctxt (LastStmt e _)
+ | isListCompExpr ctxt -- For [ e | .. ], do not mutter about "stmts"
+ = hang (ptext (sLit "In the expression:")) 2 (ppr e)
+
+pprStmtInCtxt ctxt stmt
+ = hang (ptext (sLit "In a stmt of") <+> pprAStmtContext ctxt <> colon)
+ 2 (ppr_stmt stmt)
where
-- For Group and Transform Stmts, don't print the nested stmts!
ppr_stmt (GroupStmt { grpS_by = by, grpS_using = using
diff --git a/compiler/hsSyn/HsUtils.lhs b/compiler/hsSyn/HsUtils.lhs
index de883f25a5..51a0de35c7 100644
--- a/compiler/hsSyn/HsUtils.lhs
+++ b/compiler/hsSyn/HsUtils.lhs
@@ -21,7 +21,7 @@ module HsUtils(
mkMatchGroup, mkMatch, mkHsLam, mkHsIf,
mkHsWrap, mkLHsWrap, mkHsWrapCoI, mkLHsWrapCoI,
coiToHsWrapper, mkHsLams, mkHsDictLet,
- mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI, mkDoStmts,
+ mkHsOpApp, mkHsDo, mkHsComp, mkHsWrapPat, mkHsWrapPatCoI,
nlHsTyApp, nlHsVar, nlHsLit, nlHsApp, nlHsApps, nlHsIntLit, nlHsVarApps,
nlHsDo, nlHsOpApp, nlHsLam, nlHsPar, nlHsIf, nlHsCase, nlList,
@@ -192,7 +192,6 @@ mkHsFractional :: Rational -> PostTcType -> HsOverLit id
mkHsIsString :: FastString -> PostTcType -> HsOverLit id
mkHsDo :: HsStmtContext Name -> [LStmt id] -> HsExpr id
mkHsComp :: HsStmtContext Name -> [LStmt id] -> LHsExpr id -> HsExpr id
-mkDoStmts :: [LStmt id] -> [LStmt id]
mkNPat :: HsOverLit id -> Maybe (SyntaxExpr id) -> Pat id
mkNPlusKPat :: Located id -> HsOverLit id -> Pat id
@@ -215,11 +214,6 @@ mkHsIsString s = OverLit (HsIsString s) noRebindableInfo noSyntaxExpr
noRebindableInfo :: Bool
noRebindableInfo = error "noRebindableInfo" -- Just another placeholder;
--- mkDoStmts turns a trailing ExprStmt into a LastStmt
-mkDoStmts [L loc (ExprStmt e _ _ _)] = [L loc (mkLastStmt e)]
-mkDoStmts (s:ss) = s : mkDoStmts ss
-mkDoStmts [] = []
-
mkHsDo ctxt stmts = HsDo ctxt stmts placeHolderType
mkHsComp ctxt stmts expr = mkHsDo ctxt (stmts ++ [last_stmt])
where
diff --git a/compiler/parser/Parser.y.pp b/compiler/parser/Parser.y.pp
index c42ea0c864..aa20ea6799 100644
--- a/compiler/parser/Parser.y.pp
+++ b/compiler/parser/Parser.y.pp
@@ -1602,7 +1602,7 @@ apats :: { [LPat RdrName] }
-- Statement sequences
stmtlist :: { Located [LStmt RdrName] }
- : '{' stmts '}' { LL (mkDoStmts (unLoc $2)) }
+ : '{' stmts '}' { LL (unLoc $2) }
| vocurly stmts close { $2 }
-- do { ;; s ; s ; ; s ;; }
diff --git a/compiler/rename/RnExpr.lhs b/compiler/rename/RnExpr.lhs
index d1dd222be7..11d44e3bad 100644
--- a/compiler/rename/RnExpr.lhs
+++ b/compiler/rename/RnExpr.lhs
@@ -648,32 +648,22 @@ rnStmts MDoExpr stmts thing_inside -- Deal with mdo
= -- Behave like do { rec { ...all but last... }; last }
do { ((stmts1, (stmts2, thing)), fvs)
<- rnStmt MDoExpr (noLoc $ mkRecStmt all_but_last) $ \ _ ->
- do { checkStmt MDoExpr True last_stmt
- ; rnStmt MDoExpr last_stmt thing_inside }
+ do { last_stmt' <- checkLastStmt MDoExpr last_stmt
+ ; rnStmt MDoExpr last_stmt' thing_inside }
; return (((stmts1 ++ stmts2), thing), fvs) }
where
Just (all_but_last, last_stmt) = snocView stmts
-rnStmts ctxt (lstmt@(L loc stmt) : lstmts) thing_inside
+rnStmts ctxt (lstmt@(L loc _) : lstmts) thing_inside
| null lstmts
= setSrcSpan loc $
- do { -- Turn a final ExprStmt into a LastStmt
- -- This is the first place it's convenient to do this
- -- (In principle the parser could do it, but it's
- -- just not very convenient to do so.)
- let stmt' | okEmpty ctxt
- = lstmt
- | otherwise
- = case stmt of
- ExprStmt e _ _ _ -> L loc (mkLastStmt e)
- _ -> lstmt
- ; checkStmt ctxt True {- last stmt -} stmt'
- ; rnStmt ctxt stmt' thing_inside }
+ do { lstmt' <- checkLastStmt ctxt lstmt
+ ; rnStmt ctxt lstmt' thing_inside }
| otherwise
= do { ((stmts1, (stmts2, thing)), fvs)
<- setSrcSpan loc $
- do { checkStmt ctxt False {- Not last -} lstmt
+ do { checkStmt ctxt lstmt
; rnStmt ctxt lstmt $ \ bndrs1 ->
rnStmts ctxt lstmts $ \ bndrs2 ->
thing_inside (bndrs1 ++ bndrs2) }
@@ -1211,7 +1201,7 @@ checkEmptyStmts :: HsStmtContext Name -> RnM ()
checkEmptyStmts ctxt
= unless (okEmpty ctxt) (addErr (emptyErr ctxt))
-okEmpty :: HsStmtContext Name -> Bool
+okEmpty :: HsStmtContext a -> Bool
okEmpty (PatGuard {}) = True
okEmpty _ = False
@@ -1221,14 +1211,42 @@ emptyErr (TransformStmtCtxt {}) = ptext (sLit "Empty statement group preceding '
emptyErr ctxt = ptext (sLit "Empty") <+> pprStmtContext ctxt
----------------------
+checkLastStmt :: HsStmtContext Name
+ -> LStmt RdrName
+ -> RnM (LStmt RdrName)
+checkLastStmt ctxt lstmt@(L loc stmt)
+ = case ctxt of
+ ListComp -> check_comp
+ MonadComp -> check_comp
+ PArrComp -> check_comp
+ DoExpr -> check_do
+ MDoExpr -> check_do
+ _ -> check_other
+ where
+ check_do -- Expect ExprStmt, and change it to LastStmt
+ = case stmt of
+ ExprStmt e _ _ _ -> return (L loc (mkLastStmt e))
+ LastStmt {} -> return lstmt -- "Deriving" clauses may generate a
+ -- LastStmt directly (unlike the parser)
+ _ -> do { addErr (hang last_error 2 (ppr stmt)); return lstmt }
+ last_error = (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
+ <+> ptext (sLit "must be an expression"))
+
+ check_comp -- Expect LastStmt; this should be enforced by the parser!
+ = case stmt of
+ LastStmt {} -> return lstmt
+ _ -> pprPanic "checkLastStmt" (ppr lstmt)
+
+ check_other -- Behave just as if this wasn't the last stmt
+ = do { checkStmt ctxt lstmt; return lstmt }
+
-- Checking when a particular Stmt is ok
checkStmt :: HsStmtContext Name
- -> Bool -- True <=> this is the last Stmt in the sequence
-> LStmt RdrName
-> RnM ()
-checkStmt ctxt is_last (L _ stmt)
+checkStmt ctxt (L _ stmt)
= do { dflags <- getDOpts
- ; case okStmt dflags ctxt is_last stmt of
+ ; case okStmt dflags ctxt stmt of
Nothing -> return ()
Just extra -> addErr (msg $$ extra) }
where
@@ -1250,42 +1268,32 @@ isOK, notOK :: Maybe SDoc
isOK = Nothing
notOK = Just empty
-okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name -> Bool
+okStmt, okDoStmt, okCompStmt :: DynFlags -> HsStmtContext Name
-> Stmt RdrName -> Maybe SDoc
-- Return Nothing if OK, (Just extra) if not ok
-- The "extra" is an SDoc that is appended to an generic error message
-okStmt _ (PatGuard {}) _ stmt
+okStmt _ (PatGuard {}) stmt
= case stmt of
ExprStmt {} -> isOK
BindStmt {} -> isOK
LetStmt {} -> isOK
_ -> notOK
-okStmt dflags (ParStmtCtxt ctxt) _ stmt
+okStmt dflags (ParStmtCtxt ctxt) stmt
= case stmt of
LetStmt (HsIPBinds {}) -> notOK
- _ -> okStmt dflags ctxt False stmt
- -- NB: is_last=False in recursive
- -- call; the branches of of a Par
- -- not finish with a LastStmt
+ _ -> okStmt dflags ctxt stmt
-okStmt dflags (TransformStmtCtxt ctxt) _ stmt
- = okStmt dflags ctxt False stmt
+okStmt dflags (TransformStmtCtxt ctxt) stmt
+ = okStmt dflags ctxt stmt
-okStmt dflags ctxt is_last stmt
- | isDoExpr ctxt = okDoStmt dflags ctxt is_last stmt
- | isListCompExpr ctxt = okCompStmt dflags ctxt is_last stmt
+okStmt dflags ctxt stmt
+ | isDoExpr ctxt = okDoStmt dflags ctxt stmt
+ | isListCompExpr ctxt = okCompStmt dflags ctxt stmt
| otherwise = pprPanic "okStmt" (pprStmtContext ctxt)
----------------
-okDoStmt dflags ctxt is_last stmt
- | is_last
- = case stmt of
- LastStmt {} -> isOK
- _ -> Just (ptext (sLit "The last statement in") <+> pprAStmtContext ctxt
- <+> ptext (sLit "must be an expression"))
-
- | otherwise
+okDoStmt dflags _ stmt
= case stmt of
RecStmt {}
| Opt_DoRec `xopt` dflags -> isOK
@@ -1297,13 +1305,7 @@ okDoStmt dflags ctxt is_last stmt
----------------
-okCompStmt dflags _ is_last stmt
- | is_last
- = case stmt of
- LastStmt {} -> Nothing
- _ -> pprPanic "Unexpected stmt" (ppr stmt) -- Not a user error
-
- | otherwise
+okCompStmt dflags _ stmt
= case stmt of
BindStmt {} -> isOK
LetStmt {} -> isOK
diff --git a/compiler/typecheck/TcMatches.lhs b/compiler/typecheck/TcMatches.lhs
index 820e5172ce..87449b6d5c 100644
--- a/compiler/typecheck/TcMatches.lhs
+++ b/compiler/typecheck/TcMatches.lhs
@@ -358,7 +358,7 @@ tcLcStmt :: TyCon -- The list/Parray type constructor ([] or PArray)
-> TcStmtChecker
tcLcStmt _ _ (LastStmt body _) elt_ty thing_inside
- = do { body' <- tcMonoExpr body elt_ty
+ = do { body' <- tcMonoExprNC body elt_ty
; thing <- thing_inside (panic "tcLcStmt: thing_inside")
; return (LastStmt body' noSyntaxExpr, thing) }
@@ -502,7 +502,7 @@ tcMcStmt _ (LastStmt body return_op) res_ty thing_inside
= do { a_ty <- newFlexiTyVarTy liftedTypeKind
; return_op' <- tcSyntaxOp MCompOrigin return_op
(a_ty `mkFunTy` res_ty)
- ; body' <- tcMonoExpr body a_ty
+ ; body' <- tcMonoExprNC body a_ty
; thing <- thing_inside (panic "tcMcStmt: thing_inside")
; return (LastStmt body' return_op', thing) }
@@ -558,15 +558,20 @@ tcMcStmt _ (ExprStmt rhs then_op guard_op _) res_ty thing_inside
-- [ body | stmts, then f by e ] -> f :: forall a. (a -> t) -> m a -> m a
--
tcMcStmt ctxt (TransformStmt stmts binders usingExpr maybeByExpr return_op bind_op) res_ty thing_inside
- = do {
- -- We don't know the types of binders yet, so we use this dummy and
- -- later unify this type with the `m_bndr_ty`
- ty_dummy <- newFlexiTyVarTy liftedTypeKind
+ = do { let star_star_kind = liftedTypeKind `mkArrowKind` liftedTypeKind
+ ; m1_ty <- newFlexiTyVarTy star_star_kind
+ ; m2_ty <- newFlexiTyVarTy star_star_kind
+ ; n_ty <- newFlexiTyVarTy star_star_kind
+ ; tup_ty_var <- newFlexiTyVarTy liftedTypeKind
+ ; new_res_ty <- newFlexiTyVarTy liftedTypeKind
+ ; let m1_tup_ty = m1_ty `mkAppTy` tup_ty_var
+ -- 'stmts' returns a result of type (m1_ty tuple_ty),
+ -- typically something like [(Int,Bool,Int)]
+ -- We don't know what tuple_ty is yet, so we use a variable
; (stmts', (binders', usingExpr', maybeByExpr', return_op', bind_op', thing)) <-
- tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts ty_dummy $ \res_ty' -> do
- { (_, (m_ty, _)) <- matchExpectedAppTy res_ty'
- ; (usingExpr', maybeByExpr') <-
+ tcStmtsAndThen (TransformStmtCtxt ctxt) tcMcStmt stmts m1_tup_ty $ \res_ty' -> do
+ { (usingExpr', maybeByExpr') <-
case maybeByExpr of
Nothing -> do
-- We must validate that usingExpr :: forall a. m a -> m a
@@ -671,8 +676,8 @@ tcMcStmt ctxt (GroupStmt { grpS_stmts = stmts, grpS_bndrs = bindersMap
using_res_ty = m2_ty `mkAppTy` n_tup_ty -- m2 (n (a,b,c))
using_fun_ty = using_arg_ty `mkFunTy` using_arg_ty
- -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
- -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
+ -- (>>=) :: m2 (n (a,b,c)) -> ( n (a,b,c) -> new_res_ty ) -> res_ty
+ -- using :: ((a,b,c)->t) -> m1 (a,b,c) -> m2 (n (a,b,c))
--------------- Typecheck the 'bind' function -------------
; bind_op' <- tcSyntaxOp MCompOrigin bind_op $