{-# LANGUAGE GADTs, BangPatterns #-} module CmmCommonBlockElim ( elimCommonBlocks ) where import GhcPrelude hiding (iterate, succ, unzip, zip) import BlockId import Cmm import CmmUtils import CmmSwitch (eqSwitchTargetWith) import CmmContFlowOpt -- import PprCmm () import Hoopl.Block import Hoopl.Graph import Hoopl.Label import Hoopl.Collections import Data.Bits import Data.Maybe (mapMaybe) import qualified Data.List as List import Data.Word import qualified Data.Map as M import Outputable import DynFlags (DynFlags) import UniqFM import UniqDFM import qualified TrieMap as TM import Unique import Control.Arrow (first, second) -- ----------------------------------------------------------------------------- -- Eliminate common blocks -- If two blocks are identical except for the label on the first node, -- then we can eliminate one of the blocks. To ensure that the semantics -- of the program are preserved, we have to rewrite each predecessor of the -- eliminated block to proceed with the block we keep. -- The algorithm iterates over the blocks in the graph, -- checking whether it has seen another block that is equal modulo labels. -- If so, then it adds an entry in a map indicating that the new block -- is made redundant by the old block. -- Otherwise, it is added to the useful blocks. -- To avoid comparing every block with every other block repeatedly, we group -- them by -- * a hash of the block, ignoring labels (explained below) -- * the list of outgoing labels -- The hash is invariant under relabeling, so we only ever compare within -- the same group of blocks. -- -- The list of outgoing labels is updated as we merge blocks (that is why they -- are not included in the hash, which we want to calculate only once). -- -- All in all, two blocks should never be compared if they have different -- hashes, and at most once otherwise. Previously, we were slower, and people -- rightfully complained: #10397 -- TODO: Use optimization fuel elimCommonBlocks :: DynFlags -> CmmGraph -> CmmGraph elimCommonBlocks dflags g = replaceLabels env $ copyTicks env g where env = iterate dflags mapEmpty blocks_with_key groups = groupByInt (hash_block dflags) (postorderDfs g) blocks_with_key = [ [ (successors b, [b]) | b <- bs] | bs <- groups] -- Invariant: The blocks in the list are pairwise distinct -- (so avoid comparing them again) type DistinctBlocks = [CmmBlock] type Key = [Label] type Subst = LabelMap BlockId -- The outer list groups by hash. We retain this grouping throughout. iterate :: DynFlags -> Subst -> [[(Key, DistinctBlocks)]] -> Subst iterate dflags subst blocks | mapNull new_substs = subst | otherwise = iterate dflags subst' updated_blocks where grouped_blocks :: [[(Key, [DistinctBlocks])]] grouped_blocks = map groupByLabel blocks merged_blocks :: [[(Key, DistinctBlocks)]] (new_substs, merged_blocks) = List.mapAccumL (List.mapAccumL go) mapEmpty grouped_blocks where go !new_subst1 (k,dbs) = (new_subst1 `mapUnion` new_subst2, (k,db)) where (new_subst2, db) = mergeBlockList dflags subst dbs subst' = subst `mapUnion` new_substs updated_blocks = map (map (first (map (lookupBid subst')))) merged_blocks mergeBlocks :: DynFlags -> Subst -> DistinctBlocks -> DistinctBlocks -> (Subst, DistinctBlocks) mergeBlocks dflags subst existing new = go new where go [] = (mapEmpty, existing) go (b:bs) = case List.find (eqBlockBodyWith dflags (eqBid subst) b) existing of -- This block is a duplicate. Drop it, and add it to the substitution Just b' -> first (mapInsert (entryLabel b) (entryLabel b')) $ go bs -- This block is not a duplicate, keep it. Nothing -> second (b:) $ go bs mergeBlockList :: DynFlags -> Subst -> [DistinctBlocks] -> (Subst, DistinctBlocks) mergeBlockList _ _ [] = pprPanic "mergeBlockList" empty mergeBlockList dflags subst (b:bs) = go mapEmpty b bs where go !new_subst1 b [] = (new_subst1, b) go !new_subst1 b1 (b2:bs) = go new_subst b bs where (new_subst2, b) = mergeBlocks dflags subst b1 b2 new_subst = new_subst1 `mapUnion` new_subst2 -- ----------------------------------------------------------------------------- -- Hashing and equality on blocks -- Below here is mostly boilerplate: hashing blocks ignoring labels, -- and comparing blocks modulo a label mapping. -- To speed up comparisons, we hash each basic block modulo jump labels. -- The hashing is a bit arbitrary (the numbers are completely arbitrary), -- but it should be fast and good enough. -- We want to get as many small buckets as possible, as comparing blocks is -- expensive. So include as much as possible in the hash. Ideally everything -- that is compared with (==) in eqBlockBodyWith. {- Note [Equivalence up to local registers in CBE] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CBE treats two blocks which are equivalent up to alpha-renaming of locally-bound local registers as equivalent. This was not always the case (see #14226) but is quite important for effective CBE. For instance, consider the blocks, c2VZ: // global _c2Yd::I64 = _s2Se::I64 + 1; _s2Sx::I64 = _c2Yd::I64; _s2Se::I64 = _s2Sx::I64; goto c2TE; c2VY: // global _c2Yb::I64 = _s2Se::I64 + 1; _s2Sw::I64 = _c2Yb::I64; _s2Se::I64 = _s2Sw::I64; goto c2TE; These clearly implement precisely the same logic, differing only register naming. This happens quite often in the code produced by GHC. This alpha-equivalence relation must be accounted for in two places: 1. the block hash function (hash_block), which we use for approximate "binning" 2. the exact block comparison function, which computes pair-wise equivalence In (1) we maintain a de Bruijn numbering of each block's locally-bound local registers and compute the hash relative to this numbering. For (2) we maintain a substitution which maps the local registers of one block onto those of the other. We then compare local registers modulo this substitution. -} type HashCode = Int type LocalRegEnv a = UniqFM a type DeBruijn = Int -- | Maintains a de Bruijn numbering of local registers bound within a block. -- -- See Note [Equivalence up to local registers in CBE] data HashEnv = HashEnv { localRegHashEnv :: !(LocalRegEnv DeBruijn) , nextIndex :: !DeBruijn } hash_block :: DynFlags -> CmmBlock -> HashCode hash_block dflags block = --pprTrace "hash_block" (ppr (entryLabel block) $$ ppr hash) hash where hash_fst _ (env, h) = (env, h) hash_mid m (env, h) = let (env', h') = hash_node env m in (env', h' + h `shiftL` 1) hash_lst m (env, h) = let (env', h') = hash_node env m in (env', h' + h `shiftL` 1) hash = let (_, raw_hash) = foldBlockNodesF3 (hash_fst, hash_mid, hash_lst) block (emptyEnv, 0 :: Word32) emptyEnv = HashEnv mempty 0 in fromIntegral (raw_hash .&. (0x7fffffff :: Word32)) -- UniqFM doesn't like negative Ints hash_node :: HashEnv -> CmmNode O x -> (HashEnv, Word32) hash_node env n = (env', hash) where hash = case n of n | dont_care n -> 0 -- don't care -- don't include register as it is a binding occurrence CmmAssign (CmmLocal _) e -> hash_e env e CmmAssign r e -> hash_reg env r + hash_e env e CmmStore e e' -> hash_e env e + hash_e env e' CmmUnsafeForeignCall t _ as -> hash_tgt env t + hash_list (hash_e env) as CmmBranch _ -> 23 -- NB. ignore the label CmmCondBranch p _ _ _ -> hash_e env p CmmCall e _ _ _ _ _ -> hash_e env e CmmForeignCall t _ _ _ _ _ _ -> hash_tgt env t CmmSwitch e _ -> hash_e env e _ -> error "hash_node: unknown Cmm node!" env' = foldLocalRegsDefd dflags (flip bind_local_reg) env n hash_reg :: HashEnv -> CmmReg -> Word32 hash_reg env (CmmLocal localReg) | Just idx <- lookupUFM (localRegHashEnv env) localReg = fromIntegral idx | otherwise = hash_unique localReg -- important for performance, see #10397 hash_reg _ (CmmGlobal _) = 19 hash_e :: HashEnv -> CmmExpr -> Word32 hash_e _ (CmmLit l) = hash_lit l hash_e env (CmmLoad e _) = 67 + hash_e env e hash_e env (CmmReg r) = hash_reg env r hash_e env (CmmMachOp _ es) = hash_list (hash_e env) es -- pessimal - no operator check hash_e env (CmmRegOff r i) = hash_reg env r + cvt i hash_e _ (CmmStackSlot _ _) = 13 hash_lit :: CmmLit -> Word32 hash_lit (CmmInt i _) = fromInteger i hash_lit (CmmFloat r _) = truncate r hash_lit (CmmVec ls) = hash_list hash_lit ls hash_lit (CmmLabel _) = 119 -- ugh hash_lit (CmmLabelOff _ i) = cvt $ 199 + i hash_lit (CmmLabelDiffOff _ _ i) = cvt $ 299 + i hash_lit (CmmBlock _) = 191 -- ugh hash_lit (CmmHighStackMark) = cvt 313 hash_tgt :: HashEnv -> ForeignTarget -> Word32 hash_tgt env (ForeignTarget e _) = hash_e env e hash_tgt _ (PrimTarget _) = 31 -- lots of these hash_list f = List.foldl' (\z x -> f x + z) (0::Word32) cvt = fromInteger . toInteger bind_local_reg :: LocalReg -> HashEnv -> HashEnv bind_local_reg reg env = env { localRegHashEnv = addToUFM (localRegHashEnv env) reg (nextIndex env) , nextIndex = nextIndex env + 1 } hash_unique :: Uniquable a => a -> Word32 hash_unique = cvt . getKey . getUnique -- | Ignore these node types for equality dont_care :: CmmNode O x -> Bool dont_care CmmComment {} = True dont_care CmmTick {} = True dont_care CmmUnwind {} = True dont_care _other = False -- Utilities: equality and substitution on the graph. -- Given a map ``subst'' from BlockID -> BlockID, we define equality. eqBid :: LabelMap BlockId -> BlockId -> BlockId -> Bool eqBid subst bid bid' = lookupBid subst bid == lookupBid subst bid' lookupBid :: LabelMap BlockId -> BlockId -> BlockId lookupBid subst bid = case mapLookup bid subst of Just bid -> lookupBid subst bid Nothing -> bid -- | Maps the local registers of one block to those of another -- -- See Note [Equivalence up to local registers in CBE] type LocalRegMapping = LocalRegEnv LocalReg -- Middle nodes and expressions can contain BlockIds, in particular in -- CmmStackSlot and CmmBlock, so we have to use a special equality for -- these. -- eqMiddleWith :: DynFlags -> (BlockId -> BlockId -> Bool) -> LocalRegMapping -> CmmNode O O -> CmmNode O O -> (LocalRegMapping, Bool) eqMiddleWith dflags eqBid env a b = case (a, b) of -- registers aren't compared since they are binding occurrences (CmmAssign (CmmLocal _) e1, CmmAssign (CmmLocal _) e2) -> let eq = eqExprWith eqBid env e1 e2 in (env', eq) (CmmAssign r1 e1, CmmAssign r2 e2) -> let eq = r1 == r2 && eqExprWith eqBid env e1 e2 in (env', eq) (CmmStore l1 r1, CmmStore l2 r2) -> let eq = eqExprWith eqBid env l1 l2 && eqExprWith eqBid env r1 r2 in (env', eq) -- result registers aren't compared since they are binding occurrences (CmmUnsafeForeignCall t1 _ a1, CmmUnsafeForeignCall t2 _ a2) -> let eq = t1 == t2 && and (zipWith (eqExprWith eqBid env) a1 a2) in (env', eq) _ -> (env, False) where env' = List.foldl' (\acc (ra,rb) -> addToUFM acc ra rb) emptyUFM $ List.zip defd_a defd_b defd_a = foldLocalRegsDefd dflags (flip (:)) [] a defd_b = foldLocalRegsDefd dflags (flip (:)) [] b eqExprWith :: (BlockId -> BlockId -> Bool) -> LocalRegMapping -> CmmExpr -> CmmExpr -> Bool eqExprWith eqBid env = eq where CmmLit l1 `eq` CmmLit l2 = eqLit l1 l2 CmmLoad e1 _ `eq` CmmLoad e2 _ = e1 `eq` e2 CmmReg r1 `eq` CmmReg r2 = r1 `eqReg` r2 CmmRegOff r1 i1 `eq` CmmRegOff r2 i2 = r1 `eqReg` r2 && i1==i2 CmmMachOp op1 es1 `eq` CmmMachOp op2 es2 = op1==op2 && es1 `eqs` es2 CmmStackSlot a1 i1 `eq` CmmStackSlot a2 i2 = eqArea a1 a2 && i1==i2 _e1 `eq` _e2 = False xs `eqs` ys = and (zipWith eq xs ys) -- See Note [Equivalence up to local registers in CBE] CmmLocal a `eqReg` CmmLocal b | Just a' <- lookupUFM env a = a' == b a `eqReg` b = a == b eqLit (CmmBlock id1) (CmmBlock id2) = eqBid id1 id2 eqLit l1 l2 = l1 == l2 eqArea Old Old = True eqArea (Young id1) (Young id2) = eqBid id1 id2 eqArea _ _ = False -- Equality on the body of a block, modulo a function mapping block -- IDs to block IDs. eqBlockBodyWith :: DynFlags -> (BlockId -> BlockId -> Bool) -> CmmBlock -> CmmBlock -> Bool eqBlockBodyWith dflags eqBid block block' {- | equal = pprTrace "equal" (vcat [ppr block, ppr block']) True | otherwise = pprTrace "not equal" (vcat [ppr block, ppr block']) False -} = equal where (_,m,l) = blockSplit block nodes = filter (not . dont_care) (blockToList m) (_,m',l') = blockSplit block' nodes' = filter (not . dont_care) (blockToList m') (env_mid, eqs_mid) = List.mapAccumL (\acc (a,b) -> eqMiddleWith dflags eqBid acc a b) emptyUFM (List.zip nodes nodes') equal = and eqs_mid && eqLastWith eqBid env_mid l l' eqLastWith :: (BlockId -> BlockId -> Bool) -> LocalRegMapping -> CmmNode O C -> CmmNode O C -> Bool eqLastWith eqBid env a b = case (a, b) of (CmmBranch bid1, CmmBranch bid2) -> eqBid bid1 bid2 (CmmCondBranch c1 t1 f1 l1, CmmCondBranch c2 t2 f2 l2) -> eqExprWith eqBid env c1 c2 && l1 == l2 && eqBid t1 t2 && eqBid f1 f2 (CmmCall t1 c1 g1 a1 r1 u1, CmmCall t2 c2 g2 a2 r2 u2) -> t1 == t2 && eqMaybeWith eqBid c1 c2 && a1 == a2 && r1 == r2 && u1 == u2 && g1 == g2 (CmmSwitch e1 ids1, CmmSwitch e2 ids2) -> eqExprWith eqBid env e1 e2 && eqSwitchTargetWith eqBid ids1 ids2 -- result registers aren't compared since they are binding occurrences (CmmForeignCall t1 _ a1 s1 ret_args1 ret_off1 intrbl1, CmmForeignCall t2 _ a2 s2 ret_args2 ret_off2 intrbl2) -> t1 == t2 && and (zipWith (eqExprWith eqBid env) a1 a2) && s1 == s2 && ret_args1 == ret_args2 && ret_off1 == ret_off2 && intrbl1 == intrbl2 _ -> False eqMaybeWith :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool eqMaybeWith eltEq (Just e) (Just e') = eltEq e e' eqMaybeWith _ Nothing Nothing = True eqMaybeWith _ _ _ = False -- | Given a block map, ensure that all "target" blocks are covered by -- the same ticks as the respective "source" blocks. This not only -- means copying ticks, but also adjusting tick scopes where -- necessary. copyTicks :: LabelMap BlockId -> CmmGraph -> CmmGraph copyTicks env g | mapNull env = g | otherwise = ofBlockMap (g_entry g) $ mapMap copyTo blockMap where -- Reverse block merge map blockMap = toBlockMap g revEnv = mapFoldWithKey insertRev M.empty env insertRev k x = M.insertWith (const (k:)) x [k] -- Copy ticks and scopes into the given block copyTo block = case M.lookup (entryLabel block) revEnv of Nothing -> block Just ls -> foldr copy block $ mapMaybe (flip mapLookup blockMap) ls copy from to = let ticks = blockTicks from CmmEntry _ scp0 = firstNode from (CmmEntry lbl scp1, code) = blockSplitHead to in CmmEntry lbl (combineTickScopes scp0 scp1) `blockJoinHead` foldr blockCons code (map CmmTick ticks) -- Group by [Label] groupByLabel :: [(Key, a)] -> [(Key, [a])] groupByLabel = go (TM.emptyTM :: TM.ListMap UniqDFM a) where go !m [] = TM.foldTM (:) m [] go !m ((k,v) : entries) = go (TM.alterTM k' adjust m) entries where k' = map getUnique k adjust Nothing = Just (k,[v]) adjust (Just (_,vs)) = Just (k,v:vs) groupByInt :: (a -> Int) -> [a] -> [[a]] groupByInt f xs = nonDetEltsUFM $ List.foldl' go emptyUFM xs -- See Note [Unique Determinism and code generation] where go m x = alterUFM (Just . maybe [x] (x:)) m (f x)