Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
# β Load model & tokenizer | |
model_name = "microsoft/deberta-v3-base" # Change if needed | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=16) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# β Define prediction function | |
def predict_mbti(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=256) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=1).cpu().item() | |
# Mapping predicted labels back to MBTI types | |
mbti_types = [ | |
"INFJ", "ENTP", "INTP", "INTJ", "ENTJ", "ENFJ", "INFP", "ENFP", | |
"ISFP", "ISTP", "ISFJ", "ISTJ", "ESTP", "ESFP", "ESTJ", "ESFJ" | |
] | |
return mbti_types[predictions] | |
# β Create Gradio UI | |
interface = gr.Interface( | |
fn=predict_mbti, | |
inputs=gr.Textbox(lines=3, placeholder="Enter a text to predict MBTI type"), | |
outputs="text", | |
title="MBTI Personality Predictor", | |
description="Enter a text and get the predicted MBTI personality type." | |
) | |
# β Launch app | |
if __name__ == "__main__": | |
interface.launch() | |