Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -6,7 +6,6 @@ from fastapi.staticfiles import StaticFiles
|
|
6 |
from fastapi.templating import Jinja2Templates
|
7 |
from dotenv import load_dotenv
|
8 |
import os
|
9 |
-
from transformers import AutoTokenizer
|
10 |
# Load environment variables
|
11 |
load_dotenv()
|
12 |
from gliner import GLiNER
|
@@ -20,7 +19,6 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
20 |
# Load models
|
21 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
22 |
os.makedirs(cache_dir, exist_ok=True)
|
23 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", cache_dir=cache_dir) # Replace with appropriate tokenizer for GLiNER
|
24 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
25 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
26 |
|
@@ -32,7 +30,7 @@ def extract_entities(text):
|
|
32 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
33 |
|
34 |
# Predict entities
|
35 |
-
return gliner_model.predict_entities(inputs
|
36 |
|
37 |
def validate_answer(user_query, retrieved_answer):
|
38 |
prompt = f"""
|
|
|
6 |
from fastapi.templating import Jinja2Templates
|
7 |
from dotenv import load_dotenv
|
8 |
import os
|
|
|
9 |
# Load environment variables
|
10 |
load_dotenv()
|
11 |
from gliner import GLiNER
|
|
|
19 |
# Load models
|
20 |
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback to /app/cache
|
21 |
os.makedirs(cache_dir, exist_ok=True)
|
|
|
22 |
gliner_model = GLiNER.from_pretrained("urchade/gliner_medium-v2.1",cache_dir=cache_dir)
|
23 |
groq_client = Groq(api_key=GROQ_API_KEY)
|
24 |
|
|
|
30 |
labels = ["PRODUCT", "ISSUE", "PROBLEM", "SERVICE"]
|
31 |
|
32 |
# Predict entities
|
33 |
+
return gliner_model.predict_entities(inputs, labels)
|
34 |
|
35 |
def validate_answer(user_query, retrieved_answer):
|
36 |
prompt = f"""
|