|
import gradio as gr |
|
from transformers import AutoTokenizer, TFAutoModel |
|
import joblib |
|
from huggingface_hub import hf_hub_download |
|
import json |
|
from typing import List |
|
|
|
class TextClassificationPipeline: |
|
def __init__(self, tokenizer, distilbert_model, xgb_model): |
|
self.tokenizer = tokenizer |
|
self.distilbert_model = distilbert_model |
|
self.xgb_model = xgb_model |
|
|
|
def __call__(self, text): |
|
inputs = self.tokenizer(text, return_tensors="tf", padding=True, truncation=True, max_length=128) |
|
outputs = self.distilbert_model(**inputs) |
|
embeddings = outputs.last_hidden_state[:, 0, :].numpy() |
|
prediction = self.xgb_model.predict(embeddings) |
|
return prediction |
|
|
|
HF_MODEL_ID = "AndresR2909/suicide-related-text-classification_distilbert_xgboost" |
|
|
|
|
|
xgboost_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xgboost_model.joblib") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID) |
|
distilbert_model = TFAutoModel.from_pretrained(HF_MODEL_ID) |
|
xgb_model = joblib.load(xgboost_path) |
|
|
|
|
|
pipeline = TextClassificationPipeline(tokenizer, distilbert_model, xgb_model) |
|
|
|
|
|
|
|
def predict_api(texts: List[str]) -> List[int]: |
|
|
|
predictions = [pipeline(text)[0] for text in texts] |
|
|
|
return predictions |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_api, |
|
inputs=gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆..."), |
|
outputs="text", |
|
title="Clasificaci贸n de Texto (API)", |
|
description="Introduce un texto para obtener una predicci贸n en formato JSON.", |
|
) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆...", label="Entrada de texto") |
|
gr.Textbox(label="Resultado", interactive=False) |
|
|
|
|
|
app = gr.mount_gradio_app(iface, blocks=blocks, path="/api/predict") |
|
|
|
|
|
iface.launch(share=True) |