docker01 / app.py
juanpablosanchez's picture
Update app.py
f689a80 verified
raw
history blame
1.52 kB
import gradio as gr
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import uvicorn
# Configurar FastAPI
app = FastAPI()
# Cargar el modelo y el tokenizador
model_name = "mdarhri00/named-entity-recognition"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
class TextInput(BaseModel):
text: str
@app.post("/predict")
async def predict(input: TextInput):
text = input.text
# Tokenizar el texto
inputs = tokenizer(text, return_tensors="pt")
# Realizar la inferencia
with torch.no_grad():
outputs = model(**inputs)
# Procesar los resultados
logits = outputs.logits
predictions = torch.argmax(logits, dim=2)
# Mapear etiquetas
id2label = model.config.id2label
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
entities = [{"token": token, "label": id2label[prediction.item()]} for token, prediction in zip(tokens, predictions[0])]
return {"entities": entities}
# Configurar Gradio
def predict_gradio(text):
response = requests.post("http://localhost:8000/predict", json={"text": text})
entities = response.json().get("entities", [])
return entities
demo = gr.Interface(fn=predict_gradio, inputs="text", outputs="json")
demo.launch(share=True)
# Iniciar el servidor de FastAPI
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)