Monday, November 03, 2008

python interval tree

EDIT: added a couple points inline.

I'm obsessed with trees lately -- of the CS variety, not the plant variety. Although we are studying poplar, so I'll be using trees to study trees.
I'd tried a couple times to implement an interval tree from scratch following the Wikipedia entry.
Today I more or less did that in python. It's the simplest possible form. There's no insertion (though that's trivial to add), it just takes a list of 'things' with start and stop attributes and creates a tree with a .find() method.
The wikipedia entry that baffled me was about storing 2 copies of each node's intervals--one sorted by start and the other by stop. I didn't do that as I think in most cases it won't improve lookup time. I figure if you have 1 million elements and a tree of depth 16, then you have on average 15 intervals per node (actually fewer since there are the non-leaf nodes). So I just brute force each of those nodes and move to the next. I think that increases the worst-case, but makes no difference in actual search time--with the benefit of halving storage.

EDIT: now the version in my repo keeps the intervals sorted by start, so it can avoid doing the brute for search at each node during a search when search.stop < node.intervals[0].start. This did improve performance.

The tree class takes a list of intervals and calculates a center point. From there it partitions them into left, overlapping, and right in terms of their relation to the center point. Overlapping are assigned to the current node, and left and right are recursively partitioned in that fashion until there are only `minbucket` intervals per node, or the specified `depth` has been reached AND there are fewer intervals than `maxbucket`. So a tree can have a greater `depth` than requested if it would otherwise have more than `maxbucket` intervals in a single node. The Wikipedia version doesn't have maxbucket or minbucket...

EDIT: the maxbucket actually only works on leaf-nodes, and has no effect otherwise.

I'm sure that's painfully obvious for anyone who's ever taken a CS course, but it was foggy at best for me until I implemented. Below is the entire implementation:

class IntervalTree(object):
__slots__ = ('intervals', 'left', 'right', 'center')

def __init__(self, intervals, depth=16, minbucket=96, _extent=None, maxbucket=4096):

depth -= 1
if (depth == 0 or len(intervals) < minbucket) and len(intervals) > maxbucket:
self.intervals = intervals
self.left = self.right = None
return

left, right = _extent or \
(min(i.start for i in intervals), max(i.stop for i in intervals))
center = (left + right) / 2.0


self.intervals = []
lefts, rights = [], []


for interval in intervals:
if interval.stop < center:
lefts.append(interval)
elif interval.start > center:
rights.append(interval)
else: # overlapping.
self.intervals.append(interval)

self.left = lefts and IntervalTree(lefts, depth, minbucket, (left, center)) or None
self.right = rights and IntervalTree(rights, depth, minbucket, (center, right)) or None
self.center = center


def find(self, start, stop):
"""find all elements between (or overlapping) start and stop"""
overlapping = [i for i in self.intervals if i.stop >= start
and i.start <= stop]

if self.left and start <= self.center:
overlapping += self.left.find(start, stop)

if self.right and stop >= self.center:
overlapping += self.right.find(start, stop)

return overlapping

Only 45 lines of code. I had added a couple extra attributes so that searching could do fewer checks, but it only improved performance by ~15% and I liked the simplicity. One way to improve the search speed, and the distribution on skewed data would be to sort the intervals at the top node, so they'd then be sorted for all other nodes. Then instead of using center = (left + right)/2, It'd could use the center point of the center interval at each node. That would also allow short-circuiting the brute-force search at the top of the find method with something like:

if not (start > self.intervals[-1].stop and stop < self.intervals[0].start):
overlapping = [ .. list comprehension ]

But all told, that adds 5 or so lines of code. Oh, and depending on how it's used, it's between 15 and 25 times faster than brute-force search.

EDIT: I added the above check, but it can only do the 2nd comparison "stop < self.intervals.start as the first is invalid given a very long interval. Regarding speed, the smaller the search window, the better the performance improvement. The code is now > 20 times as fast brute force for a very (speaking in terms of looking for genomic features) large swath of 100K. with a search space of 50K, it's 50+ times as fast as linear search.

The full code (including a docstring with homer simpson quote) is in my google code repo. If I've made obvious mistakes or you have improvements, I'd be glad to know them.