python sweetness

  • ask me anything
  • rss
  • archive
  • Fast PyPy-compatible ordered map in 89 lines of Python

    Skip lists are a freaking awesome data structure you should go and read about today. Despite a full implementation fitting comfortably on a single A4 page they still manage to perform well compared to significantly more complex tree structures. That aside, who doesn’t want to use a data structure that requires random.random()  in their code??

    [Edit: as is unfortunately typical of Wikipedia CS articles, the skip list article manages to say everything without explaining anything, refer to the original paper to learn more.]

    Below is an implementation done in Python. The great innovation here is that unlike a dictionary, skip lists keep their items permanently in order, making it possible to efficiently walk the collection backwards and forwards, from start or end, or from an arbitrary key, much more efficiently than would ever be possible with a dict.

    I needed this class to implement efficient indexing of an in-memory collection, where updating a single item would also update its position in a sorted list, without the cost of having to re-build or re-sort that list on every update. Inserting 938,075 string keys of ~8-10 bytes manages 75k inserts/second on PyPy on a Core 2 Duo, while still achieving a not-too-shabby 23k/second on CPython. Lookup is just as awesome: nearly 100k random searches/second on PyPy and 44k/second on CPython.

    As many will attest, it’s easy to live life without an ordered map in Python, but the moment you need one Python starts to suck really damn hard. This should be built into the language somehow. Instead we’re stuck with 100 different shabby implementations of almost the same thing on Cheese Shop, none of which quite provide everything, not to mention the majority are C-based implementations, unsuitable for PyPy.

    Well, here’s my attempt. If you’re reading this because you arrived via Google, consider this code offered under the terms of the MIT license. Use it as you please.

    class SkipList:
        """Doubly linked non-indexable skip list, providing logarithmic insertion
        and deletion. Keys are any orderable Python object.
    
            `maxsize`:
                Maximum number of items expected to exist in the list. Performance
                will degrade when this number is surpassed.
        """
        def __init__(self, maxsize=65535):
            self.max_level = int(math.log(maxsize, 2))
            self.level = 0
            self.head = self._makeNode(self.max_level, None, None)
            self.nil = self._makeNode(-1, None, None)
            self.tail = self.nil
            self.head[3:] = [self.nil for x in xrange(self.max_level)]
            self._update = [self.head] * (1 + self.max_level)
            self.p = 1/math.e
    
        def _makeNode(self, level, key, value):
            node = [None] * (4 + level)
            node[0] = key
            node[1] = value
            return node
    
        def _randomLevel(self):
            lvl = 0
            max_level = min(self.max_level, self.level + 1)
            while random.random() < self.p and lvl < max_level:
                lvl += 1
            return lvl
    
        def items(self, searchKey=None, reverse=False):
            """Yield (key, value) pairs starting from `searchKey`, or the next
            greater key, or the end of the list. Subsequent iterations move
            backwards if `reverse=True`. If `searchKey` is ``None`` then start at
            either the beginning or end of the list."""
            if reverse:
                node = self.tail
            else:
                node = self.head[3]
            if searchKey is not None:
                update = self._update[:]
                found = self._findLess(update, searchKey)
                if found[3] is not self.nil:
                    node = found[3]
            idx = 2 if reverse else 3
            while node[0] is not None:
                yield node[0], node[1]
                node = node[idx]
    
        def _findLess(self, update, searchKey):
            node = self.head
            for i in xrange(self.level, -1, -1):
                key = node[3 + i][0]
                while key is not None and key < searchKey:
                    node = node[3 + i]
                    key = node[3 + i][0]
                update[i] = node
            return node
    
        def insert(self, searchKey, value):
            """Insert `searchKey` into the list with `value`. If `searchKey`
            already exists, its previous value is overwritten."""
            assert searchKey is not None
            update = self._update[:]
            node = self._findLess(update, searchKey)
            prev = node
            node = node[3]
            if node[0] == searchKey:
                node[1] = value
            else:
                lvl = self._randomLevel()
                self.level = max(self.level, lvl)
                node = self._makeNode(lvl, searchKey, value)
                node[2] = prev
                for i in xrange(0, lvl+1):
                    node[3 + i] = update[i][3 + i]
                    update[i][3 + i] = node
                if node[3] is self.nil:
                    self.tail = node
                else:
                    node[3][2] = node
    
        def delete(self, searchKey):
            """Delete `searchKey` from the list, returning ``True`` if it
            existed."""
            update = self._update[:]
            node = self._findLess(update, searchKey)
            node = node[3]
            if node[0] == searchKey:
                node[3][2] = update[0]
                for i in xrange(self.level + 1):
                    if update[i][3 + i] is not node:
                        break
                    update[i][3 + i] = node[3 + i]
                while self.level > 0 and self.head[3 + self.level][0] is None:
                    self.level -= 1
                if self.tail is node:
                    self.tail = node[2]
                return True
    
        def search(self, searchKey):
            """Return the value associated with `searchKey`, or ``None`` if
            `searchKey` does not exist."""
            node = self.head
            for i in xrange(self.level, -1, -1):
                key = node[3 + i][0]
                while key is not None and key < searchKey:
                    node = node[3 + i]
                    key = node[3 + i][0]
            node = node[3]
            if node[0] == searchKey:
                return node[1]
    

    This implementation may look ghastly at first sight, however it’s worth note that:

    • The node pointer list is reused for the node structure to save memory, otherwise a minimum of 72 bytes is wasted per node on CPython, in addition to malloc slack for 2 allocations (the list object itself, and the array of pointers). As it stands, CPython burns about 113 bytes per record, so ‘clean’ code here would potentially double the memory requirements in addition to added runtime cost.

    • Numeric, rather than symbolic indexing wins 5k lookups/second on CPython. It’s debatable whether using speed hacks like “search(key, IDX_PREV=IDX_PREV, IDX_NEXT=IDX_NEXT)” is uglier than the bare numbers themselves. Using accessor functions to pretty the code also costs quite a lot.

    The result of these optimizations is an overhead of around 44 bytes per record for our test set on amd64, compared to a Python dict. Dicts use around 69 bytes per record, SkipList uses ~113.

    If you use this code for something, please click back to the blog’s home page, then “Ask Me Anything”, then drop a comment telling me what you used it for. Thanks!

    • 2 months ago
    • 9 notes