I'd like to get some feedback on my AVL tree implementation. I appreciate any comments related to style and performance, particularly if something can be written in a more concise or idiomatic way. If you see a way to refactor delete
into something shorter and with less repetition, please let me know. I am also interested to know if selective application of strictness declarations can help at all.
More generally, I would like some advice on how to write tests for this data structure. So far, I have worked out bugs by loading this module into ghci, and inserting, mapping, or deleting until something wasn't right and fixing that.
In a similar vein, I have heard that proof engines like Coq can be used to guarantee certain properties of Haskell programs. If I wanted do this with my AVL tree implementation, what should I look into?
One thing that I couldn't figure out is how to build a map (fully polymorphic in the non-key type) from AVLTree without re-writing the data constructor and functions. I don't know if it can be done easily, but it occurs to me that if I made Ord a => (a, b)
into an instance of a partial ordering and changed the pattern matches to handle non-comparable values, with a few small changes, I could pass key-value pairs to the Node constructor. If you have any ideas about how to do this elegantly, I would love to hear them.
I haven't included the derivation for rotateLeft
and rotateRight
, since they're somewhat long and tedious, but I can include them if you want.
My pretty printer is fairly long and not terribly elegant, so I'll submit that for review in an other question if this gets any attention.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE DeriveFoldable #-}
-- I opted not to use MultiWayIf because 'case () of' works well with
-- emacs' haskell-mode indentation while MultiWayIf does not. There are
-- some instances where LambdaCase would have saved a few characters,
-- but I thought it would be distracting for someone nor familiar with
-- it.
-- | To use this module without naming conflicts, use the following:
-- import AVLTree (AVLTree)
-- import qualified AVLTree as AVLTree
module AVLTree ( AVLTree(..)
, fromList
, toList
, insert
, delete
, search
, map
) where
import Prelude hiding (map)
import qualified Data.List as List (map)
import Data.Maybe (fromJust)
--import Data.Ord (comparing)
data AVLTree a where
Nil :: Ord a => AVLTree a
Node :: Ord a => a -> Int -> AVLTree a -> AVLTree a -> AVLTree a
-- | Instance derivations and definitions
--deriving instance Foldable AVLTree
deriving instance Show a => Show (AVLTree a)
{- Because the node data constructor has value to the left of the left
and right subtrees, when Foldable is derived, folds start at the root
node and work down the tree instead of traversing the nodes of the tree
in order. Rather than define the data constructor with the value between
the left and right nodes, I defined a new Foldable instance. -}
instance Foldable AVLTree where
-- | foldMap :: Monoid m => (a -> m) -> AVLTree a -> m
foldMap f t = case t of
Nil -> mempty
Node x _ l r -> mappend (f x) (mappend (foldMap f l) (foldMap f r))
-- | foldr :: (a -> b -> b) -> b -> AVLTree a -> b
foldr f z t = case t of
Nil -> z
Node x _ l r -> foldr f (f x (foldr f z r)) l
-- | foldl :: (b -> a -> b) -> b -> AVLTree a -> b
foldl f z t = case t of
Nil -> z
Node x _ l r -> foldl f (f (foldl f z l) x) r
{- instance Show a => Show (AVLTree a) where
-- | show :: AVLTree a -> String
show Nil = "Nil"
show (Node x _ Nil Nil) = show x
show t = draw t -}
-- | Accessors
balanceFactor :: AVLTree a -> Int
balanceFactor Nil = 0
balanceFactor (Node _ b _ _) = b
value :: AVLTree a -> Maybe a
value Nil = Nothing
value (Node x _ _ _) = Just x
left :: AVLTree a -> Maybe (AVLTree a)
left Nil = Nothing
left (Node _ _ l _) = Just l
right :: AVLTree a -> Maybe (AVLTree a)
right Nil = Nothing
right (Node _ _ _ r) = Just r
-- | Check type of data constructor
nil :: AVLTree a -> Bool
nil Nil = True
nil _ = False
-- | Construct a leaf containing the value x.
leaf :: Ord a => a -> AVLTree a
leaf x = Node x 0 Nil Nil
-- | List functions
toList :: AVLTree a -> [a]
toList = foldr (:) []
fromList :: Ord a => [a] -> AVLTree a
fromList = foldr insert Nil
-- | Insert a value into the tree.
-- In Haskell's type system, Ord a implies Eq a, so if x is neither
-- greater or lesser than y, then x == y.
insert :: a -> AVLTree a -> AVLTree a
insert x t = case t of
Nil -> leaf x
Node y _ Nil Nil -> case (compare x y) of
LT -> Node y (-1) (leaf x) Nil
GT -> Node y 1 Nil (leaf x)
EQ -> t
Node y b l r -> case (compare x y) of
LT -> case () of
_ | nil l -> Node y 0 (leaf x) r
_ | b' == -2 -> case () of
_ | bl' > 0 -> rotateLeftRight t'
_ | otherwise -> rotateRight t'
_ | otherwise -> t'
where
t' = Node y b' l' r
b' = if balanceFactor l == 0
then b - abs bl'
else b
l' = insert x l
bl' = balanceFactor l'
GT -> case () of
_ | nil r -> Node y 0 l (leaf x)
_ | b' == 2 -> case () of
_ | br' < 0 -> rotateRightLeft t'
_ | otherwise -> rotateLeft t'
_ | otherwise -> t'
where
t' = Node y b' l r'
b' = if balanceFactor r == 0
then b + abs br'
else b
r' = insert x r
br' = balanceFactor r'
EQ -> t
-- | Delete a value from the tree.
delete :: a -> AVLTree a -> AVLTree a
delete x t = case t of
Nil -> Nil
Node y _ Nil Nil -> if x == y then Nil else t
Node y b l r -> case (compare x y) of
LT -> case () of
_ | nil l -> t
_ | b' == 2 -> case () of
_ | br == -1 -> rotateRightLeft t'
_ | otherwise -> rotateLeft t'
_ | otherwise -> t'
where
t' = Node y b' l' r
b' = if nil l' || (abs bl == 1 && bl' == 0)
then b + 1
else b
l' = delete x l
bl = balanceFactor l
br = balanceFactor r
bl' = balanceFactor l'
GT -> case () of
_ | nil r -> t
_ | b' == -2 -> case () of
_ | bl == 1 -> rotateLeftRight t'
_ | otherwise -> rotateRight t'
_ | otherwise -> t'
where
t' = Node y b' l r'
b' = if nil r' || (abs br == 1 && br' == 0)
then b - 1
else b
r' = delete x r
bl = balanceFactor l
br = balanceFactor r
br' = balanceFactor r'
EQ -> case (compare b 0) of
LT -> case () of
_ | nil l -> r
_ | b' == 2 -> case () of
_ | br == -1 -> rotateRightLeft t'
_ | otherwise -> rotateLeft t'
_ | otherwise -> t'
where
t' = Node x' b' l' r
x' = findMax l
b' = if nil l' || (abs bl == 1 && bl' == 0)
then b + 1
else b
l' = delete x' l
br = balanceFactor r
bl = balanceFactor l
bl' = balanceFactor l'
_ -> case () of
_ | nil r -> l
_ | b' == -2 -> case () of
_ | bl == 1 -> rotateLeftRight t'
_ | otherwise -> rotateRight t'
_ | otherwise -> t'
where
t' = Node x' b' l r'
x' = findMin r
b' = if nil r' || (abs br == 1 && br' == 0)
then b - 1
else b
r' = delete x' r
bl = balanceFactor l
br = balanceFactor r
br' = balanceFactor r'
-- | Determine whether the tree contains an element.
-- Could also be written 'search :: a -> AVLTree a -> Maybe a' for cases
search :: a -> AVLTree a -> Bool
search x t = case t of
Nil -> False
Node y _ l r -> case (compare x y) of
EQ -> True
LT -> search x l
GT -> search x r
-- | Map a function over the tree.
-- Since the values in a binary search tree are ordered by definition,
-- mapping a function f to every value in the tree without changing
-- its structure only produces another ordered tree of the same type
-- if f is strictly increasing. Since we cannot guarantee this, map
-- cannot be written the following way:
-- map _ Nil = Nil
-- map f (Node b x l r) = Node b (f x) (fmap f l) (fmap f r)
-- For the same reason, no binary search tree with ordered values can
-- be a functor, unless the tree is a mapping with ordered keys.
map :: Ord b => (a -> b) -> AVLTree a -> AVLTree b
map f = fromList . fmap f . toList
-- | Helper functions.
-- Note: pattern match failures are deliberately left undefined, since
-- if a function is applied in a context where the pattern match fails,
-- the program should fail and return an error insted of continuing.
findMin :: AVLTree a -> a
findMin t = case t of
Nil -> undefined
Node x _ Nil _ -> x
Node _ _ l _ -> findMin l
findMax :: AVLTree a -> a
findMax t = case t of
Nil -> undefined
Node x _ _ Nil -> x
Node _ _ _ r -> findMax r
-- | Rotate right:
-- Z
-- / Y
-- Y -> / \
-- / X Z
-- X
rotateRight :: AVLTree a -> AVLTree a
rotateRight t = case t of
{-Node z (-2) (Node y (-1) ll lr) r
-> Node y 0 ll (Node z 0 lr r)
_ -> undefined-}
Node z b (Node y bl ll lr) r
-> Node y b' ll (Node z br' lr r)
where
b' = 1 + max bl (1 + b + max 0 bl)
br' = b + 1 - min 0 bl
_ -> undefined
-- | Rotate left:
-- X
-- \ Y
-- Y -> / \
-- \ X Z
-- Z
rotateLeft :: AVLTree a -> AVLTree a
rotateLeft t = case t of
{-Node x 2 l (Node y 1 rl rr)
-> Node y 0 (Node x 0 l rl) rr
_ -> undefined-}
Node x b l (Node y br rl rr)
-> Node y b' (Node x bl' l rl) rr
where
b' = -1 + min br (b - 1 + min 0 br)
bl' = b - 1 - max 0 br
_ -> undefined
-- | Rotate left child left, then rotate root right:
-- Z Z
-- / / Y
-- X -> Y -> / \
-- \ / X Z
-- Y X
rotateLeftRight :: AVLTree a -> AVLTree a
rotateLeftRight t = case t of
{-Node z (-2) (Node x 1 ll (Node y 0 lrl lrr)) r
-> Node y 0 (Node x 0 ll lrl) (Node z 0 lrr r)
_ -> undefined-}
Node z b l r -> rotateRight (Node z b' l' r)
where
l' = rotateLeft l
b' = case (compare bl 1) of
EQ -> b
GT -> b + 1
LT -> b - 1
bl = balanceFactor l
_ -> undefined
-- | Rotate right child right, then rotate root left:
-- X X
-- \ \ Y
-- Z -> Y -> / \
-- / \ X Z
-- Y Z
rotateRightLeft :: AVLTree a -> AVLTree a
rotateRightLeft t = case t of
{-Node x 2 l (Node z (-1) (Node y 0 rll rlr) rr)
-> Node y 0 (Node x 0 l rll) (Node z 0 rlr rr)
_ -> undefined-}
Node x b l r -> rotateLeft (Node x b' l r')
where
r' = rotateRight r
b' = case (compare br (-1)) of
EQ -> b
LT -> b - 1
GT -> b + 1
br = balanceFactor r
_ -> undefined