brepositorium
add model
8424a62
raw
history blame
1.11 kB
import torch
import gradio as gr
from typing import Tuple, Dict
from torchvision import models
import torch.nn as nn
from model import get_transforms, create_effnetb2_model
model = create_effnetb2_model(num_classes=3)
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()
cns = ['negative', 'neutral', 'positive']
def predict(img) -> Tuple[Dict, float]:
transform = get_transforms()
img = transform(img).unsqueeze(0)
with torch.inference_mode():
pred_probs = torch.softmax(model(img), dim=1)
pred_labels_and_probs = {cns[i]: float(pred_probs[0][i]) for i in range(len(cns))}
return pred_labels_and_probs
title = "Effnetb2 Sentiment Analysis"
description = "An EfficientNetB2 feature extractor computer vision model to analyse image sentiment."
demo = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Label(num_top_classes=3, label="Predictions")],
title=title,
description=description)
if __name__ == "__main__":
demo.launch()