midrees2806 commited on
Commit
ade1780
Β·
verified Β·
1 Parent(s): baf9f53

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +47 -40
rag.py CHANGED
@@ -1,9 +1,12 @@
1
  import json
 
 
 
2
  from sentence_transformers import SentenceTransformer, util
3
  from groq import Groq
4
- import os
5
- import csv
6
  from dotenv import load_dotenv
 
 
7
 
8
  # Load environment variables
9
  load_dotenv()
@@ -11,21 +14,34 @@ load_dotenv()
11
  # Initialize Groq client
12
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
13
 
14
- # Load similarity model
15
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
16
 
17
  # Load dataset
18
- with open('dataset.json', 'r', encoding='utf-8') as f:
19
- dataset = json.load(f)
 
 
 
 
 
 
 
20
 
21
  # Precompute embeddings
22
  dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
23
  dataset_answers = [item.get("response", "") for item in dataset]
24
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
25
 
26
- # Use absolute path for unmatched_queries.csv
27
- base_dir = os.path.dirname(os.path.abspath(__file__))
28
- file_path = os.path.join(base_dir, "unmatched_queries.csv")
 
 
 
 
 
 
29
 
30
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
31
  try:
@@ -40,48 +56,39 @@ def query_groq_llm(prompt, model_name="llama3-70b-8192"):
40
  )
41
  return chat_completion.choices[0].message.content.strip()
42
  except Exception as e:
43
- print(f"[ERROR] Groq API: {e}")
44
- return ""
45
 
46
- def log_unmatched_query(query):
47
  try:
48
- # Create file with header if not exists
49
- if not os.path.exists(file_path):
50
- with open(file_path, mode="w", newline="", encoding="utf-8") as file:
51
- writer = csv.writer(file)
52
- writer.writerow(["Unmatched Queries"])
53
-
54
- # Append unmatched query
55
- with open(file_path, mode="a", newline="", encoding="utf-8") as file:
56
- writer = csv.writer(file)
57
- writer.writerow([query])
58
- print(f"[DEBUG] Logged unmatched query: {query}")
59
-
60
  except Exception as e:
61
- print(f"[ERROR] Logging unmatched query failed: {e}")
62
 
63
  def get_best_answer(user_input):
64
  user_input_lower = user_input.lower().strip()
65
 
66
- # 🧾 Fee-specific shortcut
67
  if any(keyword in user_input_lower for keyword in ["fee", "fees", "charges", "semester fee"]):
68
  return (
69
  "πŸ’° For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
70
- "You’ll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
71
  "πŸ”— https://ue.edu.pk/allfeestructure.php"
72
  )
73
 
74
- # πŸ” Similarity matching
75
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
76
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
77
  best_match_idx = similarities.argmax().item()
78
  best_score = similarities[best_match_idx].item()
79
 
80
- # ✏️ Log unmatched queries
81
  if best_score < 0.65:
82
- log_unmatched_query(user_input)
83
 
84
- # 🧠 Prompt for LLM
85
  if best_score >= 0.65:
86
  original_answer = dataset_answers[best_match_idx]
87
  prompt = f"""As an official assistant for University of Education Lahore, provide a clear response:
@@ -95,19 +102,19 @@ def get_best_answer(user_input):
95
  Question: {user_input}
96
  Official Answer:"""
97
 
98
- # πŸ”— Query Groq LLM
99
  llm_response = query_groq_llm(prompt)
100
 
101
- # βœ‚οΈ Process LLM output
102
  if llm_response:
103
  for marker in ["Improved Answer:", "Official Answer:"]:
104
  if marker in llm_response:
105
- return llm_response.split(marker)[-1].strip()
106
- return llm_response
 
 
107
  else:
108
- return dataset_answers[best_match_idx] if best_score >= 0.65 else (
109
- "For official information:\n"
110
- "πŸ“ž +92-42-99262231-33\n"
111
- "βœ‰οΈ info@ue.edu.pk\n"
112
- "🌐 ue.edu.pk"
113
- )
 
