midrees2806 commited on
Commit
1bdac65
Β·
verified Β·
1 Parent(s): ade1780

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +71 -77
rag.py CHANGED
@@ -1,30 +1,30 @@
1
  import json
2
- import csv
3
- from pathlib import Path
4
  from datetime import datetime
5
  from sentence_transformers import SentenceTransformer, util
6
  from groq import Groq
7
  from dotenv import load_dotenv
8
  import os
 
9
  import pandas as pd
10
 
11
  # Load environment variables
12
  load_dotenv()
13
 
14
- # Initialize Groq client
15
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
16
-
17
- # Load models and dataset
18
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
19
 
20
- # Load dataset
 
 
 
 
21
  try:
22
  with open('dataset.json', 'r') as f:
23
  dataset = json.load(f)
24
- # Validate dataset structure
25
  if not all(isinstance(item, dict) and 'input' in item and 'response' in item for item in dataset):
26
  raise ValueError("Invalid dataset structure")
27
- except (json.JSONDecodeError, ValueError, FileNotFoundError) as e:
28
  print(f"Error loading dataset: {e}")
29
  dataset = []
30
 
@@ -33,88 +33,82 @@ dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
33
  dataset_answers = [item.get("response", "") for item in dataset]
34
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
35
 
36
- # Initialize unmatched queries CSV
37
- def init_unmatched_queries_file():
38
- csv_file = Path('unmatched_queries.csv')
39
- if not csv_file.exists():
40
- with open(csv_file, 'w', newline='', encoding='utf-8') as f:
41
- writer = csv.writer(f)
42
- writer.writerow(['Unmatched Queries', 'Timestamp'])
43
-
44
- init_unmatched_queries_file()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def query_groq_llm(prompt, model_name="llama3-70b-8192"):
 
47
  try:
48
- chat_completion = groq_client.chat.completions.create(
49
- messages=[{
50
- "role": "user",
51
- "content": prompt
52
- }],
53
- model=model_name,
54
  temperature=0.7,
55
- max_tokens=500
 
56
  )
57
- return chat_completion.choices[0].message.content.strip()
58
  except Exception as e:
59
- print(f"Error querying Groq API: {e}")
60
  return None
61
 
62
- def save_unmatched_query(query):
63
- try:
64
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
65
- with open('unmatched_queries.csv', 'a', newline='', encoding='utf-8') as f:
66
- writer = csv.writer(f)
67
- writer.writerow([query, timestamp])
68
- except Exception as e:
69
- print(f"Error saving unmatched query: {e}")
70
-
71
- def get_best_answer(user_input):
72
- user_input_lower = user_input.lower().strip()
73
 
74
- # Handle fee-related questions
75
- if any(keyword in user_input_lower for keyword in ["fee", "fees", "charges", "semester fee"]):
76
- return (
77
- "πŸ’° For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
78
- "You'll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
79
- "πŸ”— https://ue.edu.pk/allfeestructure.php"
80
- )
81
 
82
- # Similarity matching
83
- user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
84
- similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
85
- best_match_idx = similarities.argmax().item()
86
- best_score = similarities[best_match_idx].item()
87
 
88
- # Save unmatched queries
89
  if best_score < 0.65:
90
- save_unmatched_query(user_input)
91
 
 
92
  if best_score >= 0.65:
93
- original_answer = dataset_answers[best_match_idx]
94
- prompt = f"""As an official assistant for University of Education Lahore, provide a clear response:
95
  Question: {user_input}
96
- Original Answer: {original_answer}
97
- Improved Answer:"""
98
  else:
99
- prompt = f"""As an official assistant for University of Education Lahore, provide a helpful response:
100
- Include relevant details about university policies.
101
- If unsure, direct to official channels.
102
  Question: {user_input}
103
- Official Answer:"""
104
-
105
- llm_response = query_groq_llm(prompt)
106
-
107
- if llm_response:
108
- for marker in ["Improved Answer:", "Official Answer:"]:
109
- if marker in llm_response:
110
- response = llm_response.split(marker)[-1].strip()
111
- break
112
- else:
113
- response = llm_response
114
- else:
115
- response = dataset_answers[best_match_idx] if best_score >= 0.65 else """For official information:
116
- πŸ“ž +92-42-99262231-33
117
- βœ‰οΈ [email protected]
118
- 🌐 ue.edu.pk"""
119
-
120
- return response
 
1
  import json
 
 
2
  from datetime import datetime
3
  from sentence_transformers import SentenceTransformer, util
