aquibmoin commited on
Commit
1d831e0
1 Parent(s): 6e99860

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -27
app.py CHANGED
@@ -2,56 +2,62 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModel
3
  from openai import OpenAI
4
  import os
 
 
5
 
6
  # Load the NASA-specific bi-encoder model and tokenizer
7
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
8
  bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
9
  bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
10
 
11
- # Set up OpenAI API key
12
-
13
- openaiapi = os.getenv('OPENAI_API_KEY')
14
- client = OpenAI(api_key=openaiapi)
15
-
16
 
17
  def encode_text(text):
18
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
19
  outputs = bi_model(**inputs)
20
- # Ensure the output is 2D by averaging the last hidden state along the sequence dimension
21
- return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
22
 
23
- def generate_response(user_input, context_embedding):
24
- # Create a structured prompt for GPT-4
25
- context_str = ' '.join(map(str, context_embedding)) # Convert context embedding to a string
26
- combined_input = f"Question: {user_input}\nContext: {context_str}\nAnswer:"
 
 
 
 
 
27
 
28
- # Generate a response using GPT-4
29
  response = client.chat.completions.create(
30
- model="gpt-4",
31
- messages=[
32
- {"role": "user", "content": combined_input}
33
- ],
34
- max_tokens=400,
35
- temperature=0.5,
36
- top_p=0.9,
37
- frequency_penalty=0.5,
38
- presence_penalty=0.0
39
  )
40
-
41
- return response.choices[0].message.content.strip()
42
 
43
  def chatbot(user_input, context=""):
44
- context_embedding = encode_text(context) if context else ""
45
- response = generate_response(user_input, context_embedding)
 
46
  return response
47
 
48
  # Create the Gradio interface
49
  iface = gr.Interface(
50
  fn=chatbot,
51
- inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Textbox(lines=2, placeholder="Enter context here (optional)...")],
 
 
 
52
  outputs="text",
53
  title="Context-Aware Dynamic Response Chatbot",
54
- description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses."
55
  )
56
 
57
  # Launch the interface
@@ -63,3 +69,4 @@ iface.launch()
63
 
64
 
65
 
 
 
2
  from transformers import AutoTokenizer, AutoModel
3
  from openai import OpenAI
4
  import os
5
+ import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
 
8
  # Load the NASA-specific bi-encoder model and tokenizer
9
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
10
  bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
11
  bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
12
 
13
+ # Set up OpenAI client
14
+ api_key = os.getenv('OPENAI_API_KEY')
15
+ client = OpenAI(api_key=api_key)
 
 
16
 
17
  def encode_text(text):
18
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
19
  outputs = bi_model(**inputs)
20
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy()
 
21
 
22
+ def retrieve_relevant_context(user_input, context_texts):
23
+ user_embedding = encode_text(user_input)
24
+ context_embeddings = np.array([encode_text(text) for text in context_texts])
25
+ similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
26
+ most_relevant_idx = np.argmax(similarities)
27
+ return context_texts[most_relevant_idx]
28
+
29
+ def generate_response(user_input, relevant_context):
30
+ combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
31
 
 
32
  response = client.chat.completions.create(
33
+ model="gpt-4",
34
+ messages=[
35
+ {"role": "user", "content": combined_input}
36
+ ],
37
+ max_tokens=150,
38
+ temperature=0.7,
39
+ top_p=0.9,
40
+ frequency_penalty=0.5,
41
+ presence_penalty=0.0
42
  )
43
+ return response.choices[0].message['content'].strip()
 
44
 
45
  def chatbot(user_input, context=""):
46
+ context_texts = context.split("\n")
47
+ relevant_context = retrieve_relevant_context(user_input, context_texts) if context else ""
48
+ response = generate_response(user_input, relevant_context)
49
  return response
50
 
51
  # Create the Gradio interface
52
  iface = gr.Interface(
53
  fn=chatbot,
54
+ inputs=[
55
+ gr.Textbox(lines=2, placeholder="Enter your message here..."),
56
+ gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...")
57
+ ],
58
  outputs="text",
59
  title="Context-Aware Dynamic Response Chatbot",
60
+ description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses. Enter context to get more refined and relevant responses."
61
  )
62
 
63
  # Launch the interface
 
69
 
70
 
71
 
72
+