|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from sklearn.model_selection import train_test_split |
|
from torch.utils.data.dataset import IterableDataset |
|
from collections import deque |
|
|
|
from numpy.random import default_rng |
|
|
|
DATA = np.load( |
|
|
|
"sudoku_reshaped_3_million.npz" |
|
) |
|
|
|
rng = np.random.default_rng() |
|
|
|
|
|
def get_datasets( |
|
add_proba_fill=False, train_size=1280 // 2, test_size=1280 // 2, max_holes=None |
|
): |
|
quizzes = DATA["quizzes"][: train_size + test_size] |
|
solutions = DATA["solutions"][: train_size + test_size] |
|
X = quizzes |
|
if max_holes: |
|
while True: |
|
x_holes = X[:, 1].sum(-1) == 0 |
|
x_nb_holes = x_holes.sum((1, 2)) |
|
mask_x_max_holes = x_nb_holes > max_holes |
|
if not any(mask_x_max_holes): |
|
break |
|
for idx_x in np.nonzero(mask_x_max_holes)[0]: |
|
sub_x_holes = x_holes[idx_x] |
|
idx_fill = rng.choice(np.transpose(np.nonzero(sub_x_holes))) |
|
X[idx_x, :, idx_fill[0], idx_fill[1], :] = solutions[ |
|
idx_x, :, idx_fill[0], idx_fill[1], : |
|
] |
|
X = X.reshape(X.shape[0], 2, 9 * 9 * 9) |
|
solutions = solutions.reshape(solutions.shape[0], 2, 9 * 9 * 9) |
|
|
|
X_train, X_test, solutions_train, solutions_test = train_test_split( |
|
X, solutions, test_size=test_size, random_state=42 |
|
) |
|
if add_proba_fill: |
|
X_train_bis = X_train.copy() |
|
mask = solutions_train == 1 |
|
X_train_bis[mask] = np.random.randint(0, 2, size=mask.sum()) |
|
X_train = np.concatenate([X_train, X_train_bis]) |
|
solutions_train = np.concatenate([solutions_train, solutions_train]) |
|
|
|
train = torch.utils.data.TensorDataset( |
|
torch.Tensor(X_train), torch.Tensor(solutions_train) |
|
) |
|
test = torch.utils.data.TensorDataset( |
|
torch.Tensor(X_test), torch.Tensor(solutions_test) |
|
) |
|
return train, test |
|
|
|
|
|
train_dataset, test_dataset = get_datasets() |
|
|
|
|
|
def data_loader(batch_size=32, add_proba_fill=False): |
|
train, test = get_datasets(add_proba_fill=add_proba_fill) |
|
|
|
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size) |
|
|
|
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size) |
|
|
|
return train_loader, test_loader |
|
|
|
|
|
class DataIterBuffer(IterableDataset): |
|
def __init__(self, raw_dataset=[], buffer_optim=50, prop_new=0.1, seed=1): |
|
self.raw_dataset = raw_dataset |
|
|
|
self.buffer = deque() |
|
self.buffer_optim = buffer_optim |
|
self.prop_new = prop_new |
|
self.rng = default_rng(seed=seed) |
|
self.idx_dataset = 0 |
|
|
|
def __iter__(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
if (np.random.random() < self.prop_new) and ( |
|
len(self.buffer) <= self.buffer_optim |
|
): |
|
if self.idx_dataset >= len(self.raw_dataset): |
|
if len(self.buffer) != 0: |
|
yield self.buffer.popleft() |
|
else: |
|
break |
|
else: |
|
yield self.raw_dataset[self.idx_dataset] |
|
self.idx_dataset += 1 |
|
else: |
|
if len(self.buffer) != 0: |
|
yield self.buffer.popleft() |
|
else: |
|
if self.idx_dataset >= len(self.raw_dataset): |
|
break |
|
else: |
|
yield self.raw_dataset[self.idx_dataset] |
|
self.idx_dataset += 1 |
|
|
|
def append(self, X, Y) -> None: |
|
"""Add experience to the buffer. |
|
|
|
Args: |
|
experience: tuple (state, action, reward, done, new_state) |
|
""" |
|
|
|
X[Y == 0] = 0 |
|
mask = ~(X == Y).view(-1, 2 * 729).all(dim=1) |
|
|
|
for x, y in zip(X[mask], Y[mask]): |
|
self.buffer.append((x, y)) |
|
|
|
def __len__(self): |
|
return len(self.buffer) + len(self.raw_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|