|
import numpy as np |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
import json |
|
import random |
|
import argparse |
|
import os |
|
import hashlib |
|
from pathlib import Path |
|
|
|
class TicTacToeTrainer: |
|
def __init__(self): |
|
self.model = None |
|
|
|
def create_model(self): |
|
"""Tworzy model sieci neuronowej""" |
|
model = keras.Sequential([ |
|
layers.Dense(128, activation='relu', input_shape=(9,)), |
|
layers.Dropout(0.3), |
|
layers.Dense(64, activation='relu'), |
|
layers.Dropout(0.3), |
|
layers.Dense(9, activation='softmax') |
|
]) |
|
|
|
model.compile( |
|
optimizer='adam', |
|
loss='categorical_crossentropy', |
|
metrics=['accuracy'] |
|
) |
|
return model |
|
|
|
def calculate_board_hash(self, board): |
|
"""Calculate a unique hash for the board state""" |
|
return hashlib.md5(str(board.tolist()).encode()).hexdigest() |
|
|
|
def check_two_in_line(self, board, player): |
|
""" |
|
Check if player has two in a line and return the winning move position |
|
Returns: Position to block or None if no blocking needed |
|
""" |
|
winning_combinations = [ |
|
[0, 1, 2], [3, 4, 5], [6, 7, 8], |
|
[0, 3, 6], [1, 4, 7], [2, 5, 8], |
|
[0, 4, 8], [2, 4, 6] |
|
] |
|
|
|
for combo in winning_combinations: |
|
line = board[combo] |
|
if sum(line == player) == 2 and sum(line == 0) == 1: |
|
|
|
return combo[list(line).index(0)] |
|
return None |
|
|
|
def check_winner(self, board): |
|
""" |
|
Sprawdza czy jest zwycięzca lub czy mamy dwa znaki w linii |
|
Returns: (bool, str) - (czy wygrana/potencjalna wygrana, typ sytuacji) |
|
""" |
|
winning_combinations = [ |
|
[0, 1, 2], [3, 4, 5], [6, 7, 8], |
|
[0, 3, 6], [1, 4, 7], [2, 5, 8], |
|
[0, 4, 8], [2, 4, 6] |
|
] |
|
|
|
|
|
for combo in winning_combinations: |
|
if sum(board[combo]) == 3: |
|
return True, "win" |
|
|
|
|
|
for combo in winning_combinations: |
|
line = board[combo] |
|
if sum(line == 1) == 2 and sum(line == 0) == 1: |
|
return True, "two_in_line" |
|
|
|
return False, "none" |
|
|
|
def generate_training_data(self, num_games=1000): |
|
""" |
|
Generates unique training data including two-in-line positions |
|
""" |
|
X = [] |
|
y = [] |
|
games_hash_set = set() |
|
|
|
|
|
json_file = Path('games_data.json') |
|
if json_file.exists(): |
|
try: |
|
with open(json_file, 'r') as file: |
|
existing_games = json.load(file) |
|
for game in existing_games: |
|
games_hash_set.add(game['hash']) |
|
print(f"Loaded {len(existing_games)} existing games") |
|
except json.JSONDecodeError: |
|
print("Error reading JSON file. Starting with empty games list.") |
|
existing_games = [] |
|
else: |
|
existing_games = [] |
|
|
|
new_games = [] |
|
games_generated = 0 |
|
attempts = 0 |
|
max_attempts = num_games * 10 |
|
|
|
while games_generated < num_games and attempts < max_attempts: |
|
attempts += 1 |
|
board = np.zeros((9,), dtype=int) |
|
game_states = [] |
|
game_moves = [] |
|
full_sequence = [] |
|
|
|
while True: |
|
current_state = board.copy() |
|
valid_moves = np.where(board == 0)[0] |
|
|
|
if len(valid_moves) == 0: |
|
break |
|
|
|
|
|
move = random.choice(valid_moves) |
|
move_one_hot = np.zeros(9) |
|
move_one_hot[move] = 1 |
|
|
|
game_states.append(current_state.copy()) |
|
game_moves.append(move_one_hot) |
|
full_sequence.append({'player': 'X', 'move': int(move)}) |
|
|
|
board[move] = 1 |
|
|
|
|
|
is_winning, situation = self.check_winner(board) |
|
if is_winning or len(np.where(board == 0)[0]) == 0: |
|
break |
|
|
|
|
|
valid_moves = np.where(board == 0)[0] |
|
if len(valid_moves) > 0: |
|
blocking_move = self.check_two_in_line(board, 1) |
|
if blocking_move is not None and board[blocking_move] == 0: |
|
opponent_move = blocking_move |
|
else: |
|
opponent_move = random.choice(valid_moves) |
|
board[opponent_move] = -1 |
|
full_sequence.append({'player': 'O', 'move': int(opponent_move)}) |
|
|
|
|
|
game_hash = self.calculate_board_hash(board) |
|
|
|
|
|
is_winning, situation = self.check_winner(board) |
|
if game_hash not in games_hash_set and is_winning: |
|
games_hash_set.add(game_hash) |
|
games_generated += 1 |
|
|
|
game_data = { |
|
'hash': game_hash, |
|
'moves': full_sequence, |
|
'final_board': board.tolist(), |
|
'win': situation == "win", |
|
'situation': situation |
|
} |
|
new_games.append(game_data) |
|
|
|
X.extend(game_states) |
|
y.extend(game_moves) |
|
|
|
if games_generated % 10 == 0: |
|
print(f"Generated {games_generated}/{num_games} unique games") |
|
print(f"Last game situation: {situation}") |
|
|
|
all_games = existing_games + new_games |
|
|
|
with open(json_file, 'w') as file: |
|
json.dump(all_games, file, indent=2) |
|
|
|
print(f"\nGenerated {len(new_games)} new unique games") |
|
print(f"Total games in database: {len(all_games)}") |
|
|
|
return np.array(X), np.array(y) |
|
|
|
|
|
|
|
def train(self, epochs=50, games=1000, model_path='model'): |
|
"""Trenuje model i zapisuje go do pliku""" |
|
print(f"Rozpoczynam generowanie danych treningowych ({games} gier)...") |
|
X_train, y_train = self.generate_training_data(games) |
|
|
|
if len(X_train) == 0: |
|
print("Nie udało się wygenerować żadnych danych treningowych!") |
|
return |
|
|
|
print(f"\nWygenerowano dane treningowe: {len(X_train)} przykładów") |
|
print(f"Przykładowy stan planszy: {X_train[0]}") |
|
|
|
print(f"\nRozpoczynam trening ({epochs} epok)...") |
|
self.model = self.create_model() |
|
|
|
history = self.model.fit( |
|
X_train, |
|
y_train, |
|
epochs=epochs, |
|
batch_size=32, |
|
validation_split=0.1, |
|
verbose=1 |
|
) |
|
|
|
|
|
os.makedirs(model_path, exist_ok=True) |
|
|
|
|
|
self.model.save(model_path + "/model.keras") |
|
print(f"\nModel został zapisany w: {model_path}") |
|
|
|
|
|
metrics = { |
|
'accuracy': float(history.history['accuracy'][-1]), |
|
'val_accuracy': float(history.history['val_accuracy'][-1]), |
|
'loss': float(history.history['loss'][-1]), |
|
'val_loss': float(history.history['val_loss'][-1]) |
|
} |
|
|
|
print("\nWyniki treningu:") |
|
print(f"Dokładność: {metrics['accuracy']:.4f}") |
|
print(f"Dokładność walidacji: {metrics['val_accuracy']:.4f}") |
|
print(f"Strata: {metrics['loss']:.4f}") |
|
print(f"Strata walidacji: {metrics['val_loss']:.4f}") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Trenuj model AI do gry w kółko i krzyżyk') |
|
parser.add_argument('--epochs', type=int, default=50, |
|
help='Liczba epok treningu (domyślnie: 50)') |
|
parser.add_argument('--games', type=int, default=1000, |
|
help='Liczba gier treningowych (domyślnie: 1000)') |
|
parser.add_argument('--model-path', type=str, default='model', |
|
help='Ścieżka do zapisania modelu (domyślnie: "model")') |
|
|
|
args = parser.parse_args() |
|
|
|
trainer = TicTacToeTrainer() |
|
trainer.train(epochs=args.epochs, games=args.games, model_path=args.model_path) |
|
|
|
if __name__ == "__main__": |
|
main() |