summaryrefslogtreecommitdiff
path: root/compiler/GHC/Data/UnionFind.hs
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/GHC/Data/UnionFind.hs')
-rw-r--r--compiler/GHC/Data/UnionFind.hs91
1 files changed, 91 insertions, 0 deletions
diff --git a/compiler/GHC/Data/UnionFind.hs b/compiler/GHC/Data/UnionFind.hs
new file mode 100644
index 0000000000..21687b5a09
--- /dev/null
+++ b/compiler/GHC/Data/UnionFind.hs
@@ -0,0 +1,91 @@
+{- Union-find data structure compiled from Distribution.Utils.UnionFind -}
+module GHC.Data.UnionFind where
+
+import GHC.Prelude
+import Data.STRef
+import Control.Monad.ST
+import Control.Monad
+
+-- | A variable which can be unified; alternately, this can be thought
+-- of as an equivalence class with a distinguished representative.
+newtype Point s a = Point (STRef s (Link s a))
+ deriving (Eq)
+
+-- | Mutable write to a 'Point'
+writePoint :: Point s a -> Link s a -> ST s ()
+writePoint (Point v) = writeSTRef v
+
+-- | Read the current value of 'Point'.
+readPoint :: Point s a -> ST s (Link s a)
+readPoint (Point v) = readSTRef v
+
+-- | The internal data structure for a 'Point', which either records
+-- the representative element of an equivalence class, or a link to
+-- the 'Point' that actually stores the representative type.
+data Link s a
+ -- NB: it is too bad we can't say STRef Int#; the weights remain boxed
+ = Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
+ | Link {-# UNPACK #-} !(Point s a)
+
+-- | Create a fresh equivalence class with one element.
+fresh :: a -> ST s (Point s a)
+fresh desc = do
+ weight <- newSTRef 1
+ descriptor <- newSTRef desc
+ Point `fmap` newSTRef (Info weight descriptor)
+
+-- | Flatten any chains of links, returning a 'Point'
+-- which points directly to the canonical representation.
+repr :: Point s a -> ST s (Point s a)
+repr point = readPoint point >>= \r ->
+ case r of
+ Link point' -> do
+ point'' <- repr point'
+ when (point'' /= point') $ do
+ writePoint point =<< readPoint point'
+ return point''
+ Info _ _ -> return point
+
+-- | Return the canonical element of an equivalence
+-- class 'Point'.
+find :: Point s a -> ST s a
+find point =
+ -- Optimize length 0 and 1 case at expense of
+ -- general case
+ readPoint point >>= \r ->
+ case r of
+ Info _ d_ref -> readSTRef d_ref
+ Link point' -> readPoint point' >>= \r' ->
+ case r' of
+ Info _ d_ref -> readSTRef d_ref
+ Link _ -> repr point >>= find
+
+-- | Unify two equivalence classes, so that they share
+-- a canonical element. Keeps the descriptor of point2.
+union :: Point s a -> Point s a -> ST s ()
+union refpoint1 refpoint2 = do
+ point1 <- repr refpoint1
+ point2 <- repr refpoint2
+ when (point1 /= point2) $ do
+ l1 <- readPoint point1
+ l2 <- readPoint point2
+ case (l1, l2) of
+ (Info wref1 dref1, Info wref2 dref2) -> do
+ weight1 <- readSTRef wref1
+ weight2 <- readSTRef wref2
+ -- Should be able to optimize the == case separately
+ if weight1 >= weight2
+ then do
+ writePoint point2 (Link point1)
+ -- The weight calculation here seems a bit dodgy
+ writeSTRef wref1 (weight1 + weight2)
+ writeSTRef dref1 =<< readSTRef dref2
+ else do
+ writePoint point1 (Link point2)
+ writeSTRef wref2 (weight1 + weight2)
+ _ -> error "UnionFind.union: repr invariant broken"
+
+-- | Test if two points are in the same equivalence class.
+equivalent :: Point s a -> Point s a -> ST s Bool
+equivalent point1 point2 = liftM2 (==) (repr point1) (repr point2)
+