Prajith04 commited on
Commit
9691cf7
·
verified ·
1 Parent(s): ab86f05

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -3
main.py CHANGED
@@ -6,7 +6,6 @@ from fastapi.staticfiles import StaticFiles
6
  from fastapi.templating import Jinja2Templates
7
  from dotenv import load_dotenv
8
  import os
9
- from transformers import AutoTokenizer
10
  # Load environment variables
11
  load_dotenv()
12
  from gliner import GLiNER
@@ -20,7 +19,6 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
20
  # Load models
21
  cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
22
  os.makedirs(cache_dir, exist_ok=True)
23
- tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", cache_dir=cache_dir) # Replace with appropriate tokenizer for GLiNER
24
  gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
25
  groq_client = Groq(api_key=GROQ_API_KEY)
26
 
@@ -32,7 +30,7 @@ def extract_entities(text):
32
  labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
33
 
34
  # Predict entities
35
- return gliner_model.predict_entities(inputs['input_ids'], labels)
36
 
37
  def validate_answer(user_query, retrieved_answer):
38
  prompt = f"""
 
6
  from fastapi.templating import Jinja2Templates
7
  from dotenv import load_dotenv
8
  import os
 
9
  # Load environment variables
10
  load_dotenv()
11
  from gliner import GLiNER
 
19
  # Load models
20
  cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
21
  os.makedirs(cache_dir, exist_ok=True)
 
22
  gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
23
  groq_client = Groq(api_key=GROQ_API_KEY)
24
 
 
30
  labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
31
 
32
  # Predict entities
33
+ return gliner_model.predict_entities(inputs, labels)
34
 
35
  def validate_answer(user_query, retrieved_answer):
36
  prompt = f"""