diff options
Diffstat (limited to 'compiler/codeGen/StgCmmUtils.hs')
-rw-r--r-- | compiler/codeGen/StgCmmUtils.hs | 118 |
1 files changed, 74 insertions, 44 deletions
diff --git a/compiler/codeGen/StgCmmUtils.hs b/compiler/codeGen/StgCmmUtils.hs index 94013f5c6d..68949bf190 100644 --- a/compiler/codeGen/StgCmmUtils.hs +++ b/compiler/codeGen/StgCmmUtils.hs @@ -55,6 +55,7 @@ import CLabel import CmmUtils import CmmSwitch +import BasicTypes (BranchWeight) import ForeignCall import IdInfo import Type @@ -74,8 +75,6 @@ import RepType import qualified Data.ByteString as BS import qualified Data.Map as M import Data.Char -import Data.List -import Data.Ord import Data.Word @@ -448,16 +447,16 @@ unscramble dflags vertices = mapM_ do_component components emitSwitch :: CmmExpr -- Tag to switch on - -> [(ConTagZ, CmmAGraphScoped)] -- Tagged branches - -> Maybe CmmAGraphScoped -- Default branch (if any) + -> [(ConTagZ, CmmAGraphScoped, BranchWeight)] -- Tagged branches + -> Maybe (CmmAGraphScoped, BranchWeight) -- Default branch (if any) -> ConTagZ -> ConTagZ -- Min and Max possible values; -- behaviour outside this range is -- undefined -> FCode () -- First, two rather common cases in which there is no work to do -emitSwitch _ [] (Just code) _ _ = emit (fst code) -emitSwitch _ [(_,code)] Nothing _ _ = emit (fst code) +emitSwitch _ [] (Just code) _ _ = emit (fst $ fst code) +emitSwitch _ [(_,code,_)] Nothing _ _ = emit (fst code) -- Right, off we go emitSwitch tag_expr branches mb_deflt lo_tag hi_tag = do @@ -467,7 +466,8 @@ emitSwitch tag_expr branches mb_deflt lo_tag hi_tag = do tag_expr' <- assignTemp' tag_expr -- Sort the branches before calling mk_discrete_switch - let branches_lbls' = [ (fromIntegral i, l) | (i,l) <- sortBy (comparing fst) branches_lbls ] + let branches_lbls' = [ (fromIntegral i, l, f) + | (i,l,f) <- sortWith fstOf3 branches_lbls ] let range = (fromIntegral lo_tag, fromIntegral hi_tag) emit $ mk_discrete_switch False tag_expr' branches_lbls' mb_deflt_lbl range @@ -476,19 +476,19 @@ emitSwitch tag_expr branches mb_deflt lo_tag hi_tag = do mk_discrete_switch :: Bool -- ^ Use signed comparisons -> CmmExpr - -> [(Integer, BlockId)] - -> Maybe BlockId + -> [(Integer, BlockId, BranchWeight)] + -> Maybe (BlockId, BranchWeight) -> (Integer, Integer) -> CmmAGraph -- SINGLETON TAG RANGE: no case analysis to do -mk_discrete_switch _ _tag_expr [(tag, lbl)] _ (lo_tag, hi_tag) +mk_discrete_switch _ _tag_expr [(tag, lbl, _f)] _ (lo_tag, hi_tag) | lo_tag == hi_tag = ASSERT( tag == lo_tag ) mkBranch lbl -- SINGLETON BRANCH, NO DEFAULT: no case analysis to do -mk_discrete_switch _ _tag_expr [(_tag,lbl)] Nothing _ +mk_discrete_switch _ _tag_expr [(_tag,lbl,_)] Nothing _ = mkBranch lbl -- The simplifier might have eliminated a case -- so we may have e.g. case xs of @@ -499,25 +499,17 @@ mk_discrete_switch _ _tag_expr [(_tag,lbl)] Nothing _ -- SOMETHING MORE COMPLICATED: defer to CmmImplementSwitchPlans -- See Note [Cmm Switches, the general plan] in CmmSwitch mk_discrete_switch signed tag_expr branches mb_deflt range - = mkSwitch tag_expr $ mkSwitchTargets signed range mb_deflt (M.fromList branches) - -divideBranches :: Ord a => [(a,b)] -> ([(a,b)], a, [(a,b)]) -divideBranches branches = (lo_branches, mid, hi_branches) - where - -- 2 branches => n_branches `div` 2 = 1 - -- => branches !! 1 give the *second* tag - -- There are always at least 2 branches here - (mid,_) = branches !! (length branches `div` 2) - (lo_branches, hi_branches) = span is_lo branches - is_lo (t,_) = t < mid + = mkSwitch tag_expr $ + mkSwitchTargets signed range mb_deflt + (M.fromList $ map (\(i,e,f)-> (i,(e,f))) branches) -------------- emitCmmLitSwitch :: CmmExpr -- Tag to switch on - -> [(Literal, CmmAGraphScoped)] -- Tagged branches - -> CmmAGraphScoped -- Default branch (always) + -> [(Literal, CmmAGraphScoped, BranchWeight)] -- Tagged branches + -> (CmmAGraphScoped, BranchWeight) -- Default branch (always) -> FCode () -- Emit the code -emitCmmLitSwitch _scrut [] deflt = emit $ fst deflt -emitCmmLitSwitch scrut branches deflt = do +emitCmmLitSwitch _scrut [] (deflt,_dfreq) = emit $ fst deflt +emitCmmLitSwitch scrut branches (deflt,dfreq) = do scrut' <- assignTemp' scrut join_lbl <- newBlockId deflt_lbl <- label_code join_lbl deflt @@ -529,20 +521,22 @@ emitCmmLitSwitch scrut branches deflt = do -- We find the necessary type information in the literals in the branches let signed = case head branches of - (MachInt _, _) -> True - (MachInt64 _, _) -> True + (MachInt _, _, _) -> True + (MachInt64 _, _, _) -> True _ -> False let range | signed = (tARGET_MIN_INT dflags, tARGET_MAX_INT dflags) | otherwise = (0, tARGET_MAX_WORD dflags) if isFloatType cmm_ty - then emit =<< mk_float_switch rep scrut' deflt_lbl noBound branches_lbls + then emit =<< mk_float_switch rep scrut' + (deflt_lbl, dfreq) noBound + branches_lbls else emit $ mk_discrete_switch signed scrut' - [(litValue lit,l) | (lit,l) <- branches_lbls] - (Just deflt_lbl) + [(litValue lit,l,f) | (lit,l,f) <- branches_lbls] + (Just (deflt_lbl, dfreq)) range emitLabel join_lbl @@ -552,11 +546,30 @@ type LitBound = (Maybe Literal, Maybe Literal) noBound :: LitBound noBound = (Nothing, Nothing) -mk_float_switch :: Width -> CmmExpr -> BlockId +{- TODO: + Currently this generates a binary search tree for the given value. + + Given we have branch weights we would ideally balance the tree + by weight instead. + + Eg. given (lit,weight) of [(0,1),(1,1),(2,1),(3,99)] we want to split the + list into [(0,1),(1,1),(2,1)] and [(3,99)]. + + Things to consider: + * Does it make a difference often enough to be worth the complexity + and increase in compile time. + * Negative weights have to be rounded up to zero, + otherwise they would distort the results. + * How should entries with no information be treated? + -> Probably good enough to use the default value. + * If implemented should this only apply when optimizations are + active? +-} +mk_float_switch :: Width -> CmmExpr -> (BlockId, BranchWeight) -> LitBound - -> [(Literal,BlockId)] + -> [(Literal,BlockId,BranchWeight)] -> FCode CmmAGraph -mk_float_switch rep scrut deflt _bounds [(lit,blk)] +mk_float_switch rep scrut (deflt, _dfrq) _bounds [(lit,blk,_frq)] = do dflags <- getDynFlags return $ mkCbranch (cond dflags) deflt blk Nothing where @@ -565,17 +578,32 @@ mk_float_switch rep scrut deflt _bounds [(lit,blk)] cmm_lit = mkSimpleLit dflags lit ne = MO_F_Ne rep -mk_float_switch rep scrut deflt_blk_id (lo_bound, hi_bound) branches +mk_float_switch rep scrut (deflt_blk_id,dfreq) (lo_bound, hi_bound) branches = do dflags <- getDynFlags - lo_blk <- mk_float_switch rep scrut deflt_blk_id bounds_lo lo_branches - hi_blk <- mk_float_switch rep scrut deflt_blk_id bounds_hi hi_branches - mkCmmIfThenElse (cond dflags) lo_blk hi_blk + lo_blk <- mk_float_switch + rep scrut (deflt_blk_id,dfreq) + bounds_lo lo_branches + hi_blk <- mk_float_switch + rep scrut + (deflt_blk_id,dfreq) bounds_hi hi_branches + mkCmmIfThenElse (cond dflags) lo_blk hi_blk Nothing where + (lo_branches, mid_lit, hi_branches) = divideBranches branches bounds_lo = (lo_bound, Just mid_lit) bounds_hi = (Just mid_lit, hi_bound) + divideBranches :: Ord a => [(a,b,c)] -> ([(a,b,c)], a, [(a,b,c)]) + divideBranches branches = (lo_branches, mid, hi_branches) + where + -- 2 branches => n_branches `div` 2 = 1 + -- => branches !! 1 give the *second* tag + -- There are always at least 2 branches here + (mid,_,_) = branches !! (length branches `div` 2) + (lo_branches, hi_branches) = span is_lo branches + is_lo (t,_,_) = t < mid + cond dflags = CmmMachOp lt [scrut, CmmLit cmm_lit] where cmm_lit = mkSimpleLit dflags mid_lit @@ -583,21 +611,23 @@ mk_float_switch rep scrut deflt_blk_id (lo_bound, hi_bound) branches -------------- -label_default :: BlockId -> Maybe CmmAGraphScoped -> FCode (Maybe BlockId) +label_default :: BlockId -> Maybe (CmmAGraphScoped, BranchWeight) + -> FCode (Maybe (BlockId, BranchWeight)) label_default _ Nothing = return Nothing -label_default join_lbl (Just code) +label_default join_lbl (Just (code,f)) = do lbl <- label_code join_lbl code - return (Just lbl) + return (Just (lbl,f)) -------------- -label_branches :: BlockId -> [(a,CmmAGraphScoped)] -> FCode [(a,BlockId)] +label_branches :: BlockId -> [(a,CmmAGraphScoped, BranchWeight)] + -> FCode [(a,BlockId,BranchWeight)] label_branches _join_lbl [] = return [] -label_branches join_lbl ((tag,code):branches) +label_branches join_lbl ((tag,code,freq):branches) = do lbl <- label_code join_lbl code branches' <- label_branches join_lbl branches - return ((tag,lbl):branches') + return ((tag,lbl,freq):branches') -------------- label_code :: BlockId -> CmmAGraphScoped -> FCode BlockId |