Code Review Stack Exchange is a question and answer site for peer programmer code reviews. Join them; it only takes a minute:

Sign up
Here's how it works:
  1. Anybody can ask a question
  2. Anybody can answer
  3. The best answers are voted up and rise to the top

I've made an attempt at writing a skip list in Python. I'm using NumPy to generate geometric random variables, but since that's a bit of a heavy dependency to drag around I could easily implement that myself.

I based my implementation on the basic (none of the improvements such as capping node height etc.) algorithm here.

What do you think needs improvement?

import numpy as np

class SkipList:

    def __init__(self, p=0.5):
        """
        Create a Skiplist object.

        >>>> l = Skiplist()                                  # An empty skip list
        >>>> l = Skiplist.from_iter(zip(range(5), range(5))) # A skip list from an iterable

        """
        self.p = p

        self.head = SkipList.Node()
        self.max_height = 1

        self.__length = 0

    def from_iter(it, p=0.5):
        """
        Create a SkipList from an iterable of (Key, Value) tuples
        """
        s = SkipList(p=p)
        for k, v in it:
            s.insert(k, v)

        return s

    def __getitem__(self, key):
        curr = self.head
        for level in range(self.max_height - 1, -1, -1):
            while curr.forward[level] and curr.forward[level].key < key:
                curr = curr.forward[level]

        res = curr.forward[0]
        if res and res.key == key:
            return res.value
        else:
            raise KeyError("Key {} not found".format(key))

    def __setitem__(self, key, value):
        """
        If the key is already present, the current value will be overwritten with the new value.
        """
        new_node = None
        curr = self.head
        update = [None for _ in range(self.max_height)]
        for level in range(self.max_height - 1, -1, -1):
            while curr.forward[level] and curr.forward[level].key < key:
                curr = curr.forward[level]

            update[level] = curr

        if curr.forward[0] and curr.forward[0].key == key:
            curr.forward[0].value = value
        else:
            height = np.random.geometric(self.p)
            new_forward = [n.forward[l] for l, n in enumerate(update[:height])]

            if height > self.max_height:
                new_forward += [None for _ in range(self.max_height, height)]
                self.head.forward += [None for _ in range(self.max_height, height)]
                update += [self.head for l in range(self.max_height, height)]
                self.max_height = height

            new_node = SkipList.Node(key=key, value=value, forward=new_forward)

            for l, n in enumerate(update[:height]):
                n.forward[l] = new_node

            self.__length += 1

    def __delitem__(self, key):
        curr = self.head
        update = [None for _ in range(self.max_height)]
        for level in range(self.max_height - 1, -1, -1):
            while curr.forward[level] and curr.forward[level].key < key:
                curr = curr.forward[level]

            update[level] = curr

        del_node = curr.forward[0]
        if del_node and del_node.key == key:
            for l, f in enumerate(del_node.forward):
                update[l].forward[l] = f
            self.__length -= 1
        else:
            raise KeyError("Key {} not found".format(key))

    def items(self):
        """
        Generator in the style of dict.items
        """
        curr = self.head.forward[0]
        while curr:
            yield (curr.key, curr.value)
            curr = curr.forward[0]

    def __contains__(self, key):
        try:
            self[key]
        except KeyError:
            return False

        return True

    def __iter__(self):
        curr = self.head.forward[0]
        while curr:
            yield curr.key
            curr = curr.forward[0]

    def __len__(self):
        return self.__length

    def __eq__(self, other):
        if len(self) != len(other):
            return False

        for (k1, v1), (k2, v2) in zip(self.items(), other.items()):
            if not (k1 == k2 and v1 == v2):
                return False

        return True

    class Node:
        def __init__(self, key=None, value=None, forward=None):
            if forward is None:
                forward = [None]

            self.key = key
            self.value = value
            self.forward = forward
share|improve this question

Remove repetition

You have almost identical code:

def items(self):
    """
    Generator in the style of dict.items
    """
    curr = self.head.forward[0]
    while curr:
        yield (curr.key, curr.value)
        curr = curr.forward[0]

def __iter__(self):
    curr = self.head.forward[0]
    while curr:
        yield curr.key
        curr = curr.forward[0]

You may avoid the repetition writing:

def __iter__(self):
    for (key, _) in self.items:
        yield key

The nested loops:

    for level in range(self.max_height - 1, -1, -1):
        while curr.forward[level] and curr.forward[level].key < key:
            curr = curr.forward[level]

Are repeated identical 3 times, extract them into a function.

Use the all built-in

You do not need a manual for loop in __eq__:

