santit96's picture
Fix bug in updating the state based on a mask, in update_with_mask function
0ffced8
"""
Keep the state in a 1D int array
index[0] = remaining steps
[[status, status, status, status, status]
for _ in "ABCD..."]
where status has codes
[0, 0, 0] - no information about the char
[1, 0, 0] - char is definitely not in this spot
[0, 1, 0] - char is maybe in this spot
[0, 0, 1] - char is definitely in this spot
"""
import collections
from typing import List, Tuple
import numpy as np
from .const import CHAR_REWARD, WORDLE_CHARS, WORDLE_N
WordleState = np.ndarray
def get_nvec(max_turns: int):
return [max_turns] + [2] * 3 * WORDLE_N * len(WORDLE_CHARS)
def new(max_turns: int) -> WordleState:
return np.array(
[max_turns] + [0, 0, 0] * WORDLE_N * len(WORDLE_CHARS), dtype=np.int32
)
def remaining_steps(state: WordleState) -> int:
return state[0]
NO = 0
SOMEWHERE = 1
YES = 2
def update_from_mask(state: WordleState, word: str, mask: List[int]) -> WordleState:
"""
return a copy of state that has been updated to new state
From a mask we need slighty different logic since we don't know the
goal word.
:param state:
:param word:
:param goal_word:
:return:
"""
state = state.copy()
prior_yes = []
prior_maybe = []
# We need two passes because first pass sets definitely yesses
# second pass sets the no's for those who aren't already yes
state[0] -= 1
for i, c in enumerate(word):
cint = ord(c) - ord(WORDLE_CHARS[0])
offset = 1 + cint * WORDLE_N * 3
if mask[i] == YES:
prior_yes.append(c)
_set_yes(state, offset, cint, i)
for i, c in enumerate(word):
cint = ord(c) - ord(WORDLE_CHARS[0])
offset = 1 + cint * WORDLE_N * 3
if mask[i] == SOMEWHERE:
prior_maybe.append(c)
# Char at position i = no,
# and in other positions maybe except it had a value before,
# other chars stay as they are
_set_no(state, offset, i)
_set_if_cero(state, offset, [0, 1, 0])
elif mask[i] == NO:
# Need to check this first in case there's prior maybe or yes
if c in prior_yes or c in prior_maybe:
# Definitely not here
_set_no(state, offset, i)
# It's zero everywhere except the yesses and maybes
_set_if_cero(state, offset, [1, 0, 0])
else:
# Just straight up no
_set_all_no(state, offset)
return state
def get_mask(word: str, goal_word: str) -> List[int]:
# Definite yesses first
mask = [0, 0, 0, 0, 0]
counts = collections.Counter(goal_word)
for i, c in enumerate(word):
if goal_word[i] == c:
mask[i] = 2
counts[c] -= 1
for i, c in enumerate(word):
if mask[i] == 2:
continue
elif c in counts:
if counts[c] > 0:
mask[i] = 1
counts[c] -= 1
else:
for j in range(i + 1, len(mask)):
if mask[j] == 2:
continue
mask[j] = 0
return mask
def update_mask(state: WordleState, word: str, goal_word: str) -> WordleState:
"""
return a copy of state that has been updated to new state
:param state:
:param word:
:param goal_word:
:return:
"""
mask = get_mask(word, goal_word)
return update_from_mask(state, word, mask)
def update(state: WordleState, word: str, goal_word: str) -> Tuple[WordleState, float]:
state = state.copy()
reward = 0
state[0] -= 1
processed_letters = []
for i, c in enumerate(word):
cint = ord(c) - ord(WORDLE_CHARS[0])
offset = 1 + cint * WORDLE_N * 3
if goal_word[i] == c:
# char at position i = yes, all other chars at position i == no
reward += CHAR_REWARD
_set_yes(state, offset, cint, i)
processed_letters.append(c)
for i, c in enumerate(word):
cint = ord(c) - ord(WORDLE_CHARS[0])
offset = 1 + cint * WORDLE_N * 3
if goal_word[i] != c:
if c in goal_word and goal_word.count(c) > processed_letters.count(c):
# Char at position i = no,
# and in other positions maybe except it had a value before,
# other chars stay as they are
_set_no(state, offset, i)
_set_if_cero(state, offset, [0, 1, 0])
reward += CHAR_REWARD * 0.1
elif c not in goal_word:
# Char at all positions = no
_set_all_no(state, offset)
else:
# goal_word.count(c) <= processed_letters.count(c)
# and goal in word
# At i and in every position which is not set = no
_set_no(state, offset, i)
_set_if_cero(state, offset, [1, 0, 0])
processed_letters.append(c)
return state, reward
def _set_if_cero(state, offset, value):
# set offset character with value at all positions
# but only if it didnt have a value before
for char_idx in range(0, WORDLE_N * 3, 3):
char_offset = offset + char_idx
if tuple(state[char_offset : char_offset + 3]) == (0, 0, 0):
state[char_offset : char_offset + 3] = value
def _set_yes(state, offset, char_int, char_pos):
# char at position char_pos = yes,
# all other chars at position char_pos == no
pos_offset = 3 * char_pos
state[offset + pos_offset : offset + pos_offset + 3] = [0, 0, 1]
for ocint in range(len(WORDLE_CHARS)):
if ocint != char_int:
oc_offset = 1 + ocint * WORDLE_N * 3
yes_index = oc_offset + pos_offset
state[yes_index : yes_index + 3] = [1, 0, 0]
def _set_no(state, offset, char_pos):
# Set offset character = no at char_pos position
state[offset + 3 * char_pos : offset + 3 * char_pos + 3] = [1, 0, 0]
def _set_all_no(state, offset):
# Set offset character = no at all positions
state[offset : offset + 3 * WORDLE_N] = [1, 0, 0] * WORDLE_N