import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import gradio as gr

# Definizione del modello pre-addestrato
class PretrainedModel(nn.Module):
    def __init__(self, num_classes=19):
        super(PretrainedModel, self).__init__()
        
        weights = models.ResNet50_Weights.IMAGENET1K_V2
        self.model = models.resnet50(weights=weights)
        
        # Congela i layer iniziali
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Sostituisci l'ultimo layer
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

# Crea un'istanza del modello
model = PretrainedModel(num_classes=19)

# Carica i pesi con `weights_only=True` per evitare problemi di sicurezza
state_dict = torch.load('model_v11.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval()  # Imposta il modello in modalità valutazione

# Trasformazioni per l'immagine
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def classify_image(img):
    # Preprocessa l'immagine
    img_tensor = preprocess(img).unsqueeze(0)  # Aggiunge una dimensione per il batch
    with torch.no_grad():
        output = model(img_tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_class_index = probabilities.argmax().item()
        predicted_probability = probabilities[0][predicted_class_index].item()
        
    return f"Class {predicted_class_index}, Confidence: {predicted_probability:.4f}"

# Configura Gradio per l'accesso pubblico
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil"),
    outputs="text"
)

iface.launch(share=True)  # Abilita l'accesso pubblico e la coda