midrees2806 commited on
Commit
3e83acd
Β·
verified Β·
1 Parent(s): 6a82855

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +72 -72
rag.py CHANGED
@@ -1,48 +1,50 @@
 
1
  import json
 
 
 
 
2
  from sentence_transformers import SentenceTransformer, util
3
  from groq import Groq
4
- from datetime import datetime
5
- import os
6
- import pandas as pd
7
  from datasets import load_dataset, Dataset
8
- from dotenv import load_dotenv
9
 
10
- # Load environment variables
11
  load_dotenv()
12
 
13
- # Initialize Groq client
14
- groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
 
15
 
16
- # Load similarity model
17
- similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
18
 
19
- # Config
20
  HF_DATASET_REPO = "midrees2806/unmatched_queries"
21
- HF_TOKEN = os.getenv("HF_TOKEN")
22
 
23
- # Greeting list
 
 
 
24
  GREETINGS = [
25
  "hi", "hello", "hey", "good morning", "good afternoon", "good evening",
26
- "assalam o alaikum", "salam", "aoa", "hi there",
27
- "hey there", "greetings"
28
  ]
29
 
30
- # Load local dataset
31
  try:
32
  with open('dataset.json', 'r') as f:
33
  dataset = json.load(f)
34
- if not all(isinstance(item, dict) and 'input' in item and 'response' in item for item in dataset):
35
- raise ValueError("Invalid dataset structure")
36
  except Exception as e:
37
- print(f"Error loading dataset: {e}")
38
  dataset = []
39
 
40
- # Precompute embeddings
41
- dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
42
- dataset_answers = [item.get("response", "") for item in dataset]
43
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
44
 
45
- # Save unmatched queries to Hugging Face
46
  def manage_unmatched_queries(query: str):
47
  try:
48
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -57,95 +59,93 @@ def manage_unmatched_queries(query: str):
57
  updated_ds = Dataset.from_pandas(df)
58
  updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
59
  except Exception as e:
60
- print(f"Failed to save query: {e}")
61
 
62
- # Query Groq LLM
63
- def query_groq_llm(prompt, model_name="llama3-70b-8192"):
64
  try:
65
- chat_completion = groq_client.chat.completions.create(
66
- messages=[{
67
- "role": "user",
68
- "content": prompt
69
- }],
70
  model=model_name,
71
  temperature=0.7,
72
  max_tokens=500
73
  )
74
- return chat_completion.choices[0].message.content.strip()
75
  except Exception as e:
76
- print(f"Error querying Groq API: {e}")
77
  return ""
78
 
79
- # Main logic function to be called from Gradio
80
- def get_best_answer(user_input):
81
  if not user_input.strip():
82
- return "Please enter a valid question."
83
 
84
  user_input_lower = user_input.lower().strip()
85
 
86
- if len(user_input_lower.split()) < 3 and not any(greet in user_input_lower for greet in GREETINGS):
87
- return "Please ask your question properly with at least 3 words."
 
88
 
 
89
  if any(greet in user_input_lower for greet in GREETINGS):
90
- greeting_response = query_groq_llm(
91
- f"You are an official assistant for University of Education Lahore. "
92
- f"Respond to this greeting in a friendly and professional manner: {user_input}"
93
- )
94
- return greeting_response if greeting_response else "Hello! How can I assist you today?"
95
 
 
96
  if any(keyword in user_input_lower for keyword in ["fee structure", "fees structure", "semester fees", "semester fee"]):
97
  return (
98
- "πŸ’° For complete and up-to-date fee details for this program, we recommend visiting the official University of Education fee structure page.\n"
99
- "You'll find comprehensive information regarding tuition, admission charges, and other applicable fees there.\n"
100
- "πŸ”— https://ue.edu.pk/allfeestructure.php"
101
  )
102
 
 
103
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
104
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
105
  best_match_idx = similarities.argmax().item()
106
  best_score = similarities[best_match_idx].item()
107
 
108
- if best_score < 0.65:
109
- manage_unmatched_queries(user_input)
110
 
