Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import gradio as gr
|
3 |
import faiss
|
4 |
import numpy as np
|
@@ -6,12 +5,10 @@ import openai
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from nltk.tokenize import sent_tokenize
|
8 |
import nltk
|
9 |
-
from transformers import AutoTokenizer, AutoModel
|
10 |
import torch
|
11 |
|
12 |
# Download the required NLTK data
|
13 |
nltk.download('punkt')
|
14 |
-
nltk.download('punkt_tab')
|
15 |
|
16 |
# Paths to your files
|
17 |
faiss_path = "manual_chunked_faiss_index_500.bin"
|
@@ -52,19 +49,14 @@ except Exception as e:
|
|
52 |
raise RuntimeError(f"Failed to load FAISS index: {e}")
|
53 |
|
54 |
# Load the tokenizer and model for embeddings
|
55 |
-
|
56 |
-
embedding_model = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
57 |
|
58 |
# OpenAI API key
|
59 |
-
openai.api_key =
|
60 |
|
61 |
# Function to create embeddings
|
62 |
def embed_text(text_list):
|
63 |
-
|
64 |
-
with torch.no_grad():
|
65 |
-
outputs = embedding_model(**inputs)
|
66 |
-
embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() # Use the CLS token representation
|
67 |
-
return embeddings
|
68 |
|
69 |
# Function to retrieve relevant chunks for a user query
|
70 |
def retrieve_chunks(query, k=5):
|
@@ -78,51 +70,50 @@ def retrieve_chunks(query, k=5):
|
|
78 |
raise RuntimeError(f"FAISS search failed: {e}")
|
79 |
|
80 |
if len(indices[0]) == 0:
|
81 |
-
return []
|
82 |
|
83 |
valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
|
84 |
if not valid_indices:
|
85 |
-
return []
|
86 |
|
87 |
relevant_chunks = [manual_chunks[i] for i in valid_indices]
|
88 |
-
return relevant_chunks
|
89 |
-
|
90 |
-
# Load the tokenizer and model for generation
|
91 |
-
generator_tokenizer = AutoTokenizer.from_pretrained("gpt-3.5-turbo") # Replace with correct tokenizer if needed
|
92 |
-
generator_model = AutoModel.from_pretrained("gpt-3.5-turbo") # Replace with correct model if needed
|
93 |
-
|
94 |
-
# Function to truncate long inputs
|
95 |
-
def truncate_input(text, max_length=512):
|
96 |
-
inputs = generator_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
|
97 |
-
return inputs
|
98 |
|
99 |
# Function to perform RAG: Retrieve chunks and generate a response
|
100 |
-
def rag_response(query, k=5,
|
101 |
try:
|
102 |
-
relevant_chunks = retrieve_chunks(query, k=k)
|
103 |
|
104 |
if not relevant_chunks:
|
105 |
-
return "Sorry, I couldn't find relevant information."
|
106 |
|
107 |
augmented_input = query + "\n" + "\n".join(relevant_chunks)
|
108 |
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
except Exception as e:
|
117 |
-
return f"An error occurred: {e}"
|
118 |
|
119 |
# Gradio Interface
|
|
|
|
|
|
|
|
|
|
|
120 |
iface = gr.Interface(
|
121 |
fn=rag_response,
|
122 |
inputs="text",
|
123 |
outputs="text",
|
124 |
title="RAG Chatbot with FAISS and GPT-3.5",
|
125 |
-
description="Ask me anything!"
|
|
|
126 |
)
|
127 |
|
128 |
if __name__ == "__main__":
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import faiss
|
3 |
import numpy as np
|
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from nltk.tokenize import sent_tokenize
|
7 |
import nltk
|
|
|
8 |
import torch
|
9 |
|
10 |
# Download the required NLTK data
|
11 |
nltk.download('punkt')
|
|
|
12 |
|
13 |
# Paths to your files
|
14 |
faiss_path = "manual_chunked_faiss_index_500.bin"
|
|
|
49 |
raise RuntimeError(f"Failed to load FAISS index: {e}")
|
50 |
|
51 |
# Load the tokenizer and model for embeddings
|
52 |
+
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
53 |
|
54 |
# OpenAI API key
|
55 |
+
openai.api_key = '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'
|
56 |
|
57 |
# Function to create embeddings
|
58 |
def embed_text(text_list):
|
59 |
+
return np.array(embedding_model.encode(text_list), dtype=np.float32)
|
|
|
|
|
|
|
|
|
60 |
|
61 |
# Function to retrieve relevant chunks for a user query
|
62 |
def retrieve_chunks(query, k=5):
|
|
|
70 |
raise RuntimeError(f"FAISS search failed: {e}")
|
71 |
|
72 |
if len(indices[0]) == 0:
|
73 |
+
return [], distances, indices
|
74 |
|
75 |
valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
|
76 |
if not valid_indices:
|
77 |
+
return [], distances, indices
|
78 |
|
79 |
relevant_chunks = [manual_chunks[i] for i in valid_indices]
|
80 |
+
return relevant_chunks, distances, indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
# Function to perform RAG: Retrieve chunks and generate a response
|
83 |
+
def rag_response(query, k=5, max_tokens=150):
|
84 |
try:
|
85 |
+
relevant_chunks, distances, indices = retrieve_chunks(query, k=k)
|
86 |
|
87 |
if not relevant_chunks:
|
88 |
+
return "Sorry, I couldn't find relevant information.", distances, indices
|
89 |
|
90 |
augmented_input = query + "\n" + "\n".join(relevant_chunks)
|
91 |
|
92 |
+
# Generate response using OpenAI API
|
93 |
+
response = openai.Completion.create(
|
94 |
+
model="gpt-3.5-turbo",
|
95 |
+
prompt=augmented_input,
|
96 |
+
max_tokens=max_tokens,
|
97 |
+
temperature=0.7
|
98 |
+
)
|
99 |
+
generated_text = response.choices[0].text.strip()
|
100 |
+
return generated_text, distances, indices
|
101 |
except Exception as e:
|
102 |
+
return f"An error occurred: {e}", [], []
|
103 |
|
104 |
# Gradio Interface
|
105 |
+
def format_output(response, distances, indices):
|
106 |
+
# Format output to include distances and indices
|
107 |
+
formatted_response = f"Response: {response}\n\nDistances: {distances}\n\nIndices: {indices}"
|
108 |
+
return formatted_response
|
109 |
+
|
110 |
iface = gr.Interface(
|
111 |
fn=rag_response,
|
112 |
inputs="text",
|
113 |
outputs="text",
|
114 |
title="RAG Chatbot with FAISS and GPT-3.5",
|
115 |
+
description="Ask me anything!",
|
116 |
+
live=True
|
117 |
)
|
118 |
|
119 |
if __name__ == "__main__":
|