1
  import json
2
+ import csv
3
+ from pathlib import Path
4
+ from datetime import datetime
5
  from sentence_transformers import SentenceTransformer, util
6
  from groq import Groq
 
 
7
  from dotenv import load_dotenv
8
+ import os
9
+ import pandas as pd
10
 
11
  # Load environment variables
12
  load_dotenv()
 
14
  # Initialize Groq client
15
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
16
 
17
+ # Load models and dataset
18
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
19
 
20
  # Load dataset
21
+ try:
22
+ with open('dataset.json', 'r') as f:
23
+ dataset = json.load(f)
24
+ # Validate dataset structure
25
+ if not all(isinstance(item, dict) and 'input' in item and 'response' in item for item in dataset):
26
+ raise ValueError("Invalid dataset structure")
27
+ except (json.JSONDecodeError, ValueError, FileNotFoundError) as e:
28
+ print(f"Error loading dataset: {e}")
29
+ dataset = []
30
 
31
  # Precompute embeddings
32
  dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
33
  dataset_answers = [item.get("response", "") for item in dataset]
34
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
35
 
36
+ # Initialize unmatched queries CSV
37
+ def init_unmatched_queries_file():
38
+ csv_file = Path('unmatched_queries.csv')
39
+ if not csv_file.exists():
40
+ with open(csv_file, 'w', newline='', encoding='utf-8') as f:
41
+ writer = csv.writer(f)
42
+ writer.writerow(['Unmatched Queries', 'Timestamp'])
43
+
44
+ init_unmatched_queries_file()
45
 
46
  def query_groq_llm(prompt, model_name="llama3-70b-8192"):
47
  try:
 
56
  )
57
  return chat_completion.choices[0].message.content.strip()
58
  except Exception as e:
59
+ print(f"Error querying Groq API: {e}")
60
+ return None
61
 
62
+ def save_unmatched_query(query):
63
  try:
64
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
65
+ with open('unmatched_queries.csv', 'a', newline='', encoding='utf-8') as f:
66
+ writer = csv.writer(f)
67
+ writer.writerow([query, timestamp])
 
 
 
 
 
 
 
 
68
  except Exception as e:
69
+ print(f"Error saving unmatched query: {e}")
70
 
71
  def get_best_answer(user_input):
72
  user_input_lower = user_input.lower().strip()
73
 
74
+ # Handle fee-related questions
75
  if any(keyword in user_input_lower for keyword in ["fee", "fees", "charges", "semester fee"]):
76
  return (
77
  "πŸ’° For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
78
+ "You'll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
79
  "πŸ”— https://ue.edu.pk/allfeestructure.php"
80
  )
81
 
82
+ # Similarity matching
83
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
84
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
85
  best_match_idx = similarities.argmax().item()
86
  best_score = similarities[best_match_idx].item()
87
 
88
+ # Save unmatched queries
89
  if best_score < 0.65:
90
+ save_unmatched_query(user_input)
91
 
 
92
  if best_score >= 0.65:
93
  original_answer = dataset_answers[best_match_idx]
94
  prompt = f"""As an official assistant for University of Education Lahore, provide a clear response:
 
102
  Question: {user_input}
103
  Official Answer:"""
104
 
 
105
  llm_response = query_groq_llm(prompt)
106
 
 
107
  if llm_response:
108
  for marker in ["Improved Answer:", "Official Answer:"]:
109
  if marker in llm_response:
110
+ response = llm_response.split(marker)[-1].strip()
111
+ break
112
+ else:
113
+ response = llm_response
114
  else:
115
+ response = dataset_answers[best_match_idx] if best_score >= 0.65 else """For official information:
116
+ πŸ“ž +92-42-99262231-33
117
+ βœ‰οΈ [email protected]
118
+ 🌐 ue.edu.pk"""
119
+
120
+ return response