import numpy as np |
from tensorflow import keras |
from tensorflow.keras import layers |
import json |
import random |
import argparse |
import os |
import hashlib |
from pathlib import Path |
class TicTacToeTrainer: |
def __init__(self): |
self.model = None |
def create_model(self): |
"""Tworzy model sieci neuronowej""" |
model = keras.Sequential([ |
layers.Dense(128, activation='relu', input_shape=(9,)), |
layers.Dropout(0.3), |
layers.Dense(64, activation='relu'), |
layers.Dropout(0.3), |
layers.Dense(9, activation='softmax') |
]) |
model.compile( |
optimizer='adam', |
loss='categorical_crossentropy', |
metrics=['accuracy'] |
) |
return model |
def calculate_board_hash(self, board): |
"""Calculate a unique hash for the board state""" |
return hashlib.md5(str(board.tolist()).encode()).hexdigest() |
def check_two_in_line(self, board, player): |
""" |
Check if player has two in a line and return the winning move position |
Returns: Position to block or None if no blocking needed |
""" |
winning_combinations = [ |
[0, 1, 2], [3, 4, 5], [6, 7, 8], |
[0, 3, 6], [1, 4, 7], [2, 5, 8], |
[0, 4, 8], [2, 4, 6] |
] |
for combo in winning_combinations: |
line = board[combo] |
if sum(line == player) == 2 and sum(line == 0) == 1: |
return combo[list(line).index(0)] |
return None |
def check_winner(self, board): |
""" |
Sprawdza czy jest zwycięzca lub czy mamy dwa znaki w linii |
Returns: (bool, str) - (czy wygrana/potencjalna wygrana, typ sytuacji) |
""" |
winning_combinations = [ |
[0, 1, 2], [3, 4, 5], [6, 7, 8], |
[0, 3, 6], [1, 4, 7], [2, 5, 8], |
[0, 4, 8], [2, 4, 6] |
] |
for combo in winning_combinations: |
if sum(board[combo]) == 3: |
return True, "win" |
for combo in winning_combinations: |
line = board[combo] |
if sum(line == 1) == 2 and sum(line == 0) == 1: |
return True, "two_in_line" |
return False, "none" |
def generate_training_data(self, num_games=1000): |
""" |
Generates unique training data including two-in-line positions |
""" |
X = [] |
y = [] |
games_hash_set = set() |
json_file = Path('games_data.json') |
if json_file.exists(): |
try: |
with open(json_file, 'r') as file: |
existing_games = json.load(file) |
for game in existing_games: |
games_hash_set.add(game['hash']) |
print(f"Loaded {len(existing_games)} existing games") |
except json.JSONDecodeError: |
print("Error reading JSON file. Starting with empty games list.") |
existing_games = [] |
else: |
existing_games = [] |
new_games = [] |
games_generated = 0 |
attempts = 0 |
max_attempts = num_games * 10 |
while games_generated < num_games and attempts < max_attempts: |
attempts += 1 |
board = np.zeros((9,), dtype=int) |
game_states = [] |
game_moves = [] |
full_sequence = [] |
while True: |
current_state = board.copy() |
valid_moves = np.where(board == 0)[0] |
if len(valid_moves) == 0: |
break |
move = random.choice(valid_moves) |
move_one_hot = np.zeros(9) |
move_one_hot[move] = 1 |
game_states.append(current_state.copy()) |
game_moves.append(move_one_hot) |
full_sequence.append({'player': 'X', 'move': int(move)}) |
board[move] = 1 |
is_winning, situation = self.check_winner(board) |
if is_winning or len(np.where(board == 0)[0]) == 0: |
break |
valid_moves = np.where(board == 0)[0] |
if len(valid_moves) > 0: |
blocking_move = self.check_two_in_line(board, 1) |
if blocking_move is not None and board[blocking_move] == 0: |
opponent_move = blocking_move |
else: |
opponent_move = random.choice(valid_moves) |
board[opponent_move] = -1 |
full_sequence.append({'player': 'O', 'move': int(opponent_move)}) |
game_hash = self.calculate_board_hash(board) |
is_winning, situation = self.check_winner(board) |
if game_hash not in games_hash_set and is_winning: |
games_hash_set.add(game_hash) |
games_generated += 1 |
game_data = { |
'hash': game_hash, |
'moves': full_sequence, |
'final_board': board.tolist(), |
'win': situation == "win", |
'situation': situation |
} |
new_games.append(game_data) |
X.extend(game_states) |
y.extend(game_moves) |
if games_generated % 10 == 0: |
print(f"Generated {games_generated}/{num_games} unique games") |
print(f"Last game situation: {situation}") |
all_games = existing_games + new_games |
with open(json_file, 'w') as file: |
json.dump(all_games, file, indent=2) |
print(f"\nGenerated {len(new_games)} new unique games") |
print(f"Total games in database: {len(all_games)}") |
return np.array(X), np.array(y) |
def train(self, epochs=50, games=1000, model_path='model'): |
"""Trenuje model i zapisuje go do pliku""" |
print(f"Rozpoczynam generowanie danych treningowych ({games} gier)...") |
X_train, y_train = self.generate_training_data(games) |
if len(X_train) == 0: |
print("Nie udało się wygenerować żadnych danych treningowych!") |
return |
print(f"\nWygenerowano dane treningowe: {len(X_train)} przykładów") |
print(f"Przykładowy stan planszy: {X_train[0]}") |
print(f"\nRozpoczynam trening ({epochs} epok)...") |
self.model = self.create_model() |
history = self.model.fit( |
X_train, |
y_train, |
epochs=epochs, |
batch_size=32, |
validation_split=0.1, |
verbose=1 |
) |
os.makedirs(model_path, exist_ok=True) |
self.model.save(model_path + "/model.keras") |
print(f"\nModel został zapisany w: {model_path}") |
metrics = { |
'accuracy': float(history.history['accuracy'][-1]), |
'val_accuracy': float(history.history['val_accuracy'][-1]), |
'loss': float(history.history['loss'][-1]), |
'val_loss': float(history.history['val_loss'][-1]) |
} |
print("\nWyniki treningu:") |
print(f"Dokładność: {metrics['accuracy']:.4f}") |
print(f"Dokładność walidacji: {metrics['val_accuracy']:.4f}") |
print(f"Strata: {metrics['loss']:.4f}") |
print(f"Strata walidacji: {metrics['val_loss']:.4f}") |
def main(): |
parser = argparse.ArgumentParser(description='Trenuj model AI do gry w kółko i krzyżyk') |
parser.add_argument('--epochs', type=int, default=50, |
help='Liczba epok treningu (domyślnie: 50)') |
parser.add_argument('--games', type=int, default=1000, |
help='Liczba gier treningowych (domyślnie: 1000)') |
parser.add_argument('--model-path', type=str, default='model', |
help='Ścieżka do zapisania modelu (domyślnie: "model")') |
args = parser.parse_args() |
trainer = TicTacToeTrainer() |
trainer.train(epochs=args.epochs, games=args.games, model_path=args.model_path) |
if __name__ == "__main__": |
main() |