avsolatorio commited on
Commit
3f9fe25
1 Parent(s): 042389b

Signed-off-by: Aivin V. Solatorio <[email protected]>

Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import json
4
  import gradio as gr
5
  import pycountry
 
6
  from datetime import datetime
7
  from typing import Dict, Union
8
  from gliner import GLiNER
@@ -17,16 +18,20 @@ print(f"Cache directory: {_CACHE_DIR}")
17
 
18
 
19
  def get_model(model_name: str = None):
 
 
20
  if model_name is None:
21
  model_name = "urchade/gliner_base"
22
 
23
  global _MODEL
24
 
25
  if _MODEL.get(model_name) is None:
26
- start = datetime.now()
27
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
28
- _MODEL[model_name].to("cuda")
29
- print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
 
 
 
30
 
31
  return _MODEL[model_name]
32
 
@@ -38,9 +43,10 @@ def get_country(country_name: str):
38
  return None
39
 
40
 
41
- @spaces.GPU
42
- def predict_entities(model, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
43
  start = datetime.now()
 
44
 
45
  if isinstance(labels, str):
46
  labels = [i.strip() for i in labels.split(",")]
@@ -55,8 +61,7 @@ def predict_entities(model, query: str, labels: Union[str, list], threshold: flo
55
  def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
56
 
57
  entities = []
58
- model = get_model(model_name)
59
- _entities = predict_entities(model=model, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
60
 
61
  for entity in _entities:
62
  if entity["label"] == "country":
 
3
  import json
4
  import gradio as gr
5
  import pycountry
6
+ import torch
7
  from datetime import datetime
8
  from typing import Dict, Union
9
  from gliner import GLiNER
 
18
 
19
 
20
  def get_model(model_name: str = None):
21
+ start = datetime.now()
22
+
23
  if model_name is None:
24
  model_name = "urchade/gliner_base"
25
 
26
  global _MODEL
27
 
28
  if _MODEL.get(model_name) is None:
 
29
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
30
+
31
+ if torch.cuda.is_available() and _MODEL[model_name].device.type.startswith("cuda"):
32
+ _MODEL[model_name] = _MODEL[model_name].to("cuda")
33
+
34
+ print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
35
 
36
  return _MODEL[model_name]
37
 
 
43
  return None
44
 
45
 
46
+ @spaces.GPU(enable_queue=True)
47
+ def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
48
  start = datetime.now()
49
+ model = get_model(model_name)
50
 
51
  if isinstance(labels, str):
52
  labels = [i.strip() for i in labels.split(",")]
 
61
  def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
62
 
63
  entities = []
64
+ _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
 
65
 
66
  for entity in _entities:
67
  if entity["label"] == "country":