Spaces:
Sleeping
Sleeping
File size: 3,487 Bytes
bbae066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import argparse
from pathlib import Path
from typing import Sequence, Union
from PIL import Image
import torch
import torchvision
import numpy as np
import re
import image_to_fen.util as util
STAGED_MODEL_DIRNAME = Path(__file__).resolve().parent / "artifacts" / "image-to-fen"
MODEL_FILE = "model.pt"
class ImageToFen:
"""Takes image of chess board and returns FEN string."""
def __init__(self, model_path=None):
if model_path is None:
model_path = STAGED_MODEL_DIRNAME / MODEL_FILE
self.model = torch.jit.load(model_path)
@torch.no_grad()
def predict(self, image: Union[str, Path, Image.Image]) -> str:
"""Predict FEN string for image of chess board."""
image = image
if not isinstance(image, Image.Image):
image = util.read_image_pil(image, grayscale=True)
image = image.resize((200, 200))
image = torchvision.transforms.PILToTensor()(image)/255
pred = self.model([image])[1][0]
nms_pred = apply_nms(pred, iou_thresh=0.2)
pred_str = boxes_labels_to_fen(nms_pred['boxes'], nms_pred['labels'])
return pred_str
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
def boxes_labels_to_fen(boxes, labels, square_size=25):
boxes = torch.round(boxes / 25) * 25
eye = np.eye(13)
one_hot = onehot_from_fen("8-8-8-8-8-8-8-8")
for i, box in enumerate(boxes):
x = box[0]
y = box[1]
ind = int((x / square_size) + (y / square_size) * 8)
if (ind >= 64):
continue
one_hot[ind] = eye[12 - labels[i]].reshape((1, 13)).astype(int)
return fen_from_onehot(one_hot)
def onehot_from_fen(fen):
piece_symbols = 'prbnkqPRBNKQ'
eye = np.eye(13)
output = np.empty((0, 13))
fen = re.sub('[-]', '', fen)
for char in fen:
if(char in '12345678'):
output = np.append(
output, np.tile(eye[12], (int(char), 1)), axis=0)
else:
idx = piece_symbols.index(char)
output = np.append(output, eye[idx].reshape((1, 13)), axis=0)
return output
def fen_from_onehot(one_hot):
piece_symbols = 'prbnkqPRBNKQ'
output = ''
for j in range(8):
for i in range(8):
idx = np.where(one_hot[j*8 + i]==1)[0][0]
if(idx == 12):
output += ' '
else:
output += piece_symbols[idx]
if(j != 7):
output += '-'
for i in range(8, 0, -1):
output = output.replace(' ' * i, str(i))
return output
def main():
"""Run prediction on image."""
parser = argparse.ArgumentParser(description="Predict FEN string for image of chess board.")
parser.add_argument("image", type=Path, help="Path to image file.")
parser.add_argument("--model-path", type=Path, help="Path to model file.")
args = parser.parse_args()
image_to_fen = ImageToFen(args.model_path)
pred = image_to_fen.predict(args.image)
print(f"Prediction: {pred}")
# image_to_fen/tests/support/boards/phpSrRLQ1.png
if __name__ == "__main__":
main() |