midrees2806 commited on
Commit
0e95308
·
verified ·
1 Parent(s): 84f2063

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +45 -0
rag.py CHANGED
@@ -1,3 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def get_best_answer(user_input):
2
  user_input_lower = user_input.lower().strip()
3
 
 
1
+ import json
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from groq import Groq
4
+ import datetime
5
+ import requests
6
+ from io import BytesIO
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import numpy as np
9
+ from dotenv import load_dotenv
10
+ import os
11
+
12
+ # Load environment variables
13
+ load_dotenv()
14
+
15
+ # Initialize Groq client
16
+ groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
17
+
18
+ # Load models and dataset
19
+ similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
20
+
21
+ # Load dataset (automatically using the path)
22
+ with open('dataset.json', 'r') as f:
23
+ dataset = json.load(f)
24
+
25
+ # Precompute embeddings
26
+ dataset_questions = [item.get("input", "").lower().strip() for item in dataset]
27
+ dataset_answers = [item.get("response", "") for item in dataset]
28
+ dataset_embeddings = similarity_model.encode(dataset_questions, convert_to_tensor=True)
29
+
30
+ def query_groq_llm(prompt, model_name="llama3-70b-8192"):
31
+ try:
32
+ chat_completion = groq_client.chat.completions.create(
33
+ messages=[{
34
+ "role": "user",
35
+ "content": prompt
36
+ }],
37
+ model=model_name,
38
+ temperature=0.7,
39
+ max_tokens=500
40
+ )
41
+ return chat_completion.choices[0].message.content.strip()
42
+ except Exception as e:
43
+ print(f"Error querying Groq API: {e}")
44
+ return ""
45
+
46
  def get_best_answer(user_input):
47
  user_input_lower = user_input.lower().strip()
48