Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -6,6 +6,7 @@ from fastapi.staticfiles import StaticFiles
|
|
6 |
from fastapi.templating import Jinja2Templates
|
7 |
from dotenv import load_dotenv
|
8 |
import os
|
|
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
11 |
from gliner import GLiNER
|
@@ -19,14 +20,19 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
19 |
# Load models
|
20 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
21 |
os.makedirs(cache_dir, exist_ok=True)
|
|
|
22 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
23 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
24 |
|
25 |
init_qdrant_collection()
|
26 |
|
27 |
def extract_entities(text):
|
|
|
|
|
28 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
29 |
-
|
|
|
|
|
30 |
|
31 |
def validate_answer(user_query, retrieved_answer):
|
32 |
prompt = f"""
|
|
|
6 |
from fastapi.templating import Jinja2Templates
|
7 |
from dotenv import load_dotenv
|
8 |
import os
|
9 |
+
from transformers import AutoTokenizer
|
10 |
# Load environment variables
|
11 |
load_dotenv()
|
12 |
from gliner import GLiNER
|
|
|
20 |
# Load models
|
21 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
22 |
os.makedirs(cache_dir, exist_ok=True)
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", cache_dir=cache_dir) # Replace with appropriate tokenizer for GLiNER
|
24 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
25 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
26 |
|
27 |
init_qdrant_collection()
|
28 |
|
29 |
def extract_entities(text):
|
30 |
+
# Tokenize the input text first
|
31 |
+
inputs = tokenizer(text, return_tensors="pt") # Assuming PyTorch backend
|
32 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
33 |
+
|
34 |
+
# Predict entities
|
35 |
+
return gliner_model.predict_entities(inputs['input_ids'], labels)
|
36 |
|
37 |
def validate_answer(user_query, retrieved_answer):
|
38 |
prompt = f"""
|