111
- if best_score >= 0.65:
112
  original_answer = dataset_answers[best_match_idx]
113
- prompt = f"""Name is UOE AI Assistant! You are an official assistant for the University of Education Lahore.
114
-
115
- Rephrase the following official answer clearly and professionally.
116
- Use structured formatting (like headings, bullet points, or numbered lists) where appropriate.
117
- DO NOT add any new or extra information. ONLY rephrase and improve the clarity and formatting of the original answer.
118
-
119
  ### Question:
120
  {user_input}
121
-
122
  ### Original Answer:
123
  {original_answer}
124
-
125
- ### Rephrased Answer:
126
- """
127
  else:
128
- prompt = f"""Name is UOE AI Assistant! As an official assistant for University of Education Lahore, provide a helpful response:
129
- Include relevant details about university policies.
130
- If unsure, direct to official channels.
131
-
132
  ### Question:
133
  {user_input}
 
134
 
135
- ### Official Answer:
136
- """
137
 
138
- llm_response = query_groq_llm(prompt)
139
-
140
- if llm_response:
141
- for marker in ["Improved Answer:", "Official Answer:", "Rephrased Answer:"]:
142
- if marker in llm_response:
143
- return llm_response.split(marker)[-1].strip()
144
- return llm_response
145
  else:
146
- return dataset_answers[best_match_idx] if best_score >= 0.65 else (
147
  "For official information:\n"
148
  "πŸ“ž +92-42-99262231-33\n"
149
  "βœ‰οΈ [email protected]\n"
150
  "🌐 https://ue.edu.pk"
151
  )
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import json
3
+ import requests
4
+ import pandas as pd
5
+ from dotenv import load_dotenv
6
+ from datetime import datetime
7
  from sentence_transformers import SentenceTransformer, util
8
  from groq import Groq
 
 
 
9
  from datasets import load_dataset, Dataset
 
10
 
11
+ # βœ… Load environment variables from .env
12
  load_dotenv()
13
 
14
+ # βœ… API Keys
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
17
 
18
+ # βœ… Initialize Groq client
19
+ groq_client = Groq(api_key=GROQ_API_KEY)
20
 
21
+ # βœ… Hugging Face Dataset Repo
22
  HF_DATASET_REPO = "midrees2806/unmatched_queries"
 
23
 
24
+ # βœ… Sentence Transformer model for semantic similarity
25
+ similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
26
+
27
+ # βœ… Greeting keywords
28
  GREETINGS = [
29
  "hi", "hello", "hey", "good morning", "good afternoon", "good evening",
30
+ "assalam o alaikum", "salam", "aoa", "hi there", "hey there", "greetings"
 
31
  ]
32
 
33
+ # βœ… Load dataset
34
  try:
35
  with open('dataset.json', 'r') as f:
36
  dataset = json.load(f)
37
+ assert all('input' in d and 'response' in d for d in dataset), "Invalid dataset format"
 
38
  except Exception as e:
39
+ print(f"[ERROR] Loading dataset: {e}")
40
  dataset = []
41
 
42
+ # βœ… Prepare embeddings
43
+ dataset_questions = [d["input"].lower().strip() for d in dataset]
44
+ dataset_answers = [d["response"] for d in dataset]
45
  dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
46
 
47
+ # βœ… Function: Save unmatched queries to Hugging Face Hub
48
  def manage_unmatched_queries(query: str):
49
  try:
50
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
 
59
  updated_ds = Dataset.from_pandas(df)
60
  updated_ds.push_to_hub(HF_DATASET_REPO, token=HF_TOKEN)
61
  except Exception as e:
62
+ print(f"[ERROR] Logging unmatched query: {e}")
63
 
64
+ # βœ… Function: Call Groq LLM
65
+ def query_groq_llm(prompt: str, model_name="llama3-70b-8192") -> str:
66
  try:
67
+ completion = groq_client.chat.completions.create(
68
+ messages=[{"role": "user", "content": prompt}],
 
 
 
69
  model=model_name,
70
  temperature=0.7,
71
  max_tokens=500
72
  )
