diff --git a/pathfinding/core/heap.py b/pathfinding/core/heap.py index 75f873b..dbff1b7 100644 --- a/pathfinding/core/heap.py +++ b/pathfinding/core/heap.py @@ -1,19 +1,18 @@ """Simple heap with ordering and removal.""" import heapq +from typing import Optional from .graph import Graph from .grid import Grid from .world import World +from .node import Node class SimpleHeap: - """Simple wrapper around open_list that keeps track of order and removed - nodes automatically.""" + """Simple wrapper around open_list that keeps track of order.""" def __init__(self, node, grid): self.grid = grid self.open_list = [self._get_node_tuple(node, 0)] - self.removed_node_tuples = set() - self.heap_order = {} self.number_pushed = 0 def _get_node_tuple(self, node, heap_order): @@ -34,28 +33,27 @@ def _get_node_id(self, node): elif isinstance(self.grid, World): return (node.x, node.y, node.grid_id) - def pop_node(self): - """ - Pops node off the heap. i.e. returns the one with the lowest f. - - Notes: - 1. Checks if that values is in removed_node_tuples first, if not tries - again. - 2. We use this approach to avoid invalidating the heap structure. - """ - node_tuple = heapq.heappop(self.open_list) - while node_tuple in self.removed_node_tuples: + def pop_node(self) -> Optional[Node]: + """Pops node off the heap. i.e. returns the one with the lowest f.""" + while self.open_list: node_tuple = heapq.heappop(self.open_list) - if isinstance(self.grid, Graph): - node = self.grid.node(node_tuple[2]) - elif isinstance(self.grid, Grid): - node = self.grid.node(node_tuple[2], node_tuple[3]) - elif isinstance(self.grid, World): - node = self.grid.grids[ - node_tuple[4]].node(node_tuple[2], node_tuple[3]) - - return node + if isinstance(self.grid, Graph): + node = self.grid.node(node_tuple[2]) + elif isinstance(self.grid, Grid): + node = self.grid.node(node_tuple[2], node_tuple[3]) + elif isinstance(self.grid, World): + node = self.grid.grids[ + node_tuple[4]].node(node_tuple[2], node_tuple[3]) + + # node already updated with lower f, ignore + f = node_tuple[0] + if f > node.f: + continue + else: + return node + + return None def push_node(self, node): """ @@ -65,27 +63,9 @@ def push_node(self, node): """ self.number_pushed = self.number_pushed + 1 node_tuple = self._get_node_tuple(node, self.number_pushed) - node_id = self._get_node_id(node) - - self.heap_order[node_id] = self.number_pushed heapq.heappush(self.open_list, node_tuple) - def remove_node(self, node, f): - """ - Remove the node from the heap. - - This just stores it in a set and we just ignore the node if it does - get popped from the heap. - - :param node: The node to remove. - :param f: The old f value of the node. - """ - node_id = self._get_node_id(node) - heap_order = self.heap_order[node_id] - node_tuple = self._get_node_tuple(node, heap_order) - self.removed_node_tuples.add(node_tuple) - def __len__(self): """Returns the length of the open_list.""" return len(self.open_list) diff --git a/pathfinding/finder/finder.py b/pathfinding/finder/finder.py index ac29007..e034525 100644 --- a/pathfinding/finder/finder.py +++ b/pathfinding/finder/finder.py @@ -131,7 +131,6 @@ def process_node( ng = parent.g + graph.calc_cost(parent, node, self.weighted) if not node.opened or ng < node.g: - old_f = node.f node.g = ng node.h = node.h or self.apply_heuristic(node, end, graph=graph) # f is the estimated total cost from start to goal @@ -144,7 +143,6 @@ def process_node( # the node can be reached with smaller cost. # Since its f value has been updated, we have to # update its position in the open list - open_list.remove_node(node, old_f) open_list.push_node(node) def check_neighbors(self, start, end, graph, open_list, diff --git a/pytest.ini b/pytest.ini index 0776b8c..ea224bc 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,5 @@ [pytest] +pythonpath = . log_cli = 1 log_cli_level = INFO log_cli_format = %(asctime)s.%(msecs)03d [%(levelname)8s] (%(filename)s:%(lineno)s) %(message)s diff --git a/test/test_heap.py b/test/test_heap.py index b675a91..4a4a1aa 100644 --- a/test/test_heap.py +++ b/test/test_heap.py @@ -16,11 +16,17 @@ def test_heap(): open_list.push_node(grid.node(1, 2)) open_list.push_node(grid.node(1, 3)) - # Test removal and pop - assert len(open_list) == 3 - open_list.remove_node(grid.node(1, 2), 0) - assert len(open_list) == 3 - assert open_list.pop_node() == grid.node(1, 1) + assert open_list.pop_node() == grid.node(1, 2) assert open_list.pop_node() == grid.node(1, 3) assert len(open_list) == 0 + + # Test inconsistent f + test_node = grid.node(1,1) + test_node.f = 1 + open_list.push_node(test_node) + test_node.f = 0 + open_list.push_node(test_node) + assert open_list.pop_node() == test_node + assert open_list.pop_node() is None + assert len(open_list) == 0