asataura's picture
initial commit
6fa23b0
raw
history blame
2.9 kB
import numpy as np
from utilities.data_structures.Node import Node
class Max_Heap(object):
"""Generic max heap object"""
def __init__(self, max_size, dimension_of_value_attribute, default_key_to_use):
self.max_size = max_size
self.dimension_of_value_attribute = dimension_of_value_attribute
self.default_key_to_use = default_key_to_use
self.heap = self.initialise_heap()
def initialise_heap(self):
"""Initialises a heap of Nodes of length self.max_size * 4 + 1"""
heap = np.array([Node(self.default_key_to_use, tuple([None for _ in range(self.dimension_of_value_attribute)])) for _ in range(self.max_size * 4 + 1)])
# We don't use the 0th element in a heap so we want it to have infinite value so it is never swapped with a lower node
heap[0] = Node(float("inf"), (None, None, None, None, None))
return heap
def update_element_and_reorganise_heap(self, heap_index_for_change, new_element):
self.update_heap_element(heap_index_for_change, new_element)
self.reorganise_heap(heap_index_for_change)
def update_heap_element(self, heap_index, new_element):
self.heap[heap_index] = new_element
def reorganise_heap(self, heap_index_changed):
"""This reorganises the heap after a new value is added so as to keep the max value at the top of the heap which
is index position 1 in the array self.heap"""
node_key = self.heap[heap_index_changed].key
parent_index = int(heap_index_changed / 2)
if node_key > self.heap[parent_index].key:
self.swap_heap_elements(heap_index_changed, parent_index)
self.reorganise_heap(parent_index)
else:
biggest_child_index = self.calculate_index_of_biggest_child(heap_index_changed)
if node_key < self.heap[biggest_child_index].key:
self.swap_heap_elements(heap_index_changed, biggest_child_index)
self.reorganise_heap(biggest_child_index)
def swap_heap_elements(self, index1, index2):
"""Swaps the position of two heap elements"""
self.heap[index1], self.heap[index2] = self.heap[index2], self.heap[index1]
def calculate_index_of_biggest_child(self, heap_index_changed):
"""Calculates the heap index of the node's child with the biggest td_error value"""
left_child = self.heap[int(heap_index_changed * 2)]
right_child = self.heap[int(heap_index_changed * 2) + 1]
if left_child.key > right_child.key:
biggest_child_index = heap_index_changed * 2
else:
biggest_child_index = heap_index_changed * 2 + 1
return biggest_child_index
def give_max_key(self):
"""Returns the maximum td error currently in the heap. Because it is a max heap this is the top element of the heap"""
return self.heap[1].key