def __eq__(self, other):
    if len(self) != len(other):
        return False

    return all(self_pair = other_pair
                 for self_pair, other_pair in zip(self.items(), other.items())

all and avoiding tuple unpacking is closer to how you would describe the function in English (all pairs should be equal)

You may also use and instead of a separate if

def __eq__(self, other):
    return len(self) == len(other) and \
           all(self_pair = other_pair
                 for self_pair, other_pair in zip(self.items(), other.items())

It makes the code even nearer to English (The length should be equal and all pairs should be equal)

share|improve this answer
    
Thanks! I've made the changes. One thing I'm unsure about is that the function I created to replace the repated loop code builds the update list even when it's not necessary (i.e. for getitem). – Davis Yoshida Jan 31 at 3:56

1. Review

  1. The goal here seems to be create a class presenting the mutable mapping interface (so that it can be used like a dictionary). But the implementation is not complete: in addition to the methods implemented in the SkipList class, a mutable mapping also has the methods keys, values, get, pop, popitem, clear, update and setdefault.

    The easiest way to implement the full mutable mapping interface is to inherit from the collections.abc.MutableMapping abstract base class. The idea is that you implement __getitem__, __setitem__, __delitem__, __iter__, and __len__ methods, and the MutableMapping class implements everything else in terms of those.

    (For efficiency, you'll eventually want to implement your own items and values methods, as the default implementations of these call __iter__ and then look up each key, resulting in a \$Ω(n \log n)\$ algorithm instead of \$O(n)\$. But you could live with the default implementations to start with: a factor of \$\log n\$ is not a big deal.)

  2. The class is named SkipList but the examples in the docstring use the name Skiplist. The doctest module can be used to check examples in docstrings.

  3. The from_iter method doesn't work, as it's missing a self parameter. Presumably you forgot to decorate it with @staticmethod.

    But it would be better to avoid the from_iter method altogether by taking an optional mapping or iterator of (key, value) pairs in the __init__ method. This matches the interface to other mapping classes.

  4. Instead of:

    range(self.max_height - 1, -1, -1)
    

    write:

    reversed(range(self.max_height))
    
  5. This loop:

    while curr.forward[level] and curr.forward[level].key < key:
        curr = curr.forward[level]
    

    requires three lookups of curr.forward[level] on each iteration. Instead, look it up once and cache it in a local variable:

    next = curr.forward[level]
    while next and next.key < key:
        curr = next
        next = curr.forward[level]
    

    Note that when the loop over level exits, the value of next is exactly the value of res, so we can avoid the lookup of the latter.

  6. I think that prev and curr are better names than curr and next.

  7. Instead of:

    update = [None for _ in range(self.max_height)]
    

    write:

    update = [None] * self.max_height
    
  8. It's not clear why you use __length rather than length. Double underscores are used when you need to avoid conflict between attributes of the same name in two or more classes, when combining the classes via inheritance. But this use case doesn't apply here. If you want to indicate that the attribute is private to the class, it's conventional to use a single underscore.

  9. In __setitem__ the number of levels can increase arbitrarily, depending on the random number. Pugh says (page 671):

    If we generate a random level that is more than one greater than the current maximum level in the list, we [could] simply use one plus the current maximum level in the list as the level of the new node. In practice and intuitively, this change seems to work well. However, it totally destroys our ability to analyze the resulting algorithms, since the level of a node is no longer completely random. Although programmers may implement this method, purists should avoid it.

    As I'm a programmer rather than a purist, I would make this change.

  10. The same or similar search logic appears in the __getitem__, __setitem__ and __delitem__ methods. This is risky (it would be easy to get this wrong in just one place) and hard to maintain (you have to edit all three methods). It would be a good idea to put the common code in its own method.

2. Revised code

import numpy as np
from collections.abc import Mapping, MutableMapping

class SkipListNode:
    def __init__(self, key=None, value=None, forward=None):
        if forward is None:
            forward = [None]
        self.key = key
        self.value = value
        self.forward = forward

class SkipList(MutableMapping):
    """A mutable mapping implemented using a skip list."""

    def __init__(self, iterable_or_mapping=None, p=0.5):
        """Create a SkipList.

        SkipList() -> new empty skiplist
        SkipList(mapping) -> new skiplist initialized from a mapping
        SkipList(iterable) -> new skiplist initialized from iterable
            of (key, value) pairs

        Keyword argument p (default 0.5) gives the probability that an
        item appearing in layer i will also appear in layer i+1.

        For example:

        >>> l = SkipList(zip(range(5), 'abcde'))
        >>> l[0]
        'a'
        >>> 3 in l
        True
        >>> l[2] = 'C'
        >>> del l[0]
        >>> len(l)
        4
        >>> sorted(l.items())
        [(1, 'b'), (2, 'C'), (3, 'd'), (4, 'e')]

        """
        self._p = p
        self._head = SkipListNode()
        self._max_height = 1
        self._length = 0
        if iterable_or_mapping is not None:
            self.update(iterable_or_mapping)

    def _search(self, key):
        prev = self._head
        update = [None] * self._max_height
        for level in reversed(range(self._max_height)):
            curr = prev.forward[level]
            while curr and curr.key < key:
                prev = curr
                curr = prev.forward[level]
            update[level] = prev
        return update, curr

    def __getitem__(self, key):
        _, curr = self._search(key)
        if curr and curr.key == key:
            return curr.value
        else:
            raise KeyError(key)

    def __setitem__(self, key, value):
        update, curr = self._search(key)
        if curr and curr.key == key:
            curr.value = value
        else:
            height = np.random.geometric(self._p)
            forward = [n.forward[l] for l, n in enumerate(update[:height])]
            if height > self._max_height:
                forward.append(None)
                self._head.forward.append(None)
                update.append(self._head)
                self._max_height += 1
            node = SkipListNode(key=key, value=value, forward=forward)
            for l, n in enumerate(update[:height]):
                n.forward[l] = node
            self._length += 1

    def __delitem__(self, key):
        update, curr = self._search(key)
        if curr and curr.key == key:
            for l, n in enumerate(curr.forward):
                update[l].forward[l] = n
            self._length -= 1
        else:
            raise KeyError(key)

    def __iter__(self):
        prev = self._head.forward[0]
        while prev:
            yield prev.key
            prev = prev.forward[0]

    def __len__(self):
        return self._length
share|improve this answer

Your Answer

 
discard

By posting your answer, you agree to the privacy policy and terms of service.

Not the answer you're looking for? Browse other questions tagged or ask your own question.