Prajith04 commited on
Commit
74fa4c4
·
verified ·
1 Parent(s): c2fcac9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -1
main.py CHANGED
@@ -6,6 +6,7 @@ from fastapi.staticfiles import StaticFiles
6
  from fastapi.templating import Jinja2Templates
7
  from dotenv import load_dotenv
8
  import os
 
9
  # Load environment variables
10
  load_dotenv()
11
  from gliner import GLiNER
@@ -19,14 +20,19 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
19
  # Load models
20
  cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
21
  os.makedirs(cache_dir, exist_ok=True)
 
22
  gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
23
  groq_client = Groq(api_key=GROQ_API_KEY)
24
 
25
  init_qdrant_collection()
26
 
27
  def extract_entities(text):
 
 
28
  labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
29
- return gliner_model.predict_entities(text, labels)
 
 
30
 
31
  def validate_answer(user_query, retrieved_answer):
32
  prompt = f"""
 
6
  from fastapi.templating import Jinja2Templates
7
  from dotenv import load_dotenv
8
  import os
9
+ from transformers import AutoTokenizer
10
  # Load environment variables
11
  load_dotenv()
12
  from gliner import GLiNER
 
20
  # Load models
21
  cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
22
  os.makedirs(cache_dir, exist_ok=True)
23
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", cache_dir=cache_dir) # Replace with appropriate tokenizer for GLiNER
24
  gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
25
  groq_client = Groq(api_key=GROQ_API_KEY)
26
 
27
  init_qdrant_collection()
28
 
29
  def extract_entities(text):
30
+ # Tokenize the input text first
31
+ inputs = tokenizer(text, return_tensors="pt") # Assuming PyTorch backend
32
  labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
33
+
34
+ # Predict entities
35
+ return gliner_model.predict_entities(inputs['input_ids'], labels)
36
 
37
  def validate_answer(user_query, retrieved_answer):
38
  prompt = f"""