4
  from groq import Groq
5
  from dotenv import load_dotenv
6
  import os
7
+ from datasets import load_dataset, Dataset, DatasetDict
8
  import pandas as pd
9
 
10
  # Load environment variables
11
  load_dotenv()
12
 
13
+ # Initialize clients
14
  groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
 
15
  similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
16
 
17
+ # Configuration
18
+ HF_DATASET_REPO = "midrees2806/unmatched_queries" # Your dataset repo
19
+ HF_TOKEN = os.getenv("HF_TOKEN") # From Space secrets
20
+
21
+ # --- Dataset Loading ---
22
  try:
23
  with open('dataset.json', 'r') as f:
24
  dataset = json.load(f)
 
25
  if not all(isinstance(item, dict) and 'input' in item and 'response' in item for item in dataset):
26
  raise ValueError("Invalid dataset structure")
27
+ except Exception as e:
28
  print(f"Error loading dataset: {e}")
29
  dataset = []
30
 
 
33
  dataset_answers = [item.get("response", "") for item in dataset]
34
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
35
 
36
+ # --- Unmatched Queries Handler ---
37
+ def manage_unmatched_queries(query: str):
38
+ """Save unmatched queries to HF Dataset with error handling"""
39
+ try:
40
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
41
+
42
+ # Load existing dataset or create new
43
+ try:
44
+ ds = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
45
+ df = ds["train"].to_pandas()
46
+ except:
47
+ df = pd.DataFrame(columns=["Query", "Timestamp", "Processed"])
48
+
49
+ # Append new query (avoid duplicates)
50
+ if query not in df["Query"].values:
51
+ new_entry = {"Query": query, "Timestamp": timestamp, "Processed": False}
52
+ df = pd.concat([df, pd.DataFrame([new_entry])], ignore_index=True)
53
+
54
+ # Push to Hub
55
+ updated_ds = Dataset.from_pandas(df)
56
+ updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
57
+ except Exception as e:
58
+ print(f"Failed to save query: {e}")
59
 
60
+ # --- Enhanced LLM Query ---
61
+ def query_llm(prompt: str, model: str = "llama3-70b-8192") -> str:
62
  try:
63
+ response = groq_client.chat.completions.create(
64
+ messages=[{"role": "user", "content": prompt}],
65
+ model=model,
 
 
 
66
  temperature=0.7,
67
+ max_tokens=1024,
68
+ top_p=0.9
69
  )
70
+ return response.choices[0].message.content.strip()
71
  except Exception as e:
72
+ print(f"LLM Error: {e}")
73
  return None
74
 
75
+ # --- Main Chat Function ---
76
+ def get_best_answer(user_input: str) -> str:
77
+ user_input = user_input.strip()
78
+ lower_input = user_input.lower()
 
 
 
 
 
 
 
79
 
80
+ # 1. Handle special cases
81
+ if any(kw in lower_input for kw in ["fee", "fees", "tuition"]):
82
+ return ("πŸ’° Fee information:\n"
83
+ "Please visit: https://ue.edu.pk/allfeestructure.php\n"
84
+ "For personalized help, contact [email protected]")
 
 
85
 
86
+ # 2. Semantic similarity search
87
+ query_embedding = similarity_model.encode(lower_input, convert_to_tensor=True)
88
+ scores = util.pytorch_cos_sim(query_embedding, dataset_embeddings)[0]
89
+ best_idx = scores.argmax().item()
90
+ best_score = scores[best_idx].item()
91
 
92
+ # 3. Save unmatched queries (threshold = 0.65)
93
  if best_score < 0.65:
94
+ manage_unmatched_queries(user_input)
95
 
96
+ # 4. Generate response
97
  if best_score >= 0.65:
98
+ context = dataset_answers[best_idx]
99
+ prompt = f"""University Assistant Task:
100
  Question: {user_input}
101
+ Context: {context}
102
+ Generate a helpful, accurate response using the context. If unsure, say "Please contact [email protected]" """
103
  else:
104
+ prompt = f"""As an official University of Education assistant, answer:
 
 
105
  Question: {user_input}
106
+ Guidelines:
107
+ - Be polite and professional
108
+ - Direct to official channels if uncertain
109
+ - Keep responses under 3 sentences"""
110
+
111
+ response = query_llm(prompt)
112
+ return response or """For official assistance:
113
+ πŸ“ž +92-42-99262231-33
114
+ βœ‰οΈ [email protected]"""