73
+ return completion.choices[0].message.content.strip()
74
  except Exception as e:
75
+ print(f"[ERROR] Groq LLM call failed: {e}")
76
  return ""
77
 
78
+ # βœ… Main RAG logic
79
+ def get_best_answer(user_input: str) -> str:
80
  if not user_input.strip():
81
+ return "⚠️ Please enter a valid question."
82
 
83
  user_input_lower = user_input.lower().strip()
84
 
85
+ # Handle short or vague questions
86
+ if len(user_input_lower.split()) < 3 and not any(g in user_input_lower for g in GREETINGS):
87
+ return "πŸ”Ž Please provide more details or ask a complete question (at least 3 words)."
88
 
89
+ # Handle greetings
90
  if any(greet in user_input_lower for greet in GREETINGS):
91
+ prompt = f"You are an official assistant for University of Education Lahore. Respond to this greeting in a professional and friendly tone: {user_input}"
92
+ return query_groq_llm(prompt) or "πŸ‘‹ Hello! How can I assist you today?"
 
 
 
93
 
94
+ # Handle direct FAQ (e.g., fee structure)
95
  if any(keyword in user_input_lower for keyword in ["fee structure", "fees structure", "semester fees", "semester fee"]):
96
  return (
97
+ "πŸ’° For complete and up-to-date fee details for this program, please visit:\n"
98
+ "πŸ”— https://ue.edu.pk/allfeestructure.php\n"
99
+ "It contains all relevant information including tuition, admission, and semester-wise fees."
100
  )
101
 
102
+ # Semantic search for best matching question
103
  user_embedding = similarity_model.encode(user_input_lower, convert_to_tensor=True)
104
  similarities = util.pytorch_cos_sim(user_embedding, dataset_embeddings)[0]
105
  best_match_idx = similarities.argmax().item()
106
  best_score = similarities[best_match_idx].item()
107
 
108
+ # Threshold to determine match quality
109
+ SIMILARITY_THRESHOLD = 0.65
110
 
111
+ if best_score >= SIMILARITY_THRESHOLD:
112
  original_answer = dataset_answers[best_match_idx]
113
+ prompt = f"""You are UOE AI Assistant! As an official assistant for University of Education Lahore, rephrase the following answer clearly and professionally.
114
+ Use bullet points or headings if helpful. Do NOT add extra information.
 
 
 
 
115
  ### Question:
116
  {user_input}
 
117
  ### Original Answer:
118
  {original_answer}
119
+ ### Rephrased Answer:"""
 
 
120
  else:
121
+ manage_unmatched_queries(user_input)
122
+ prompt = f"""You are UOE AI Assistant. As an official assistant for University of Education Lahore, provide a helpful response to this query.
123
+ If unsure, direct the user to the official university contact options.
 
124
  ### Question:
125
  {user_input}
126
+ ### Official Answer:"""
127
 
128
+ # Get the response from LLM
129
+ response = query_groq_llm(prompt)
130
 
131
+ if response:
132
+ for marker in ["Rephrased Answer:", "Official Answer:", "Improved Answer:"]:
133
+ if marker in response:
134
+ return response.split(marker)[-1].strip()
135
+ return response # if no marker found
 
 
136
  else:
137
+ return dataset_answers[best_match_idx] if best_score >= SIMILARITY_THRESHOLD else (
138
  "For official information:\n"
139
  "πŸ“ž +92-42-99262231-33\n"
140
  "βœ‰οΈ [email protected]\n"
141
  "🌐 https://ue.edu.pk"
142
  )
143
+
144
+ # βœ… Example (for direct testing)
145
+ if __name__ == "__main__":
146
+ while True:
147
+ user_input = input("\nπŸ§‘β€πŸŽ“ You: ")
148
+ if user_input.lower() in ["exit", "quit"]:
149
+ break
150
+ answer = get_best_answer(user_input)
151
+ print(f"\nπŸ€– UOE Assistant:\n{answer}")