I have a written a quadtree program in Python 2.7 to cross-correlate large catalogs with each other i.e. find the common objects in the catalogs based on their position. The problem is that it's still quite slow. Accuracy is my primary goal (No throwing out real matches, no erroneous matches, and getting the true closest match) and speed is a close second.
Summary of the code
I read one of the catalog text files into the quadtree, and the quadtree organizes the objects based on the x and y positions.
I read another catalog text file into a list.
I iterate over the list to find the closest match in the quadtree and, if they are within some specified distance I call it a true match and save that object.
Quadtree.py
import math
from bigfloat import *
import _norm
import geom_utils as gu
import Quadtree_Utilities as utils
MAX = 60
class Quadtree(object):
"""
Quadtree base class. Only functions that are agnostic to
the type of coordinate system or source object used. Must
use a subclass.
"""
def __init__(self, xmin, ymin, xmax, ymax):
self.top = Node(xmin, ymin, xmax, ymax)
self.num_subdivides = 0
self.num_insert = 0
self.num_inserttonodes = 0
self.num_matched = 0
self.num_inserttoquads = 0
self.num_nearersources = 0
def debug(self):
print "Number of subdivides: %d" % self.num_subdivides
print "Inserttonode was called %d times" % self.num_inserttonodes
print "Matched was called %d times" % self.num_matched
print "Inserttoquad was called %d times" % self.num_inserttoquads
print "Nearer sources was called %d times" % self.num_nearersources
print "Insert was called %d times" % self.num_insert
def inserttonode(self, node, source):
self.num_inserttonodes+=1
if len(node.contents) == MAX:
self.subdivide(node)
if node.q1:
self.inserttoquad(node, source)
else:
# If no subquads exist add source to the list in CONTENTS element
node.contents.append(source)
def inserttoquad(self, node, source):
self.num_inserttoquads+=1
if BigFloat(source.x) >= BigFloat(node.xmid):
if BigFloat(source.y) >= BigFloat(node.ymid):
quadrant = node.q1
else:
quadrant = node.q4
else:
if BigFloat(source.y) >= BigFloat(node.ymid):
quadrant = node.q2
else:
quadrant = node.q3
self.inserttonode(quadrant, source)
def subdivide(self, node):
self.num_subdivides+=1
node.q1 = Node(node.xmid, node.ymid, node.xmax, node.ymax)
node.q2 = Node(node.xmin, node.ymid, node.xmid, node.ymax)
node.q3 = Node(node.xmin, node.ymin, node.xmid, node.ymid)
node.q4 = Node(node.xmid, node.ymin, node.xmax, node.ymid)
# Pop the list and insert the sources as they come off
while node.contents:
self.inserttoquad(node, node.contents.pop())
def match(self, x, y):
self.num_matched+=1
return self.nearestsource(self, x, y)
def nearestsource(self, tree, x, y):
nearest = utils.Nearest()
nearest.dist = self.initial_dist(tree.top.xmax, tree.top.xmin,
tree.top.ymax, tree.top.ymin)
interest = utils.Interest(x-nearest.dist, y-nearest.dist,
x+nearest.dist, y+nearest.dist)
interest = gu.clip_box(interest.xmin, interest.ymin,
interest.xmax, interest.ymax,
tree.top.xmin, tree.top.ymin,
tree.top.xmax, tree.top.ymax)
nearest.dist = nearest.dist*nearest.dist
self.nearersource(tree, tree.top, x, y, nearest, interest)
return nearest.source
def nearersource(self, tree, node, x, y, nearest, interest):
self.num_nearersources+=1
if gu.intersecting(node.xmin, node.xmax,
node.ymin, node.ymax,
interest.xmin, interest.xmax,
interest.ymin, interest.ymax):
if node.q1 == None:
for s in node.contents:
s_dist = self.norm2(BigFloat(s.x), BigFloat(s.y), BigFloat(x), BigFloat(y))
if s_dist < nearest.dist:
nearest.source = s.source
nearest.dist = s_dist
dist = math.sqrt(s_dist)
interest.xmin = x - dist
interest.ymin = y - dist
interest.xmax = x + dist
interest.ymax = y + dist
interest = gu.clip_box(interest.xmin, interest.ymin,
interest.xmax, interest.ymax,
tree.top.xmin, tree.top.ymin,
tree.top.xmax, tree.top.ymax)
else:
self.nearersource(tree, node.q1, x, y, nearest, interest)
self.nearersource(tree, node.q2, x, y, nearest, interest)
self.nearersource(tree, node.q3, x, y, nearest, interest)
self.nearersource(tree, node.q4, x, y, nearest, interest)
class Node(object):
def __init__(self, xmin, ymin, xmax, ymax):
self.xmin = BigFloat(xmin)
self.ymin = BigFloat(ymin)
self.xmax = BigFloat(xmax)
self.ymax = BigFloat(ymax)
self.xmid = BigFloat((self.xmin + self.xmax)/2.0)
self.ymid = BigFloat((self.ymin + self.ymax)/2.0)
self.q1 = self.q2 = self.q3 = self.q4 = None
self.contents = []
class Point(object):
"""
The point of Point (heh.) is to have a uniform object that
can be passed around the Quadtree. This makes for
easy switching between equatorial and pixel coordinate
systems or different objects.
"""
def __init__(self, source, x, y):
self.source = source
self.x = BigFloat(x)
self.y = BigFloat(y)
class ScamPixelQuadtree(Quadtree):
def __init__(self, xmin, ymin, xmax, ymax):
super(ScamPixelQuadtree, self).__init__(xmin, ymin, xmax, ymax)
def insert(self, source):
self.num_insert+=1
self.inserttonode(self.top, Point(source, source.ximg, source.yimg))
def norm2(self, x1, y1, x2, y2):
return _norm.norm2(x1, y1, x2, y2)
def initial_dist(self, x2, x1, y2, y1):
return min(x2 - x1, y2 - y1)/1000.0
class ScamEquatorialQuadtree(Quadtree):
def __init__(self, xmin, ymin, xmax, ymax):
super(ScamEquatorialQuadtree, self).__init__(xmin, ymin, xmax, ymax)
def insert(self, source):
self.num_insert+=1
self.inserttonode(self.top, Point(source, source.ra, source.dec))
def norm2(self, x1, y1, x2, y2):
return _angular_dist.angular_dist2(x1, y1, x2, y2)
def initial_dist(self, x2, x1, y2, y1):
return min(BigFloat(x2) - BigFloat(x1), BigFloat(y2) - BigFloat(y1))/100.0
Where the various helper classes are,
class Interest:
def __init__(self, xmin, ymin, xmax, ymax):
self.xmin = xmin
self.ymin = ymin
self.xmax = xmax
self.ymax = ymax
class Nearest:
def __init__(self):
self.source = None
self.dist = None
The program where I do the matching, test_tree.py looks like this,
'''
Testing the quadtree
'''
import sys
import Sources
import Quadtree
import phot_utils
import _norm
def associate(list1, tree2):
dist = 2
matches = []
for entry in list1:
match = tree2.match(entry.ximg, entry.yimg)
if match != None:
if _norm.norm(entry.ximg, entry.yimg, match.ximg, match.yimg) <= dist:
matches.append(match)
return matches
with open('test_1.cat', 'r') as f:
catalog = filter(lambda line: phot_utils.no_head(line), f)
test_1_catalog = map(lambda source: Sources.SCAMSource(source), catalog)
ximg = map(lambda source: source.ximg, test_1_catalog)
yimg = map(lambda source: source.yimg, test_1_catalog)
test_1_sources = Quadtree.ScamPixelQuadtree((min(ximg)), (min(yimg)), (max(ximg)), (max(yimg)))
map(lambda line: g_sources.insert(line), test_1_catalog)
with open('test_2.cat', 'r') as f:
catalog = filter(lambda line: phot_utils.no_head(line), f)
test_2_sources = map(lambda source: Sources.SCAMSource(source), catalog)
matches = associate(i_sources, g_sources)
I used cProfile
to get an idea of how long each function is taking. The biggest time sinks are nearer/nearestsource()
which I'm sure is do due to the recursion. Are there any ways that I can cut down on the recursive calls to these functions? I know that some people use memoize decorators to optimize recursive function but I'm not sure if that makes sense for nearersource. I'm already sort of caching by keeping track of the current nearest source. I'd appreciate any insight from people more familiar with memoize decorators and recursive functions for what I can do here.
Reading the catalog into the quadtree (inserttonode
and instertoquade
) takes a fair bit of time as well, I don't know what I can do for those functions to get them to go faster.
Using the BigFloats
adds time as well. I unfortunately need the extra precision though or else entire regions in the quadtree get rejected when I do the matching.
scipy.spatial.KDTree
? – Gareth Rees Jul 14 '14 at 20:33