Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch | |
# Load the tokenizer and model | |
model_name = "roberta-large" # Replace with your trained model if uploaded | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
# Define the prediction function | |
def classify_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
outputs = model(**inputs) | |
probabilities = torch.softmax(outputs.logits, dim=-1).tolist()[0] | |
labels = ["Speculating War Outcomes", "Discrediting Ukraine", "Praise of Russia"] # Replace with your actual labels | |
predictions = {label: prob for label, prob in zip(labels, probabilities)} | |
return predictions | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=gr.Textbox(lines=3, placeholder="Enter text to classify..."), | |
outputs=gr.Label(num_top_classes=3), | |
title="Narrative Classification", | |
description="Classify text into predefined narrative categories." | |
) | |
# Launch the app | |
demo.launch() | |