from copy import deepcopy def load_data(file): with open(file) as f: raw_data = f.readlines() grid = [] for line in raw_data: line = line.strip("\n") grid.append(list(line)) return grid def get_end_pos(grid): M = len(grid) N = len(grid[0]) for i in range(M): for j in range(N): if grid[i][j] == "E": return (i,j) class UnvisitedSet: # There must definitely better data structures for this... def __init__(self, grid): self.grid = grid self.visited = deepcopy(grid) self.values = [] # Store here the (min_value, pos, direction) where pos = (i,j) is position in grid M = len(grid) N = len(grid[0]) for i in range(M): for j in range(N): self.visited[i][j] = False pos = (i, j) if grid[i][j] == "S": self.values.append([0, pos, ">"]) else: self.values.append([float("inf"), pos, ""]) def update_value(self, new_pos, new_val, new_dir): for idx, (val, pos, d) in enumerate(self.values): if new_pos == pos and new_val < val: self.values[idx] = [new_val, new_pos, new_dir] break self.sort_values() def get_value(self, pos): for (v, p, d) in self.values: if pos == p: return v def get_direction(self, pos): for (v, p, d) in self.values: if pos == p: return d def mark_visited(self, pos): i, j = pos self.visited[i][j] = True def sort_values(self): self.values.sort(key=lambda x: x[0]) def get_min_unvisited(self): self.sort_values() for val, pos, d in self.values: i,j = pos if not self.visited[i][j] and val < float("inf"): return val, pos, d return "empty", "empty", "empty" def values_to_grid(self): new_grid = deepcopy(self.grid) for value, pos, _ in self.values: i, j = pos if value < float("inf"): new_grid[i][j] = "O" return new_grid def pprint(grid): grid_str = "\n".join(["".join(str(l)) for l in grid]) print(grid_str) def pprint2(grid): grid_str = "\n".join(["".join(l) for l in grid]) print(grid_str) def get_neighbours(pos, grid): directions = [(0,1), (1,0), (-1,0), (0, -1)] M = len(grid) N = len(grid[0]) ns = [] i, j = pos for dx, dy in directions: ni, nj = (i+dx, j+dy) if ni in range(M) and nj in range(N): if grid[ni][nj] != "#": ns.append((ni, nj)) return ns def get_cost(pos, next_pos, direction): # Only valid if we are moving at most 1 neighbor away i, j = pos ni, nj = next_pos if (ni - i) > 0: intended_direction = "^" elif (ni - i) < 0: intended_direction = "v" elif (nj - j) > 0: intended_direction = ">" elif (nj - j) < 0: intended_direction = "<" else: raise ValueError("Debug?") if direction == intended_direction: cost = 1 elif direction == "^" and intended_direction == "v": cost = 2001 elif direction == "v" and intended_direction == "^": cost = 2001 elif direction == ">" and intended_direction == "<": cost = 2001 elif direction == "<" and intended_direction == ">": cost = 2001 else: # Every other case involves a 90deg rotation cost = 1001 return cost, intended_direction file = "input.txt" # file = "test2.txt" grid = load_data(file) def djikstra(grid): end_pos = get_end_pos(grid) unvisited_set = UnvisitedSet(grid) ei, ej = end_pos while not unvisited_set.visited[ei][ej]: val, pos, d = unvisited_set.get_min_unvisited() ns = get_neighbours(pos, grid) for next_pos in ns: cost, next_dir = get_cost(pos, next_pos, d) unvisited_set.update_value(next_pos, val+cost, next_dir) unvisited_set.mark_visited(pos) return unvisited_set unvisited_set = djikstra(grid) end_pos = get_end_pos(grid) print(unvisited_set.get_value(end_pos)) # ## Part 2 # # def djikstra(grid): # # end_pos = get_end_pos(grid) # # unvisited_set = UnvisitedSet(grid) # # # # ei, ej = end_pos # # while not unvisited_set.visited[ei][ej]: # # # # val, pos, d = unvisited_set.get_min_unvisited() # # i, j = pos # # if unvisited_set.visited[i][j]: # # continue # # # # ns = get_neighbours(pos, grid) # # # # # # # # for next_pos in ns: # # cost, next_dir = get_cost(pos, next_pos, d) # # unvisited_set.update_value(next_pos, val+cost, next_dir) # # # # unvisited_set.mark_visited(pos) # # # # return unvisited_set # # file = "test.txt" # # file = "input.txt" # file = "test2.txt" # # file = "input.txt" # grid = load_data(file) # end_pos = get_end_pos(grid) # unvisited_set = djikstra(grid) # end_pos = get_end_pos(grid) # def draw_grid_from_paths(grid, min_paths): # grid_copy = deepcopy(grid) # for pos in min_paths: # i,j = pos # grid_copy[i][j] = "O" # # pprint2(grid_copy) # return grid_copy # # Here we will get all points in the minimum path backtracking back to the start point # min_paths = [end_pos] # # for direction in [">", "^"]: # # direction = ">" # queue = [(end_pos, "v"), (end_pos, "<")] # visited = set() # unvisited_set.update_value((1, 14), float("inf"), new_dir = "v") # while len(queue) > 0: # pos, direction = queue.pop(0) # if pos in visited: # continue # val = unvisited_set.get_value(pos) # end_val = unvisited_set.get_value(end_pos) # # direction = unvisited_set.get_direction(pos) # ns = get_neighbours(pos, grid) # n_vals_pos_dirs = [] # for n_pos in ns: # n_val = unvisited_set.get_value(n_pos) # n_dir = unvisited_set.get_direction(n_pos) # n_vals_pos_dirs.append((n_val, n_pos, n_dir)) # # print(pos) # # print(val) # # print(n_vals_pos_dirs) # # print() # # print() # # Work backwards and subtract cost # pprint2(draw_grid_from_paths(grid, min_paths)) # for n_val, n_pos, _ in n_vals_pos_dirs: # if n_pos in visited: # continue # # if pos == end_pos: # # cost = 1 # # else: # # cost, n_dir = get_cost(pos, n_pos, direction) # cost, n_dir = get_cost(n_pos, pos, direction) # print(pos, n_pos, cost, direction, n_dir) # import ipdb; ipdb.set_trace(); # if n_val > end_val: # continue # # if n_val - (val - cost) >= 0: # if n_val <= (val - cost): # min_paths.append(n_pos) # pprint2(draw_grid_from_paths(grid, min_paths)) # if n_pos not in visited: # queue.append((n_pos, n_dir)) # visited.add(pos) # # draw_grid_from_paths(grid, min_paths) # # print() # # # # # # while len(queue) > 0: # # pos = queue.pop(0) # # ns = get_neighbours(pos, grid) # # # # if pos in visited: # # continue # # # # n_vals_pos = [] # # for n_pos in ns: # # n_val = unvisited_set.get_value(n_pos) # # n_vals_pos.append((n_val, n_pos)) # # # # n_vals_not_inf = sum([val < float("inf") for val, pos in n_vals_pos]) # # if n_vals_not_inf > 0: # # for n in ns: # # queue.append(n) # # min_paths.append(pos) # # else: # # print("Here") # # # # visited.add(pos) # print(len(set(min_paths))) # # import numpy as np # # import matplotlib.pyplot as plt # # # # grid_copy = deepcopy(grid) # # M = len(grid_copy) # # N = len(grid_copy[0]) # # for value in unvisited_set.values: # # val, pos, dir = value # # i, j = pos # # grid_copy[i][j] = val if val != float("inf") else grid[i][j] # # for i in range(M): # # for j in range(N): # # if grid_copy[i][j] == "#": # # grid_copy[i][j] = -1 # # if grid_copy[i][j] == ".": # # grid_copy[i][j] = 0 # # arr = np.array(grid_copy) # # plt.figure() # # # # plt.imshow(arr)