advent24-llm / day16 /solution_jerpint.py
jerpint's picture
Add solution files
a4da721
raw
history blame
8.42 kB
from copy import deepcopy
def load_data(file):
with open(file) as f:
raw_data = f.readlines()
grid = []
for line in raw_data:
line = line.strip("\n")
grid.append(list(line))
return grid
def get_end_pos(grid):
M = len(grid)
N = len(grid[0])
for i in range(M):
for j in range(N):
if grid[i][j] == "E":
return (i,j)
class UnvisitedSet:
# There must definitely better data structures for this...
def __init__(self, grid):
self.grid = grid
self.visited = deepcopy(grid)
self.values = [] # Store here the (min_value, pos, direction) where pos = (i,j) is position in grid
M = len(grid)
N = len(grid[0])
for i in range(M):
for j in range(N):
self.visited[i][j] = False
pos = (i, j)
if grid[i][j] == "S":
self.values.append([0, pos, ">"])
else:
self.values.append([float("inf"), pos, ""])
def update_value(self, new_pos, new_val, new_dir):
for idx, (val, pos, d) in enumerate(self.values):
if new_pos == pos and new_val < val:
self.values[idx] = [new_val, new_pos, new_dir]
break
self.sort_values()
def get_value(self, pos):
for (v, p, d) in self.values:
if pos == p:
return v
def get_direction(self, pos):
for (v, p, d) in self.values:
if pos == p:
return d
def mark_visited(self, pos):
i, j = pos
self.visited[i][j] = True
def sort_values(self):
self.values.sort(key=lambda x: x[0])
def get_min_unvisited(self):
self.sort_values()
for val, pos, d in self.values:
i,j = pos
if not self.visited[i][j] and val < float("inf"):
return val, pos, d
return "empty", "empty", "empty"
def values_to_grid(self):
new_grid = deepcopy(self.grid)
for value, pos, _ in self.values:
i, j = pos
if value < float("inf"):
new_grid[i][j] = "O"
return new_grid
def pprint(grid):
grid_str = "\n".join(["".join(str(l)) for l in grid])
print(grid_str)
def pprint2(grid):
grid_str = "\n".join(["".join(l) for l in grid])
print(grid_str)
def get_neighbours(pos, grid):
directions = [(0,1), (1,0), (-1,0), (0, -1)]
M = len(grid)
N = len(grid[0])
ns = []
i, j = pos
for dx, dy in directions:
ni, nj = (i+dx, j+dy)
if ni in range(M) and nj in range(N):
if grid[ni][nj] != "#":
ns.append((ni, nj))
return ns
def get_cost(pos, next_pos, direction):
# Only valid if we are moving at most 1 neighbor away
i, j = pos
ni, nj = next_pos
if (ni - i) > 0:
intended_direction = "^"
elif (ni - i) < 0:
intended_direction = "v"
elif (nj - j) > 0:
intended_direction = ">"
elif (nj - j) < 0:
intended_direction = "<"
else:
raise ValueError("Debug?")
if direction == intended_direction:
cost = 1
elif direction == "^" and intended_direction == "v":
cost = 2001
elif direction == "v" and intended_direction == "^":
cost = 2001
elif direction == ">" and intended_direction == "<":
cost = 2001
elif direction == "<" and intended_direction == ">":
cost = 2001
else:
# Every other case involves a 90deg rotation
cost = 1001
return cost, intended_direction
file = "input.txt"
# file = "test2.txt"
grid = load_data(file)
def djikstra(grid):
end_pos = get_end_pos(grid)
unvisited_set = UnvisitedSet(grid)
ei, ej = end_pos
while not unvisited_set.visited[ei][ej]:
val, pos, d = unvisited_set.get_min_unvisited()
ns = get_neighbours(pos, grid)
for next_pos in ns:
cost, next_dir = get_cost(pos, next_pos, d)
unvisited_set.update_value(next_pos, val+cost, next_dir)
unvisited_set.mark_visited(pos)
return unvisited_set
unvisited_set = djikstra(grid)
end_pos = get_end_pos(grid)
print(unvisited_set.get_value(end_pos))
# ## Part 2
# # def djikstra(grid):
# # end_pos = get_end_pos(grid)
# # unvisited_set = UnvisitedSet(grid)
# #
# # ei, ej = end_pos
# # while not unvisited_set.visited[ei][ej]:
# #
# # val, pos, d = unvisited_set.get_min_unvisited()
# # i, j = pos
# # if unvisited_set.visited[i][j]:
# # continue
# #
# # ns = get_neighbours(pos, grid)
# #
# #
# #
# # for next_pos in ns:
# # cost, next_dir = get_cost(pos, next_pos, d)
# # unvisited_set.update_value(next_pos, val+cost, next_dir)
# #
# # unvisited_set.mark_visited(pos)
# #
# # return unvisited_set
# # file = "test.txt"
# # file = "input.txt"
# file = "test2.txt"
# # file = "input.txt"
# grid = load_data(file)
# end_pos = get_end_pos(grid)
# unvisited_set = djikstra(grid)
# end_pos = get_end_pos(grid)
# def draw_grid_from_paths(grid, min_paths):
# grid_copy = deepcopy(grid)
# for pos in min_paths:
# i,j = pos
# grid_copy[i][j] = "O"
# # pprint2(grid_copy)
# return grid_copy
# # Here we will get all points in the minimum path backtracking back to the start point
# min_paths = [end_pos]
# # for direction in [">", "^"]:
# # direction = ">"
# queue = [(end_pos, "v"), (end_pos, "<")]
# visited = set()
# unvisited_set.update_value((1, 14), float("inf"), new_dir = "v")
# while len(queue) > 0:
# pos, direction = queue.pop(0)
# if pos in visited:
# continue
# val = unvisited_set.get_value(pos)
# end_val = unvisited_set.get_value(end_pos)
# # direction = unvisited_set.get_direction(pos)
# ns = get_neighbours(pos, grid)
# n_vals_pos_dirs = []
# for n_pos in ns:
# n_val = unvisited_set.get_value(n_pos)
# n_dir = unvisited_set.get_direction(n_pos)
# n_vals_pos_dirs.append((n_val, n_pos, n_dir))
# # print(pos)
# # print(val)
# # print(n_vals_pos_dirs)
# # print()
# # print()
# # Work backwards and subtract cost
# pprint2(draw_grid_from_paths(grid, min_paths))
# for n_val, n_pos, _ in n_vals_pos_dirs:
# if n_pos in visited:
# continue
# # if pos == end_pos:
# # cost = 1
# # else:
# # cost, n_dir = get_cost(pos, n_pos, direction)
# cost, n_dir = get_cost(n_pos, pos, direction)
# print(pos, n_pos, cost, direction, n_dir)
# import ipdb; ipdb.set_trace();
# if n_val > end_val:
# continue
# # if n_val - (val - cost) >= 0:
# if n_val <= (val - cost):
# min_paths.append(n_pos)
# pprint2(draw_grid_from_paths(grid, min_paths))
# if n_pos not in visited:
# queue.append((n_pos, n_dir))
# visited.add(pos)
# # draw_grid_from_paths(grid, min_paths)
# # print()
# #
# #
# # while len(queue) > 0:
# # pos = queue.pop(0)
# # ns = get_neighbours(pos, grid)
# #
# # if pos in visited:
# # continue
# #
# # n_vals_pos = []
# # for n_pos in ns:
# # n_val = unvisited_set.get_value(n_pos)
# # n_vals_pos.append((n_val, n_pos))
# #
# # n_vals_not_inf = sum([val < float("inf") for val, pos in n_vals_pos])
# # if n_vals_not_inf > 0:
# # for n in ns:
# # queue.append(n)
# # min_paths.append(pos)
# # else:
# # print("Here")
# #
# # visited.add(pos)
# print(len(set(min_paths)))
# # import numpy as np
# # import matplotlib.pyplot as plt
# #
# # grid_copy = deepcopy(grid)
# # M = len(grid_copy)
# # N = len(grid_copy[0])
# # for value in unvisited_set.values:
# # val, pos, dir = value
# # i, j = pos
# # grid_copy[i][j] = val if val != float("inf") else grid[i][j]
# # for i in range(M):
# # for j in range(N):
# # if grid_copy[i][j] == "#":
# # grid_copy[i][j] = -1
# # if grid_copy[i][j] == ".":
# # grid_copy[i][j] = 0
# # arr = np.array(grid_copy)
# # plt.figure()
# #
# # plt.imshow(arr)