docker01 / app.py
juanpablosanchez's picture
Update app.py
bdc0325 verified
raw
history blame
1.63 kB
from fastapi import FastAPI, Request
from pydantic import BaseModel
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
import gradio as gr
from threading import Thread
import uvicorn
# Configurar FastAPI
app = FastAPI()
# Cargar el modelo y el tokenizador
model_name = "EmergentMethods/gliner_medium_news-v2.1"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class TextInput(BaseModel):
text: str
@app.post("/predict")
async def predict(input: TextInput):
text = input.text
# Tokenizar el texto
inputs = tokenizer(text, return_tensors="pt")
# Realizar la inferencia
with torch.no_grad():
outputs = model(**inputs)
# Procesar los resultados
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
# Mapear etiquetas
id2label = model.config.id2label
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
entities = [{"token": token, "label": id2label[prediction.item()]} for token, prediction in zip(tokens, predictions[0])]
return {"entities": entities}
# Iniciar el servidor de FastAPI en un hilo separado
def start_api():
uvicorn.run(app, host="0.0.0.0", port=8000)
api_thread = Thread(target=start_api, daemon=True)
api_thread.start()
# Configurar Gradio
def predict_gradio(text):
response = requests.post("http://localhost:8000/predict", json={"text": text})
entities = response.json().get("entities", [])
return entities
gr.Interface(fn=predict_gradio, inputs="text", outputs="json").launch()