juanpablosanchez commited on
Commit
bdc0325
·
verified ·
1 Parent(s): 56b8a63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -4
app.py CHANGED
@@ -1,26 +1,56 @@
1
- import gradio as gr
 
2
  from transformers import AutoModelForTokenClassification, AutoTokenizer
3
  import torch
 
 
 
 
 
 
4
 
5
  # Cargar el modelo y el tokenizador
6
  model_name = "EmergentMethods/gliner_medium_news-v2.1"
7
  model = AutoModelForTokenClassification.from_pretrained(model_name)
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- def predict(text):
 
 
 
 
 
 
 
11
  inputs = tokenizer(text, return_tensors="pt")
12
 
13
  # Realizar la inferencia
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
 
 
17
  logits = outputs.logits
18
  predictions = torch.argmax(logits, dim=2)
19
 
 
20
  id2label = model.config.id2label
21
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
22
  entities = [{"token": token, "label": id2label[prediction.item()]} for token, prediction in zip(tokens, predictions[0])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  return entities
24
 
25
- demo = gr.Interface(fn=predict, inputs="text", outputs="json")
26
- demo.launch()
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
  from transformers import AutoModelForTokenClassification, AutoTokenizer
4
  import torch
5
+ import gradio as gr
6
+ from threading import Thread
7
+ import uvicorn
8
+
9
+ # Configurar FastAPI
10
+ app = FastAPI()
11
 
12
  # Cargar el modelo y el tokenizador
13
  model_name = "EmergentMethods/gliner_medium_news-v2.1"
14
  model = AutoModelForTokenClassification.from_pretrained(model_name)
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
+ class TextInput(BaseModel):
18
+ text: str
19
+
20
+ @app.post("/predict")
21
+ async def predict(input: TextInput):
22
+ text = input.text
23
+
24
+ # Tokenizar el texto
25
  inputs = tokenizer(text, return_tensors="pt")
26
 
27
  # Realizar la inferencia
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
 
31
+ # Procesar los resultados
32
  logits = outputs.logits
33
  predictions = torch.argmax(logits, dim=2)
34
 
35
+ # Mapear etiquetas
36
  id2label = model.config.id2label
37
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
38
  entities = [{"token": token, "label": id2label[prediction.item()]} for token, prediction in zip(tokens, predictions[0])]
39
+
40
+ return {"entities": entities}
41
+
42
+ # Iniciar el servidor de FastAPI en un hilo separado
43
+ def start_api():
44
+ uvicorn.run(app, host="0.0.0.0", port=8000)
45
+
46
+ api_thread = Thread(target=start_api, daemon=True)
47
+ api_thread.start()
48
+
49
+ # Configurar Gradio
50
+ def predict_gradio(text):
51
+ response = requests.post("http://localhost:8000/predict", json={"text": text})
52
+ entities = response.json().get("entities", [])
53
  return entities
54
 
55
+ gr.Interface(fn=predict_gradio, inputs="text", outputs="json").launch()
56
+