Yoxas commited on
Commit
fcb1289
·
verified ·
1 Parent(s): a63f6c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -3,12 +3,9 @@ import torch
3
  from sentence_transformers import SentenceTransformer, util
4
  import gradio as gr
5
  import json
6
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
7
  import spaces
8
 
9
- # Ensure you have GPU support
10
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
-
12
  # Load the CSV file with embeddings
13
  df = pd.read_csv('RBDx10kstats.csv')
14
  df['embedding'] = df['embedding'].apply(json.loads) # Convert JSON string back to list
@@ -19,9 +16,12 @@ embeddings = torch.tensor(df['embedding'].tolist(), device=device)
19
  # Load the Sentence Transformer model
20
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
21
 
22
- # Load the LLaMA model for response generation
23
- llama_tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad")
24
- llama_model = AutoModelForQuestionAnswering.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad").to(device)
 
 
 
25
 
26
  # Define the function to find the most relevant document
27
  @spaces.GPU(duration=120)
@@ -31,14 +31,29 @@ def retrieve_relevant_doc(query):
31
  best_match_idx = torch.argmax(similarities).item()
32
  return df.iloc[best_match_idx]['Abstract']
33
 
 
 
 
 
 
 
 
 
34
  # Define the function to generate a response
35
  @spaces.GPU(duration=120)
36
  def generate_response(query):
37
  relevant_doc = retrieve_relevant_doc(query)
38
- input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
39
- inputs = llama_tokenizer(input_text, return_tensors="pt").to(device)
40
- outputs = llama_model.generate(inputs["input_ids"], max_length=500)
41
- response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
42
  return response
43
 
44
  # Create a Gradio interface
@@ -47,7 +62,7 @@ iface = gr.Interface(
47
  inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
48
  outputs="text",
49
  title="RAG Chatbot",
50
- description="This chatbot retrieves relevant documents based on your query and generates responses using LLaMA."
51
  )
52
 
53
  # Launch the Gradio interface
 
3
  from sentence_transformers import SentenceTransformer, util
4
  import gradio as gr
5
  import json
6
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, AutoModelForSequenceClassification
7
  import spaces
8
 
 
 
 
9
  # Load the CSV file with embeddings
10
  df = pd.read_csv('RBDx10kstats.csv')
11
  df['embedding'] = df['embedding'].apply(json.loads) # Convert JSON string back to list
 
16
  # Load the Sentence Transformer model
17
  model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
18
 
19
+ # Load the ai model for response generation
20
+ tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad")
21
+ model_response = AutoModelForQuestionAnswering.from_pretrained("distilbert/distilbert-base-uncased-distilled-squad").to(device)
22
+
23
+ # Load the NLU model for intent detection
24
+ nlu_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased-finetuned-sst-2-english").to(device)
25
 
26
  # Define the function to find the most relevant document
27
  @spaces.GPU(duration=120)
 
31
  best_match_idx = torch.argmax(similarities).item()
32
  return df.iloc[best_match_idx]['Abstract']
33
 
34
+ # Define the function to detect intent
35
+ @spaces.GPU(duration=120)
36
+ def detect_intent(query):
37
+ inputs = tokenizer(query, return_tensors="pt").to(device)
38
+ outputs = nlu_model(inputs["input_ids"], attention_mask=inputs["attention_mask"])
39
+ intent = torch.argmax(outputs.logits).item()
40
+ return intent
41
+
42
  # Define the function to generate a response
43
  @spaces.GPU(duration=120)
44
  def generate_response(query):
45
  relevant_doc = retrieve_relevant_doc(query)
46
+ intent = detect_intent(query)
47
+ if intent == 0: # Handle intent 0 (e.g., informational query)
48
+ input_text = f"Document: {relevant_doc}\n\nQuestion: {query}\n\nAnswer:"
49
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
50
+ outputs = model_response.generate(inputs["input_ids"], max_length=500)
51
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ elif intent == 1: # Handle intent 1 (e.g., opinion-based query)
53
+ # Generate a response based on the detected intent
54
+ response = "I'm not sure I understand your question. Can you please rephrase?"
55
+ else:
56
+ response = "I'm not sure I understand your question. Can you please rephrase?"
57
  return response
58
 
59
  # Create a Gradio interface
 
62
  inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
63
  outputs="text",
64
  title="RAG Chatbot",
65
+ description="This chatbot retrieves relevant documents based on your query and generates responses using ai models."
66
  )
67
 
68
  # Launch the Gradio interface