File size: 4,202 Bytes
7474d12
6978a4b
 
7474d12
6978a4b
 
7474d12
 
 
6978a4b
 
 
7474d12
6978a4b
7474d12
 
 
 
 
 
 
 
6978a4b
 
 
7474d12
 
 
6978a4b
7474d12
 
 
 
 
6978a4b
 
 
 
 
 
 
7474d12
 
 
6978a4b
7474d12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6978a4b
 
7474d12
 
 
6978a4b
 
 
 
 
 
 
 
7474d12
 
 
6978a4b
 
 
 
 
 
e25e79d
7474d12
 
6978a4b
7474d12
 
 
 
 
6978a4b
 
7474d12
 
6978a4b
7474d12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# app.py
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import gradio as gr
import torch
from tqdm import tqdm
from groq import Groq

# Load dataset
dataset = load_dataset("midrees2806/7K_Dataset")
print("Dataset sample:", dataset['train'][0])

# Initialize sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Batch processing function
def generate_embeddings_batch(texts):
    return model.encode(texts, batch_size=1024, convert_to_tensor=True, device=device).cpu().numpy()

# Prepare embeddings
train_dataset = dataset['train']
texts = [data['text'] for data in train_dataset]

batch_size = 1024
chunked_embeddings = []
for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
    batch = texts[i:i+batch_size]
    embeddings = generate_embeddings_batch(batch)
    chunked_embeddings.append(embeddings)

chunked_embeddings = np.vstack(chunked_embeddings)

# Initialize FAISS index
dimension = chunked_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(chunked_embeddings)

# Initialize Groq client
client = Groq(api_key="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")

def get_groq_response(query):
    try:
        # Get relevant context from FAISS
        faiss_results = search_in_faiss(query)
        context = "\n".join([result[0] for result in faiss_results])
        
        # Create a prompt that forces the model to only use the provided context
        prompt = f"""
        You are an expert assistant for University of Education Lahore and its sub-campuses ONLY.
        You must ONLY use the following context to answer questions. If the answer isn't in the context,
        say "I don't have information about that in the University of Education Lahore dataset."
        
        Context:
        {context}
        
        Question: {query}
        
        Answer:"""
        
        chat_completion = client.chat.completions.create(
            messages=[{"role": "user", "content": prompt}],
            model="llama3-70b-8192",
            temperature=0.3,
            max_tokens=1024
        )
        
        return chat_completion.choices[0].message.content
    except Exception as e:
        print(f"Error in Groq response: {str(e)}")
        return "I encountered an error while processing your request."

def search_in_faiss(query):
    query_embedding = model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
    distances, indices = index.search(query_embedding, k=3)
    return [(dataset['train'][int(idx)]['text'], float(dist)) for idx, dist in zip(indices[0], distances[0])]

def respond(message, chat_history):
    try:
        faiss_results = search_in_faiss(message)
        model_response = get_groq_response(message)

        bot_response = "**Relevant Information from Dataset:**\n\n"
        for result in faiss_results:
            bot_response += f"- {result[0]} (Similarity: {result[1]:.4f})\n\n"
        bot_response = "\n**Model Response:**\n\n" + model_response
                    #+
        return "", chat_history + [(message, bot_response)]
    except Exception as e:
        print(f"Error: {str(e)}")
        return "", chat_history + [(message, f"Error processing request: {str(e)}")]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# <center>UoE Chatbot</center>")
    gr.Markdown("<center>University of Education Lahore Information Bot</center>")
    gr.Markdown("<center>This bot only answers questions about University of Education Lahore and its sub-campuses</center>")

    chatbot = gr.Chatbot(height=500, bubble_full_width=False)
    with gr.Row():
        msg = gr.Textbox(label="Type your message here...", placeholder="Ask about University of Education Lahore...", scale=7)
        submit_btn = gr.Button("Submit", variant="primary")
    clear_btn = gr.Button("Clear Chat")

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    submit_btn.click(respond, [msg, chatbot], [msg, chatbot])
    clear_btn.click(lambda: None, None, chatbot, queue=False)

demo.launch()