|
--- |
|
license: mit |
|
datasets: |
|
- yp-edu/stockfish-debug |
|
name: yp-edu/gpt2-stockfish-debug |
|
results: |
|
- task: train |
|
metrics: |
|
- name: train-loss |
|
type: loss |
|
value: 0.151 |
|
verified: false |
|
- name: eval-loss |
|
type: loss |
|
value: 0.138 |
|
verified: false |
|
widget: |
|
- text: "FEN: rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1\nMOVE:" |
|
example_title: "Init Board" |
|
- text: "FEN: r2q1rk1/1p3ppp/4bb2/p2p4/5B2/1P1P4/1PPQ1PPP/R3R1K1 w - - 1 17\nMOVE:" |
|
example_title: "Middle Board" |
|
- text: "FEN: 4r1k1/1p1b1ppp/8/8/3P4/2P5/1q3PPP/6K1 b - - 0 28\nMOVE:" |
|
example_title: "Checkmate Possible" |
|
--- |
|
# Model Card for gpt2-stockfish-debug |
|
|
|
See my [blog post](https://yp-edu.github.io/projects/training-gpt2-on-stockfish-games) for additional details. |
|
|
|
## Training Details |
|
|
|
The model was trained during 1 epoch on the [yp-edu/stockfish-debug](https://huggingface.co/datasets/yp-edu/stockfish-debug) dataset (no hyperparameter tuning done). The samples are: |
|
|
|
```json |
|
{"prompt":"FEN: {fen}\nMOVE:", "completion": " {move}"} |
|
``` |
|
|
|
Two possible simple extensions: |
|
|
|
- Expand the FEN string: `r2qk3/...` -> `r11qk111/...` or equivalent |
|
- Condition with the result (ELO not available in the dataset): |
|
```json |
|
{"prompt":"RES: {res}\nFEN: {fen}\nMOVE:", "completion": " {move}"} |
|
``` |
|
|
|
## Use the Model |
|
|
|
The following code requires `python-chess` (in addition to `transformers`) which you can install using `pip install python-chess`. |
|
|
|
```python |
|
import chess |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
def next_move(model, tokenizer, fen): |
|
input_ids = tokenizer(f"FEN: {fen}\nMOVE:", return_tensors="pt") |
|
input_ids = {k: v.to(model.device) for k, v in input_ids.items()} |
|
out = model.generate( |
|
**input_ids, |
|
max_new_tokens=10, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
temperature=0.1, |
|
) |
|
out_str = tokenizer.batch_decode(out)[0] |
|
return out_str.split("MOVE:")[-1].replace("<|endoftext|>", "").strip() |
|
|
|
|
|
board = chess.Board() |
|
model = AutoModelForCausalLM.from_pretrained("yp-edu/gpt2-stockfish-debug") |
|
tokenizer = AutoTokenizer.from_pretrained("yp-edu/gpt2-stockfish-debug") # or "gpt2" |
|
tokenizer.pad_token = tokenizer.eos_token |
|
for i in range(100): |
|
fen = board.fen() |
|
move_uci = next_move(model, tokenizer, fen) |
|
try: |
|
print(move_uci) |
|
move = chess.Move.from_uci(move_uci) |
|
if move not in board.legal_moves: |
|
raise chess.IllegalMoveError |
|
board.push(move) |
|
outcome = board.outcome() |
|
if outcome is not None: |
|
print(board) |
|
print(outcome.result()) |
|
break |
|
except chess.IllegalMoveError: |
|
print(board) |
|
print("Illegal move", i) |
|
break |
|
else: |
|
print(board) |
|
``` |
|
|