Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
model_path = "LukeJacob2023/religion-classifier" | |
# 分类名称 | |
labels = ["基督教", "佛教", "无信仰"] | |
# 1. 加载tokenizer和模型 | |
tokenizer = BertTokenizer.from_pretrained(model_path) | |
model = BertForSequenceClassification.from_pretrained(model_path) | |
# 确保模型在评估模式 | |
model.eval() | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.nn.functional.softmax(outputs.logits / 5.0, dim=-1)[0] | |
return {label: float(prob) for label, prob in zip(labels, probabilities)} | |
# 创建Gradio接口 | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Textbox(lines=2, label="Input Text"), | |
outputs=gr.Label(num_top_classes=3, label="Predictions", flagged=False), | |
title="Religion Classification", | |
description="请输入内容(繁体中文)" | |
) | |
# 启动Gradio应用 | |
iface.launch() | |