AndresR2909's picture
Update app.py
69432bb verified
raw
history blame
2.09 kB
import gradio as gr
from transformers import AutoTokenizer, TFAutoModel
import joblib
from huggingface_hub import hf_hub_download
import json
from typing import List
class TextClassificationPipeline:
def __init__(self, tokenizer, distilbert_model, xgb_model):
self.tokenizer = tokenizer
self.distilbert_model = distilbert_model
self.xgb_model = xgb_model
def __call__(self, text):
inputs = self.tokenizer(text, return_tensors="tf", padding=True, truncation=True, max_length=128)
outputs = self.distilbert_model(**inputs)
embeddings = outputs.last_hidden_state[:, 0, :].numpy()
prediction = self.xgb_model.predict(embeddings)
return prediction
HF_MODEL_ID = "AndresR2909/suicide-related-text-classification_distilbert_xgboost"
# Descargar modelo
xgboost_path = hf_hub_download(repo_id=HF_MODEL_ID, filename="xgboost_model.joblib")
# Cargar los modelos
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
distilbert_model = TFAutoModel.from_pretrained(HF_MODEL_ID)
xgb_model = joblib.load(xgboost_path)
# Crear el pipeline una sola vez al inicio
pipeline = TextClassificationPipeline(tokenizer, distilbert_model, xgb_model)
# Funci贸n para la API
def predict_api(texts: List[str]) -> List[int]:
# Hacer predicciones (usando el pipeline precargado)
predictions = [pipeline(text)[0] for text in texts]
return predictions
# Crear la interfaz (opcional)
iface = gr.Interface(
fn=predict_api,
inputs=gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆..."),
outputs="text",
title="Clasificaci贸n de Texto (API)",
description="Introduce un texto para obtener una predicci贸n en formato JSON.",
)
# Crear un bloque de gradio para el API
with gr.Blocks() as blocks:
gr.Textbox(lines=2, placeholder="Introduce un texto aqu铆...", label="Entrada de texto")
gr.Textbox(label="Resultado", interactive=False)
# Montar la API
app = gr.mount_gradio_app(iface, blocks=blocks, path="/api/predict")
# Lanzar la interfaz gr谩fica si deseas compartirla
iface.launch(share=True)