Grandediw's picture
Update app.py
cab2759 verified
raw
history blame
4.41 kB
import gradio as gr
import pandas as pd
import numpy as np
from xgboost import Booster, DMatrix
import requests
# Generate card_numbers
card_numbers = {
"Archers": 1, "Archer Queen": 2, "Baby Dragon": 3, "Balloon": 4, "Bandit": 5, "Barbarians": 6,
"Bats": 7, "Battle Healer": 8, "Battle Ram": 9, "Bomber": 10, "Bowler": 11, "Bush Goblins": 12,
# Add other cards...
}
# Generate card_images with normalized filenames
base_url = "https://raw.githubusercontent.com/RoyaleAPI/cr-api-assets/master/cards/"
card_images = {
card_name: f"{base_url}{card_name.lower().replace(' ', '-').replace('/', '-').replace('.', '')}.png"
for card_name in card_numbers.keys()
}
# Validate URLs
valid_card_images = {}
for card, url in card_images.items():
response = requests.head(url)
if response.status_code == 200:
valid_card_images[card] = url
# Define model-related functions
MODEL_PATH = "model.json"
def load_model(model_path):
"""Load the saved XGBoost model."""
model = Booster()
model.load_model(model_path)
return model
def deck_to_ids(deck, mapping):
"""Convert card names to IDs based on the mapping."""
return [mapping.get(card, 0) - 1 for card in deck]
def preprocess_deck(deck):
"""Prepare the selected deck for the model."""
deck_ids = deck_to_ids(deck, card_numbers)
num_choices = len(card_numbers)
one_hot = np.zeros(num_choices, dtype=int)
one_hot[np.array(deck_ids)] = 1
features = np.concatenate(([0, 0], one_hot))
return pd.DataFrame([features])
def predict_outcome(opponent_deck):
"""Make a prediction based on the opponent's deck."""
deck_data = preprocess_deck(opponent_deck)
dmatrix = DMatrix(deck_data)
prediction = model.predict(dmatrix)
return f"Probability of Winning: {prediction[0] * 100:.2f}%"
# Load the model
model = load_model(MODEL_PATH)
# Create Gradio Interface
with gr.Blocks(css="""
.card-container img {
width: 80px !important;
height: 80px !important;
object-fit: contain;
margin: 5px auto;
}
.card-container {
text-align: center;
padding: 5px;
border: 1px solid #ddd;
border-radius: 8px;
margin: 5px;
}
.checkbox-container {
margin-top: 5px;
}
""") as interface:
gr.Markdown("## Clash Royale Prediction")
gr.Markdown("Select 8 cards from the opponent's deck to predict the probability of winning!")
# State for tracking selected cards
selected_cards_display = gr.Markdown("Selected cards: 0/8")
def update_selection(*checkbox_values):
selected_count = sum(checkbox_values)
return f"Selected cards: {selected_count}/8"
# Create card grid using rows and columns
cards_per_row = 8
cards_list = list(valid_card_images.items())
all_checkboxes = []
for i in range(0, len(cards_list), cards_per_row):
with gr.Row():
for card, url in cards_list[i:i + cards_per_row]:
with gr.Column(elem_classes="card-container"):
gr.Image(value=url, show_label=False)
checkbox = gr.Checkbox(label=card, elem_classes="checkbox-container")
all_checkboxes.append(checkbox)
with gr.Row():
result = gr.Textbox(label="Prediction Result:", interactive=False)
clear_btn = gr.Button("Clear Selection")
predict_btn = gr.Button("Make Prediction", variant="primary")
def clear_selection():
return [False] * len(all_checkboxes) + ["Selected cards: 0/8", ""]
clear_btn.click(
clear_selection,
outputs=all_checkboxes + [selected_cards_display, result]
)
def validate_and_predict(*checkbox_values):
selected_cards = [
card for card, checked in zip(valid_card_images.keys(), checkbox_values)
if checked
]
if len(selected_cards) != 8:
return f"Error: Please select exactly 8 cards. You selected {len(selected_cards)}."
return predict_outcome(selected_cards)
predict_btn.click(
validate_and_predict,
inputs=all_checkboxes,
outputs=result
)
# Update the card count display
for checkbox in all_checkboxes:
checkbox.change(
update_selection,
inputs=all_checkboxes,
outputs=selected_cards_display
)
interface.launch()