midrees2806 commited on
Commit
018c04e
·
verified ·
1 Parent(s): 21d81a3

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +30 -0
rag.py CHANGED
@@ -18,6 +18,10 @@ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
18
  # Load models and dataset
19
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
20
 
 
 
 
 
21
  # Load dataset (automatically using the path)
22
  with open('dataset.json', 'r') as f:
23
  dataset = json.load(f)
@@ -27,6 +31,32 @@ dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
27
  dataset_answers = [item.get("response", "") for item in dataset]
28
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
31
  try:
32
  chat_completion = groq_client.chat.completions.create(
 
18
  # Load models and dataset
19
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
20
 
21
+ # Configuration
22
+ HF_DATASET_REPO = "midrees2806/unmatched_queries" # Your dataset repo
23
+ HF_TOKEN = os.getenv("HF_TOKEN") # From Space secrets
24
+
25
  # Load dataset (automatically using the path)
26
  with open('dataset.json', 'r') as f:
27
  dataset = json.load(f)
 
31
  dataset_answers = [item.get("response", "") for item in dataset]
32
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
33
 
34
+ # --- Unmatched Queries Handler ---
35
+ def manage_unmatched_queries(query: str):
36
+ """Save unmatched queries to HF Dataset with error handling"""
37
+ try:
38
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
39
+
40
+ # Load existing dataset or create new
41
+ try:
42
+ ds = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
43
+ df = ds["train"].to_pandas()
44
+ except:
45
+ df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
46
+
47
+ # Append new query (avoid duplicates)
48
+ if query not in df["Query"].values:
49
+ new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
50
+ df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
51
+
52
+ # Push to Hub
53
+ updated_ds = Dataset.from_pandas(df)
54
+ updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
55
+ except Exception as e:
56
+ print(f"Failed to save query: {e}")
57
+
58
+ # --- Enhanced LLM Query ---
59
+
60
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
61
  try:
62
  chat_completion = groq_client.chat.completions.create(