Last commit not found
import numpy as np | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
import json | |
import random | |
import argparse | |
import os | |
import hashlib | |
from pathlib import Path | |
class TicTacToeTrainer: | |
def __init__(self): | |
self.model = None | |
def create_model(self): | |
"""Tworzy model sieci neuronowej""" | |
model = keras.Sequential([ | |
layers.Dense(128, activation='relu', input_shape=(9,)), | |
layers.Dropout(0.3), | |
layers.Dense(64, activation='relu'), | |
layers.Dropout(0.3), | |
layers.Dense(9, activation='softmax') | |
]) | |
model.compile( | |
optimizer='adam', | |
loss='categorical_crossentropy', | |
metrics=['accuracy'] | |
) | |
return model | |
def calculate_board_hash(self, board): | |
"""Calculate a unique hash for the board state""" | |
return hashlib.md5(str(board.tolist()).encode()).hexdigest() | |
def check_two_in_line(self, board, player): | |
""" | |
Check if player has two in a line and return the winning move position | |
Returns: Position to block or None if no blocking needed | |
""" | |
winning_combinations = [ | |
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Horizontal | |
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Vertical | |
[0, 4, 8], [2, 4, 6] # Diagonal | |
] | |
for combo in winning_combinations: | |
line = board[combo] | |
if sum(line == player) == 2 and sum(line == 0) == 1: | |
# Return the empty position in the line | |
return combo[list(line).index(0)] | |
return None | |
def check_winner(self, board): | |
""" | |
Sprawdza czy jest zwycięzca lub czy mamy dwa znaki w linii | |
Returns: (bool, str) - (czy wygrana/potencjalna wygrana, typ sytuacji) | |
""" | |
winning_combinations = [ | |
[0, 1, 2], [3, 4, 5], [6, 7, 8], # Horizontal | |
[0, 3, 6], [1, 4, 7], [2, 5, 8], # Vertical | |
[0, 4, 8], [2, 4, 6] # Diagonal | |
] | |
# Sprawdź pełną wygraną (3 w linii) | |
for combo in winning_combinations: | |
if sum(board[combo]) == 3: | |
return True, "win" | |
# Sprawdź czy mamy dwa w linii z pustym polem | |
for combo in winning_combinations: | |
line = board[combo] | |
if sum(line == 1) == 2 and sum(line == 0) == 1: | |
return True, "two_in_line" | |
return False, "none" | |
def generate_training_data(self, num_games=1000): | |
""" | |
Generates unique training data including two-in-line positions | |
""" | |
X = [] | |
y = [] | |
games_hash_set = set() | |
# Load existing games if file exists | |
json_file = Path('games_data.json') | |
if json_file.exists(): | |
try: | |
with open(json_file, 'r') as file: | |
existing_games = json.load(file) | |
for game in existing_games: | |
games_hash_set.add(game['hash']) | |
print(f"Loaded {len(existing_games)} existing games") | |
except json.JSONDecodeError: | |
print("Error reading JSON file. Starting with empty games list.") | |
existing_games = [] | |
else: | |
existing_games = [] | |
new_games = [] | |
games_generated = 0 | |
attempts = 0 | |
max_attempts = num_games * 10 | |
while games_generated < num_games and attempts < max_attempts: | |
attempts += 1 | |
board = np.zeros((9,), dtype=int) | |
game_states = [] | |
game_moves = [] | |
full_sequence = [] | |
while True: | |
current_state = board.copy() | |
valid_moves = np.where(board == 0)[0] | |
if len(valid_moves) == 0: | |
break | |
# Player 1 move | |
move = random.choice(valid_moves) | |
move_one_hot = np.zeros(9) | |
move_one_hot[move] = 1 | |
game_states.append(current_state.copy()) | |
game_moves.append(move_one_hot) | |
full_sequence.append({'player': 'X', 'move': int(move)}) | |
board[move] = 1 | |
# Sprawdź wygraną lub dwa w linii | |
is_winning, situation = self.check_winner(board) | |
if is_winning or len(np.where(board == 0)[0]) == 0: | |
break | |
# Player 2 move (defensive) | |
valid_moves = np.where(board == 0)[0] | |
if len(valid_moves) > 0: | |
blocking_move = self.check_two_in_line(board, 1) | |
if blocking_move is not None and board[blocking_move] == 0: | |
opponent_move = blocking_move | |
else: | |
opponent_move = random.choice(valid_moves) | |
board[opponent_move] = -1 | |
full_sequence.append({'player': 'O', 'move': int(opponent_move)}) | |
# Calculate hash for the game | |
game_hash = self.calculate_board_hash(board) | |
# If game is unique and ended in a win or two-in-line | |
is_winning, situation = self.check_winner(board) | |
if game_hash not in games_hash_set and is_winning: | |
games_hash_set.add(game_hash) | |
games_generated += 1 | |
game_data = { | |
'hash': game_hash, | |
'moves': full_sequence, | |
'final_board': board.tolist(), | |
'win': situation == "win", | |
'situation': situation | |
} | |
new_games.append(game_data) | |
X.extend(game_states) | |
y.extend(game_moves) | |
if games_generated % 10 == 0: | |
print(f"Generated {games_generated}/{num_games} unique games") | |
print(f"Last game situation: {situation}") | |
all_games = existing_games + new_games | |
with open(json_file, 'w') as file: | |
json.dump(all_games, file, indent=2) | |
print(f"\nGenerated {len(new_games)} new unique games") | |
print(f"Total games in database: {len(all_games)}") | |
return np.array(X), np.array(y) | |
def train(self, epochs=50, games=1000, model_path='model'): | |
"""Trenuje model i zapisuje go do pliku""" | |
print(f"Rozpoczynam generowanie danych treningowych ({games} gier)...") | |
X_train, y_train = self.generate_training_data(games) | |
if len(X_train) == 0: | |
print("Nie udało się wygenerować żadnych danych treningowych!") | |
return | |
print(f"\nWygenerowano dane treningowe: {len(X_train)} przykładów") | |
print(f"Przykładowy stan planszy: {X_train[0]}") | |
print(f"\nRozpoczynam trening ({epochs} epok)...") | |
self.model = self.create_model() | |
history = self.model.fit( | |
X_train, | |
y_train, | |
epochs=epochs, | |
batch_size=32, | |
validation_split=0.1, | |
verbose=1 | |
) | |
# Tworzenie katalogu jeśli nie istnieje | |
os.makedirs(model_path, exist_ok=True) | |
# Zapisywanie modelu | |
self.model.save(model_path + "/model.keras") | |
print(f"\nModel został zapisany w: {model_path}") | |
# Zapisywanie metryk treningu | |
metrics = { | |
'accuracy': float(history.history['accuracy'][-1]), | |
'val_accuracy': float(history.history['val_accuracy'][-1]), | |
'loss': float(history.history['loss'][-1]), | |
'val_loss': float(history.history['val_loss'][-1]) | |
} | |
print("\nWyniki treningu:") | |
print(f"Dokładność: {metrics['accuracy']:.4f}") | |
print(f"Dokładność walidacji: {metrics['val_accuracy']:.4f}") | |
print(f"Strata: {metrics['loss']:.4f}") | |
print(f"Strata walidacji: {metrics['val_loss']:.4f}") | |
def main(): | |
parser = argparse.ArgumentParser(description='Trenuj model AI do gry w kółko i krzyżyk') | |
parser.add_argument('--epochs', type=int, default=50, | |
help='Liczba epok treningu (domyślnie: 50)') | |
parser.add_argument('--games', type=int, default=1000, | |
help='Liczba gier treningowych (domyślnie: 1000)') | |
parser.add_argument('--model-path', type=str, default='model', | |
help='Ścieżka do zapisania modelu (domyślnie: "model")') | |
args = parser.parse_args() | |
trainer = TicTacToeTrainer() | |
trainer.train(epochs=args.epochs, games=args.games, model_path=args.model_path) | |
if __name__ == "__main__": | |
main() |