# BentleyOttmann sweep-line implementation # (for finding all intersections in a set of line segments) from __future__ import annotations __all__ = ( "isect_segments", "isect_polygon", # same as above but includes segments with each intersections "isect_segments_include_segments", "isect_polygon_include_segments", # for testing only (correct but slow) "isect_segments__naive", "isect_polygon__naive", ) # ---------------------------------------------------------------------------- # Main Poly Intersection # Defines to change behavior. # # Whether to ignore intersections of line segments when both # their end points form the intersection point. USE_IGNORE_SEGMENT_ENDINGS = True USE_DEBUG = True USE_VERBOSE = False # checks we should NOT need, # but do them in case we find a test-case that fails. USE_PARANOID = False # Support vertical segments, # (the bentley-ottmann method doesn't support this). # We use the term 'START_VERTICAL' for a vertical segment, # to differentiate it from START/END/INTERSECTION USE_VERTICAL = True # end defines! # ------------ # --------- # Constants X, Y = 0, 1 # ----------------------------------------------------------------------------- # Switchable Number Implementation NUMBER_TYPE = 'native' if NUMBER_TYPE == 'native': Real = float NUM_EPS = Real("1e-10") NUM_INF = Real(float("inf")) elif NUMBER_TYPE == 'decimal': # Not passing tests! import decimal Real = decimal.Decimal decimal.getcontext().prec = 80 NUM_EPS = Real("1e-10") NUM_INF = Real(float("inf")) elif NUMBER_TYPE == 'numpy': import numpy Real = numpy.float64 del numpy NUM_EPS = Real("1e-10") NUM_INF = Real(float("inf")) elif NUMBER_TYPE == 'gmpy2': # Not passing tests! import gmpy2 gmpy2.set_context(gmpy2.ieee(128)) Real = gmpy2.mpz NUM_EPS = Real(float("1e-10")) NUM_INF = gmpy2.get_emax_max() del gmpy2 else: raise Exception("Type not found") NUM_EPS_SQ = NUM_EPS * NUM_EPS NUM_ZERO = Real(0.0) NUM_ONE = Real(1.0) class Event: __slots__ = ( "type", "point", "segment", # this is just cache, # we may remove or calculate slope on the fly "slope", "span", ) + (() if not USE_DEBUG else ( # debugging only "other", "in_sweep", )) class Type: END = 0 INTERSECTION = 1 START = 2 if USE_VERTICAL: START_VERTICAL = 3 def __init__(self, type, point, segment, slope): assert(isinstance(point, tuple)) self.type = type self.point = point self.segment = segment # will be None for INTERSECTION self.slope = slope if segment is not None: self.span = segment[1][X] - segment[0][X] if USE_DEBUG: self.other = None self.in_sweep = False # note that this isn't essential, # it just avoids non-deterministic ordering, see #9. def __hash__(self): return hash(self.point) def is_vertical(self): # return self.segment[0][X] == self.segment[1][X] return self.span == NUM_ZERO def y_intercept_x(self, x: Real): # vertical events only for comparison (above_all check) # never added into the binary-tree its self if USE_VERTICAL: if self.is_vertical(): return None if x <= self.segment[0][X]: return self.segment[0][Y] elif x >= self.segment[1][X]: return self.segment[1][Y] # use the largest to avoid float precision error with nearly vertical lines. delta_x0 = x - self.segment[0][X] delta_x1 = self.segment[1][X] - x if delta_x0 > delta_x1: ifac = delta_x0 / self.span fac = NUM_ONE - ifac else: fac = delta_x1 / self.span ifac = NUM_ONE - fac assert(fac <= NUM_ONE) return (self.segment[0][Y] * fac) + (self.segment[1][Y] * ifac) @staticmethod def Compare(sweep_line, this, that): if this is that: return 0 if USE_DEBUG: if this.other is that: return 0 current_point_x = sweep_line._current_event_point_x this_y = this.y_intercept_x(current_point_x) that_y = that.y_intercept_x(current_point_x) # print(this_y, that_y) if USE_VERTICAL: if this_y is None: this_y = this.point[Y] if that_y is None: that_y = that.point[Y] delta_y = this_y - that_y assert((delta_y < NUM_ZERO) == (this_y < that_y)) # NOTE, VERY IMPORTANT TO USE EPSILON HERE! # otherwise w/ float precision errors we get incorrect comparisons # can get very strange & hard to debug output without this. if abs(delta_y) > NUM_EPS: return -1 if (delta_y < NUM_ZERO) else 1 else: this_slope = this.slope that_slope = that.slope if this_slope != that_slope: if sweep_line._before: return -1 if (this_slope > that_slope) else 1 else: return 1 if (this_slope > that_slope) else -1 delta_x_p1 = this.segment[0][X] - that.segment[0][X] if delta_x_p1 != NUM_ZERO: return -1 if (delta_x_p1 < NUM_ZERO) else 1 delta_x_p2 = this.segment[1][X] - that.segment[1][X] if delta_x_p2 != NUM_ZERO: return -1 if (delta_x_p2 < NUM_ZERO) else 1 return 0 def __repr__(self): return ("Event(0x%x, s0=%r, s1=%r, p=%r, type=%d, slope=%r)" % ( id(self), self.segment[0], self.segment[1], self.point, self.type, self.slope, )) class SweepLine: __slots__ = ( # A map holding all intersection points mapped to the Events # that form these intersections. # {Point: set(Event, ...), ...} "intersections", "queue", # Events (sorted set of ordered events, no values) # # note: START & END events are considered the same so checking if an event is in the tree # will return true if its opposite side is found. # This is essential for the algorithm to work, and why we don't explicitly remove START events. # Instead, the END events are never added to the current sweep, and removing them also removes the start. "_events_current_sweep", # The point of the current Event. "_current_event_point_x", # A flag to indicate if we're slightly before or after the line. "_before", ) def __init__(self, queue: EventQueue): self.intersections = {} self.queue = queue self._current_event_point_x = None self._events_current_sweep = RBTree(cmp=Event.Compare, cmp_data=self) self._before = True def get_intersections(self): """ Return a list of unordered intersection points. """ if Real is float: return list(self.intersections.keys()) else: return [(float(p[0]), float(p[1])) for p in self.intersections.keys()] # Not essential for implementing this algorithm, but useful. def get_intersections_with_segments(self): """ Return a list of unordered intersection '(point, segment)' pairs, where segments may contain 2 or more values. """ if Real is float: return [ (p, [event.segment for event in event_set]) for p, event_set in self.intersections.items() ] else: return [ ( (float(p[0]), float(p[1])), [((float(event.segment[0][0]), float(event.segment[0][1])), (float(event.segment[1][0]), float(event.segment[1][1]))) for event in event_set], ) for p, event_set in self.intersections.items() ] # Checks if an intersection exists between two Events 'a' and 'b'. def _check_intersection(self, a: Event, b: Event): # Return immediately in case either of the events is null, or # if one of them is an INTERSECTION event. if ( (a is None or b is None) or (a.type == Event.Type.INTERSECTION) or (b.type == Event.Type.INTERSECTION) ): return if a is b: return # Get the intersection point between 'a' and 'b'. p = isect_seg_seg_v2_point( a.segment[0], a.segment[1], b.segment[0], b.segment[1], ) # No intersection exists. if p is None: return # If the intersection is formed by both the segment endings, AND # USE_IGNORE_SEGMENT_ENDINGS is true, # return from this method. if USE_IGNORE_SEGMENT_ENDINGS: if ((len_squared_v2v2(p, a.segment[0]) < NUM_EPS_SQ or len_squared_v2v2(p, a.segment[1]) < NUM_EPS_SQ) and (len_squared_v2v2(p, b.segment[0]) < NUM_EPS_SQ or len_squared_v2v2(p, b.segment[1]) < NUM_EPS_SQ)): return # Add the intersection. events_for_point = self.intersections.pop(p, set()) is_new = len(events_for_point) == 0 events_for_point.add(a) events_for_point.add(b) self.intersections[p] = events_for_point # If the intersection occurs to the right of the sweep line, OR # if the intersection is on the sweep line and it's above the # current event-point, add it as a new Event to the queue. if is_new and p[X] >= self._current_event_point_x: event_isect = Event(Event.Type.INTERSECTION, p, None, None) self.queue.offer(p, event_isect) def _sweep_to(self, p): if p[X] == self._current_event_point_x: # happens in rare cases, # we can safely ignore return self._current_event_point_x = p[X] def insert(self, event): assert(event not in self._events_current_sweep) assert(not USE_VERTICAL or event.type != Event.Type.START_VERTICAL) if USE_DEBUG: assert(event.in_sweep == False) assert(event.other.in_sweep == False) self._events_current_sweep.insert(event, None) if USE_DEBUG: event.in_sweep = True event.other.in_sweep = True def remove(self, event): try: self._events_current_sweep.remove(event) if USE_DEBUG: assert(event.in_sweep == True) assert(event.other.in_sweep == True) event.in_sweep = False event.other.in_sweep = False return True except KeyError: if USE_DEBUG: assert(event.in_sweep == False) assert(event.other.in_sweep == False) return False def above(self, event): return self._events_current_sweep.succ_key(event, None) def below(self, event): return self._events_current_sweep.prev_key(event, None) ''' def above_all(self, event): while True: event = self.above(event) if event is None: break yield event ''' def above_all(self, event): # assert(event not in self._events_current_sweep) return self._events_current_sweep.key_slice(event, None, reverse=False) def handle(self, p, events_current): if len(events_current) == 0: return # done already # self._sweep_to(events_current[0]) assert(p[0] == self._current_event_point_x) if not USE_IGNORE_SEGMENT_ENDINGS: if len(events_current) > 1: for i in range(0, len(events_current) - 1): for j in range(i + 1, len(events_current)): self._check_intersection( events_current[i], events_current[j]) for e in events_current: self.handle_event(e) def handle_event(self, event): t = event.type if t == Event.Type.START: # print(" START") self._before = False self.insert(event) e_above = self.above(event) e_below = self.below(event) self._check_intersection(event, e_above) self._check_intersection(event, e_below) if USE_PARANOID: self._check_intersection(e_above, e_below) elif t == Event.Type.END: # print(" END") self._before = True e_above = self.above(event) e_below = self.below(event) self.remove(event) self._check_intersection(e_above, e_below) if USE_PARANOID: self._check_intersection(event, e_above) self._check_intersection(event, e_below) elif t == Event.Type.INTERSECTION: # print(" INTERSECTION") self._before = True event_set = self.intersections[event.point] # note: events_current aren't sorted. reinsert_stack = [] # Stack for e in event_set: # Since we know the Event wasn't already removed, # we want to insert it later on. if self.remove(e): reinsert_stack.append(e) self._before = False # Insert all Events that we were able to remove. while reinsert_stack: e = reinsert_stack.pop() self.insert(e) e_above = self.above(e) e_below = self.below(e) self._check_intersection(e, e_above) self._check_intersection(e, e_below) if USE_PARANOID: self._check_intersection(e_above, e_below) elif (USE_VERTICAL and (t == Event.Type.START_VERTICAL)): # just check sanity assert(event.segment[0][X] == event.segment[1][X]) assert(event.segment[0][Y] <= event.segment[1][Y]) # In this case we only need to find all segments in this span. y_above_max = event.segment[1][Y] # self.insert(event) for e_above in self.above_all(event): if e_above.type == Event.Type.START_VERTICAL: continue y_above = e_above.y_intercept_x( self._current_event_point_x) if USE_IGNORE_SEGMENT_ENDINGS: if y_above >= y_above_max - NUM_EPS: break else: if y_above > y_above_max: break # We know this intersects, # so we could use a faster function now: # ix = (self._current_event_point_x, y_above) # ...however best use existing functions # since it does all sanity checks on endpoints... etc. self._check_intersection(event, e_above) # self.remove(event) class EventQueue: __slots__ = ( # note: we only ever pop_min, this could use a 'heap' structure. # The sorted map holding the points -> event list # [Point: Event] (tree) "events_scan", ) def __init__(self, segments): self.events_scan = RBTree() # segments = [s for s in segments if s[0][0] != s[1][0] and s[0][1] != s[1][1]] for s in segments: assert(s[0][X] <= s[1][X]) slope = slope_v2v2(*s) if s[0] == s[1]: pass elif USE_VERTICAL and (s[0][X] == s[1][X]): e_start = Event(Event.Type.START_VERTICAL, s[0], s, slope) if USE_DEBUG: e_start.other = e_start # FAKE, avoid error checking self.offer(s[0], e_start) else: e_start = Event(Event.Type.START, s[0], s, slope) e_end = Event(Event.Type.END, s[1], s, slope) if USE_DEBUG: e_start.other = e_end e_end.other = e_start self.offer(s[0], e_start) self.offer(s[1], e_end) def offer(self, p, e: Event): """ Offer a new event ``s`` at point ``p`` in this queue. """ existing = self.events_scan.setdefault( p, ([], [], [], []) if USE_VERTICAL else ([], [], []), ) # Can use double linked-list for easy insertion at beginning/end ''' if e.type == Event.Type.END: existing.insert(0, e) else: existing.append(e) ''' existing[e.type].append(e) # return a set of events def poll(self): """ Get, and remove, the first (lowest) item from this queue. :return: the first (lowest) item from this queue. :rtype: Point, Event pair. """ assert(len(self.events_scan) != 0) p, events_current = self.events_scan.pop_min() return p, events_current def isect_segments_impl(segments, *, include_segments=False, validate=True) -> list: # order points left -> right if Real is float: segments = [ # in nearly all cases, comparing X is enough, # but compare Y too for vertical lines (s[0], s[1]) if (s[0] <= s[1]) else (s[1], s[0]) for s in segments] else: segments = [ # in nearly all cases, comparing X is enough, # but compare Y too for vertical lines ( (Real(s[0][0]), Real(s[0][1])), (Real(s[1][0]), Real(s[1][1])), ) if (s[0] <= s[1]) else ( (Real(s[1][0]), Real(s[1][1])), (Real(s[0][0]), Real(s[0][1])), ) for s in segments] # Ensure segments don't have duplicates or single points, see: #24. if validate: segments_old = segments segments = [] visited = set() for s in segments_old: # Ignore points. if s[0] == s[1]: continue # Ignore duplicates. if s in visited: continue visited.add(s) segments.append(s) del segments_old queue = EventQueue(segments) sweep_line = SweepLine(queue) while len(queue.events_scan) > 0: if USE_VERBOSE: print(len(queue.events_scan), sweep_line._current_event_point_x) p, e_ls = queue.poll() for events_current in e_ls: if events_current: sweep_line._sweep_to(p) sweep_line.handle(p, events_current) if include_segments is False: return sweep_line.get_intersections() else: return sweep_line.get_intersections_with_segments() def isect_polygon_impl(points, *, include_segments=False, validate=True) -> list: n = len(points) segments = [ (tuple(points[i]), tuple(points[(i + 1) % n])) for i in range(n) ] return isect_segments_impl(segments, include_segments=include_segments, validate=validate) def isect_segments(segments, *, validate=True) -> list: return isect_segments_impl(segments, include_segments=False, validate=validate) def isect_polygon(segments, *, validate=True) -> list: return isect_polygon_impl(segments, include_segments=False, validate=validate) def isect_segments_include_segments(segments, *, validate=True) -> list: return isect_segments_impl(segments, include_segments=True, validate=validate) def isect_polygon_include_segments(segments, *, validate=True) -> list: return isect_polygon_impl(segments, include_segments=True, validate=validate) # ---------------------------------------------------------------------------- # 2D math utilities def slope_v2v2(p1, p2): if p1[X] == p2[X]: if p1[Y] < p2[Y]: return NUM_INF else: return -NUM_INF else: return (p2[Y] - p1[Y]) / (p2[X] - p1[X]) def sub_v2v2(a, b): return ( a[0] - b[0], a[1] - b[1]) def dot_v2v2(a, b): return ( (a[0] * b[0]) + (a[1] * b[1])) def len_squared_v2v2(a, b): c = sub_v2v2(a, b) return dot_v2v2(c, c) def line_point_factor_v2(p, l1, l2, default=NUM_ZERO): u = sub_v2v2(l2, l1) h = sub_v2v2(p, l1) dot = dot_v2v2(u, u) return (dot_v2v2(u, h) / dot) if dot != NUM_ZERO else default def isect_seg_seg_v2_point(v1, v2, v3, v4, bias=NUM_ZERO): # Only for predictability and hashable point when same input is given if v1 > v2: v1, v2 = v2, v1 if v3 > v4: v3, v4 = v4, v3 if (v1, v2) > (v3, v4): v1, v2, v3, v4 = v3, v4, v1, v2 div = (v2[0] - v1[0]) * (v4[1] - v3[1]) - (v2[1] - v1[1]) * (v4[0] - v3[0]) if div == NUM_ZERO: return None vi = (((v3[0] - v4[0]) * (v1[0] * v2[1] - v1[1] * v2[0]) - (v1[0] - v2[0]) * (v3[0] * v4[1] - v3[1] * v4[0])) / div, ((v3[1] - v4[1]) * (v1[0] * v2[1] - v1[1] * v2[0]) - (v1[1] - v2[1]) * (v3[0] * v4[1] - v3[1] * v4[0])) / div, ) fac = line_point_factor_v2(vi, v1, v2, default=-NUM_ONE) if fac < NUM_ZERO - bias or fac > NUM_ONE + bias: return None fac = line_point_factor_v2(vi, v3, v4, default=-NUM_ONE) if fac < NUM_ZERO - bias or fac > NUM_ONE + bias: return None # vi = round(vi[X], 8), round(vi[Y], 8) return vi # ---------------------------------------------------------------------------- # Simple naive line intersect, (for testing only) def isect_segments__naive(segments) -> list: """ Brute force O(n2) version of ``isect_segments`` for test validation. """ isect = [] # order points left -> right if Real is float: segments = [ (s[0], s[1]) if s[0][X] <= s[1][X] else (s[1], s[0]) for s in segments] else: segments = [ ( (Real(s[0][0]), Real(s[0][1])), (Real(s[1][0]), Real(s[1][1])), ) if (s[0] <= s[1]) else ( (Real(s[1][0]), Real(s[1][1])), (Real(s[0][0]), Real(s[0][1])), ) for s in segments] n = len(segments) for i in range(n): a0, a1 = segments[i] for j in range(i + 1, n): b0, b1 = segments[j] if a0 not in (b0, b1) and a1 not in (b0, b1): ix = isect_seg_seg_v2_point(a0, a1, b0, b1) if ix is not None: # USE_IGNORE_SEGMENT_ENDINGS handled already isect.append(ix) return isect def isect_polygon__naive(points) -> list: """ Brute force O(n2) version of ``isect_polygon`` for test validation. """ isect = [] n = len(points) if Real is float: pass else: points = [(Real(p[0]), Real(p[1])) for p in points] for i in range(n): a0, a1 = points[i], points[(i + 1) % n] for j in range(i + 1, n): b0, b1 = points[j], points[(j + 1) % n] if a0 not in (b0, b1) and a1 not in (b0, b1): ix = isect_seg_seg_v2_point(a0, a1, b0, b1) if ix is not None: if USE_IGNORE_SEGMENT_ENDINGS: if ((len_squared_v2v2(ix, a0) < NUM_EPS_SQ or len_squared_v2v2(ix, a1) < NUM_EPS_SQ) and (len_squared_v2v2(ix, b0) < NUM_EPS_SQ or len_squared_v2v2(ix, b1) < NUM_EPS_SQ)): continue isect.append(ix) return isect # ---------------------------------------------------------------------------- # Inline Libs # # bintrees: 2.0.2, extracted from: # http://pypi.python.org/pypi/bintrees # # - Removed unused functions, such as slicing and range iteration. # - Added 'cmp' and and 'cmp_data' arguments, # so we can define our own comparison that takes an arg. # Needed for sweep-line. # - Added support for 'default' arguments for prev_item/succ_item, # so we can avoid exception handling. # ------- # ABCTree from operator import attrgetter _sentinel = object() class _ABCTree(object): def __init__(self, cmp=None, cmp_data=None): """T.__init__(...) initializes T; see T.__class__.__doc__ for signature""" self._root = None self._count = 0 if cmp is None: def cmp(cmp_data, a, b): if a < b: return -1 elif a > b: return 1 else: return 0 self._cmp = cmp self._cmp_data = cmp_data def clear(self): """T.clear() -> None. Remove all items from T.""" def _clear(node): if node is not None: _clear(node.left) _clear(node.right) node.free() _clear(self._root) self._count = 0 self._root = None @property def count(self): """Get items count.""" return self._count def _get_value_or_sentinel(self, key): node = self._root while node is not None: cmp = self._cmp(self._cmp_data, key, node.key) if cmp == 0: return node.value elif cmp < 0: node = node.left else: node = node.right return _sentinel def get_value(self, key): value = self._get_value_or_sentinel(key) if value is _sentinel: raise KeyError(str(key)) return value def pop_item(self): """T.pop_item() -> (k, v), remove and return some (key, value) pair as a 2-tuple; but raise KeyError if T is empty. """ if self.is_empty(): raise KeyError("pop_item(): tree is empty") node = self._root while True: if node.left is not None: node = node.left elif node.right is not None: node = node.right else: break key = node.key value = node.value self.remove(key) return key, value popitem = pop_item # for compatibility to dict() def min_item(self): """Get item with min key of tree, raises ValueError if tree is empty.""" if self.is_empty(): raise ValueError("Tree is empty") node = self._root while node.left is not None: node = node.left return node.key, node.value def max_item(self): """Get item with max key of tree, raises ValueError if tree is empty.""" if self.is_empty(): raise ValueError("Tree is empty") node = self._root while node.right is not None: node = node.right return node.key, node.value def succ_item(self, key, default=_sentinel): """Get successor (k,v) pair of key, raises KeyError if key is max key or key does not exist. optimized for pypy. """ # removed graingets version, because it was little slower on CPython and much slower on pypy # this version runs about 4x faster with pypy than the Cython version # Note: Code sharing of succ_item() and ceiling_item() is possible, but has always a speed penalty. node = self._root succ_node = None while node is not None: cmp = self._cmp(self._cmp_data, key, node.key) if cmp == 0: break elif cmp < 0: if (succ_node is None) or self._cmp(self._cmp_data, node.key, succ_node.key) < 0: succ_node = node node = node.left else: node = node.right if node is None: # stay at dead end if default is _sentinel: raise KeyError(str(key)) return default # found node of key if node.right is not None: # find smallest node of right subtree node = node.right while node.left is not None: node = node.left if succ_node is None: succ_node = node elif self._cmp(self._cmp_data, node.key, succ_node.key) < 0: succ_node = node elif succ_node is None: # given key is biggest in tree if default is _sentinel: raise KeyError(str(key)) return default return succ_node.key, succ_node.value def prev_item(self, key, default=_sentinel): """Get predecessor (k,v) pair of key, raises KeyError if key is min key or key does not exist. optimized for pypy. """ # removed graingets version, because it was little slower on CPython and much slower on pypy # this version runs about 4x faster with pypy than the Cython version # Note: Code sharing of prev_item() and floor_item() is possible, but has always a speed penalty. node = self._root prev_node = None while node is not None: cmp = self._cmp(self._cmp_data, key, node.key) if cmp == 0: break elif cmp < 0: node = node.left else: if (prev_node is None) or self._cmp(self._cmp_data, prev_node.key, node.key) < 0: prev_node = node node = node.right if node is None: # stay at dead end (None) if default is _sentinel: raise KeyError(str(key)) return default # found node of key if node.left is not None: # find biggest node of left subtree node = node.left while node.right is not None: node = node.right if prev_node is None: prev_node = node elif self._cmp(self._cmp_data, prev_node.key, node.key) < 0: prev_node = node elif prev_node is None: # given key is smallest in tree if default is _sentinel: raise KeyError(str(key)) return default return prev_node.key, prev_node.value def __repr__(self): """T.__repr__(...) <==> repr(x)""" tpl = "%s({%s})" % (self.__class__.__name__, '%s') return tpl % ", ".join(("%r: %r" % item for item in self.items())) def __contains__(self, key): """k in T -> True if T has a key k, else False""" return self._get_value_or_sentinel(key) is not _sentinel def __len__(self): """T.__len__() <==> len(x)""" return self.count def is_empty(self): """T.is_empty() -> False if T contains any items else True""" return self.count == 0 def set_default(self, key, default=None): """T.set_default(k[,d]) -> T.get(k,d), also set T[k]=d if k not in T""" value = self._get_value_or_sentinel(key) if value is _sentinel: self.insert(key, default) return default return value setdefault = set_default # for compatibility to dict() def get(self, key, default=None): """T.get(k[,d]) -> T[k] if k in T, else d. d defaults to None.""" value = self._get_value_or_sentinel(key) if value is _sentinel: return default return value def pop(self, key, *args): """T.pop(k[,d]) -> v, remove specified key and return the corresponding value. If key is not found, d is returned if given, otherwise KeyError is raised """ if len(args) > 1: raise TypeError("pop expected at most 2 arguments, got %d" % (1 + len(args))) value = self._get_value_or_sentinel(key) if value is _sentinel: if len(args) == 0: raise KeyError(str(key)) return args[0] self.remove(key) return value def prev_key(self, key, default=_sentinel): """Get predecessor to key, raises KeyError if key is min key or key does not exist. """ item = self.prev_item(key, default) return default if item is default else item[0] def succ_key(self, key, default=_sentinel): """Get successor to key, raises KeyError if key is max key or key does not exist. """ item = self.succ_item(key, default) return default if item is default else item[0] def pop_min(self): """T.pop_min() -> (k, v), remove item with minimum key, raise ValueError if T is empty. """ item = self.min_item() self.remove(item[0]) return item def pop_max(self): """T.pop_max() -> (k, v), remove item with maximum key, raise ValueError if T is empty. """ item = self.max_item() self.remove(item[0]) return item def min_key(self): """Get min key of tree, raises ValueError if tree is empty. """ return self.min_item()[0] def max_key(self): """Get max key of tree, raises ValueError if tree is empty. """ return self.max_item()[0] def key_slice(self, start_key, end_key, reverse=False): """T.key_slice(start_key, end_key) -> key iterator: start_key <= key < end_key. Yields keys in ascending order if reverse is False else in descending order. """ return (k for k, v in self.iter_items(start_key, end_key, reverse=reverse)) def iter_items(self, start_key=None, end_key=None, reverse=False): """Iterates over the (key, value) items of the associated tree, in ascending order if reverse is True, iterate in descending order, reverse defaults to False""" # optimized iterator (reduced method calls) - faster on CPython but slower on pypy if self.is_empty(): return [] if reverse: return self._iter_items_backward(start_key, end_key) else: return self._iter_items_forward(start_key, end_key) def _iter_items_forward(self, start_key=None, end_key=None): for item in self._iter_items(left=attrgetter("left"), right=attrgetter("right"), start_key=start_key, end_key=end_key): yield item def _iter_items_backward(self, start_key=None, end_key=None): for item in self._iter_items(left=attrgetter("right"), right=attrgetter("left"), start_key=start_key, end_key=end_key): yield item def _iter_items(self, left=attrgetter("left"), right=attrgetter("right"), start_key=None, end_key=None): node = self._root stack = [] go_left = True in_range = self._get_in_range_func(start_key, end_key) while True: if left(node) is not None and go_left: stack.append(node) node = left(node) else: if in_range(node.key): yield node.key, node.value if right(node) is not None: node = right(node) go_left = True else: if not len(stack): return # all done node = stack.pop() go_left = False def _get_in_range_func(self, start_key, end_key): if start_key is None and end_key is None: return lambda x: True else: if start_key is None: start_key = self.min_key() if end_key is None: return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0) else: return (lambda x: self._cmp(self._cmp_data, start_key, x) <= 0 and self._cmp(self._cmp_data, x, end_key) < 0) # ------ # RBTree class Node(object): """Internal object, represents a tree node.""" __slots__ = ['key', 'value', 'red', 'left', 'right'] def __init__(self, key=None, value=None): self.key = key self.value = value self.red = True self.left = None self.right = None def free(self): self.left = None self.right = None self.key = None self.value = None def __getitem__(self, key): """N.__getitem__(key) <==> x[key], where key is 0 (left) or 1 (right).""" return self.left if key == 0 else self.right def __setitem__(self, key, value): """N.__setitem__(key, value) <==> x[key]=value, where key is 0 (left) or 1 (right).""" if key == 0: self.left = value else: self.right = value class RBTree(_ABCTree): """ RBTree implements a balanced binary tree with a dict-like interface. see: http://en.wikipedia.org/wiki/Red_black_tree """ @staticmethod def is_red(node): if (node is not None) and node.red: return True else: return False @staticmethod def jsw_single(root, direction): other_side = 1 - direction save = root[other_side] root[other_side] = save[direction] save[direction] = root root.red = True save.red = False return save @staticmethod def jsw_double(root, direction): other_side = 1 - direction root[other_side] = RBTree.jsw_single(root[other_side], other_side) return RBTree.jsw_single(root, direction) def _new_node(self, key, value): """Create a new tree node.""" self._count += 1 return Node(key, value) def insert(self, key, value): """T.insert(key, value) <==> T[key] = value, insert key, value into tree.""" if self._root is None: # Empty tree case self._root = self._new_node(key, value) self._root.red = False # make root black return head = Node() # False tree root grand_parent = None grand_grand_parent = head parent = None # parent direction = 0 last = 0 # Set up helpers grand_grand_parent.right = self._root node = grand_grand_parent.right # Search down the tree while True: if node is None: # Insert new node at the bottom node = self._new_node(key, value) parent[direction] = node elif RBTree.is_red(node.left) and RBTree.is_red(node.right): # Color flip node.red = True node.left.red = False node.right.red = False # Fix red violation if RBTree.is_red(node) and RBTree.is_red(parent): direction2 = 1 if grand_grand_parent.right is grand_parent else 0 if node is parent[last]: grand_grand_parent[direction2] = RBTree.jsw_single(grand_parent, 1 - last) else: grand_grand_parent[direction2] = RBTree.jsw_double(grand_parent, 1 - last) # Stop if found if self._cmp(self._cmp_data, key, node.key) == 0: node.value = value # set new value for key break last = direction direction = 0 if (self._cmp(self._cmp_data, key, node.key) < 0) else 1 # Update helpers if grand_parent is not None: grand_grand_parent = grand_parent grand_parent = parent parent = node node = node[direction] self._root = head.right # Update root self._root.red = False # make root black def remove(self, key): """T.remove(key) <==> del T[key], remove item from tree.""" if self._root is None: raise KeyError(str(key)) head = Node() # False tree root node = head node.right = self._root parent = None grand_parent = None found = None # Found item direction = 1 # Search and push a red down while node[direction] is not None: last = direction # Update helpers grand_parent = parent parent = node node = node[direction] direction = 1 if (self._cmp(self._cmp_data, node.key, key) < 0) else 0 # Save found node if self._cmp(self._cmp_data, key, node.key) == 0: found = node # Push the red node down if not RBTree.is_red(node) and not RBTree.is_red(node[direction]): if RBTree.is_red(node[1 - direction]): parent[last] = RBTree.jsw_single(node, direction) parent = parent[last] elif not RBTree.is_red(node[1 - direction]): sibling = parent[1 - last] if sibling is not None: if (not RBTree.is_red(sibling[1 - last])) and (not RBTree.is_red(sibling[last])): # Color flip parent.red = False sibling.red = True node.red = True else: direction2 = 1 if grand_parent.right is parent else 0 if RBTree.is_red(sibling[last]): grand_parent[direction2] = RBTree.jsw_double(parent, last) elif RBTree.is_red(sibling[1-last]): grand_parent[direction2] = RBTree.jsw_single(parent, last) # Ensure correct coloring grand_parent[direction2].red = True node.red = True grand_parent[direction2].left.red = False grand_parent[direction2].right.red = False # Replace and remove if found if found is not None: found.key = node.key found.value = node.value parent[int(parent.right is node)] = node[int(node.left is None)] node.free() self._count -= 1 # Update root and make it black self._root = head.right if self._root is not None: self._root.red = False if not found: raise KeyError(str(key))