FridayMaster commited on
Commit
b29db7c
·
verified ·
1 Parent(s): 63f18ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -35
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import gradio as gr
3
  import faiss
4
  import numpy as np
@@ -6,12 +5,10 @@ import openai
6
  from sentence_transformers import SentenceTransformer
7
  from nltk.tokenize import sent_tokenize
8
  import nltk
9
- from transformers import AutoTokenizer, AutoModel
10
  import torch
11
 
12
  # Download the required NLTK data
13
  nltk.download('punkt')
14
- nltk.download('punkt_tab')
15
 
16
  # Paths to your files
17
  faiss_path = "manual_chunked_faiss_index_500.bin"
@@ -52,19 +49,14 @@ except Exception as e:
52
  raise RuntimeError(f"Failed to load FAISS index: {e}")
53
 
54
  # Load the tokenizer and model for embeddings
55
- embedding_tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
56
- embedding_model = AutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
57
 
58
  # OpenAI API key
59
- openai.api_key = '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'
60
 
61
  # Function to create embeddings
62
  def embed_text(text_list):
63
- inputs = embedding_tokenizer(text_list, padding=True, truncation=True, return_tensors="pt")
64
- with torch.no_grad():
65
- outputs = embedding_model(**inputs)
66
- embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() # Use the CLS token representation
67
- return embeddings
68
 
69
  # Function to retrieve relevant chunks for a user query
70
  def retrieve_chunks(query, k=5):
@@ -78,51 +70,50 @@ def retrieve_chunks(query, k=5):
78
  raise RuntimeError(f"FAISS search failed: {e}")
79
 
80
  if len(indices[0]) == 0:
81
- return []
82
 
83
  valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
84
  if not valid_indices:
85
- return []
86
 
87
  relevant_chunks = [manual_chunks[i] for i in valid_indices]
88
- return relevant_chunks
89
-
90
- # Load the tokenizer and model for generation
91
- generator_tokenizer = AutoTokenizer.from_pretrained("gpt-3.5-turbo") # Replace with correct tokenizer if needed
92
- generator_model = AutoModel.from_pretrained("gpt-3.5-turbo") # Replace with correct model if needed
93
-
94
- # Function to truncate long inputs
95
- def truncate_input(text, max_length=512):
96
- inputs = generator_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
97
- return inputs
98
 
99
  # Function to perform RAG: Retrieve chunks and generate a response
100
- def rag_response(query, k=5, max_new_tokens=150):
101
  try:
102
- relevant_chunks = retrieve_chunks(query, k=k)
103
 
104
  if not relevant_chunks:
105
- return "Sorry, I couldn't find relevant information."
106
 
107
  augmented_input = query + "\n" + "\n".join(relevant_chunks)
108
 
109
- inputs = truncate_input(augmented_input)
110
-
111
- # Generate response
112
- outputs = generator_model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens)
113
- generated_text = generator_tokenizer.decode(outputs[0], skip_special_tokens=True)
114
-
115
- return generated_text
 
 
116
  except Exception as e:
117
- return f"An error occurred: {e}"
118
 
119
  # Gradio Interface
 
 
 
 
 
120
  iface = gr.Interface(
121
  fn=rag_response,
122
  inputs="text",
123
  outputs="text",
124
  title="RAG Chatbot with FAISS and GPT-3.5",
125
- description="Ask me anything!"
 
126
  )
127
 
128
  if __name__ == "__main__":
 
 
1
  import gradio as gr
2
  import faiss
3
  import numpy as np
 
5
  from sentence_transformers import SentenceTransformer
6
  from nltk.tokenize import sent_tokenize
7
  import nltk
 
8
  import torch
9
 
10
  # Download the required NLTK data
11
  nltk.download('punkt')
 
12
 
13
  # Paths to your files
14
  faiss_path = "manual_chunked_faiss_index_500.bin"
 
49
  raise RuntimeError(f"Failed to load FAISS index: {e}")
50
 
51
  # Load the tokenizer and model for embeddings
52
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
53
 
54
  # OpenAI API key
55
+ openai.api_key = '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'
56
 
57
  # Function to create embeddings
58
  def embed_text(text_list):
59
+ return np.array(embedding_model.encode(text_list), dtype=np.float32)
 
 
 
 
60
 
61
  # Function to retrieve relevant chunks for a user query
62
  def retrieve_chunks(query, k=5):
 
70
  raise RuntimeError(f"FAISS search failed: {e}")
71
 
72
  if len(indices[0]) == 0:
73
+ return [], distances, indices
74
 
75
  valid_indices = [i for i in indices[0] if i < len(manual_chunks)]
76
  if not valid_indices:
77
+ return [], distances, indices
78
 
79
  relevant_chunks = [manual_chunks[i] for i in valid_indices]
80
+ return relevant_chunks, distances, indices
 
 
 
 
 
 
 
 
 
81
 
82
  # Function to perform RAG: Retrieve chunks and generate a response
83
+ def rag_response(query, k=5, max_tokens=150):
84
  try:
85
+ relevant_chunks, distances, indices = retrieve_chunks(query, k=k)
86
 
87
  if not relevant_chunks:
88
+ return "Sorry, I couldn't find relevant information.", distances, indices
89
 
90
  augmented_input = query + "\n" + "\n".join(relevant_chunks)
91
 
92
+ # Generate response using OpenAI API
93
+ response = openai.Completion.create(
94
+ model="gpt-3.5-turbo",
95
+ prompt=augmented_input,
96
+ max_tokens=max_tokens,
97
+ temperature=0.7
98
+ )
99
+ generated_text = response.choices[0].text.strip()
100
+ return generated_text, distances, indices
101
  except Exception as e:
102
+ return f"An error occurred: {e}", [], []
103
 
104
  # Gradio Interface
105
+ def format_output(response, distances, indices):
106
+ # Format output to include distances and indices
107
+ formatted_response = f"Response: {response}\n\nDistances: {distances}\n\nIndices: {indices}"
108
+ return formatted_response
109
+
110
  iface = gr.Interface(
111
  fn=rag_response,
112
  inputs="text",
113
  outputs="text",
114
  title="RAG Chatbot with FAISS and GPT-3.5",
115
+ description="Ask me anything!",
116
+ live=True
117
  )
118
 
119
  if __name__ == "__main__":