aquibmoin commited on
Commit
5cdeca9
1 Parent(s): 3643e08

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -1,40 +1,38 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel, GPT2Tokenizer
3
- import torch
 
4
 
5
  # Load the NASA-specific bi-encoder model and tokenizer
6
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
7
  bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
8
  bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
9
 
10
- # Load the GPT-2 model and tokenizer for response generation
11
- gpt2_model_name = "gpt2"
12
- gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
13
- gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
14
 
15
  def encode_text(text):
16
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
17
  outputs = bi_model(**inputs)
18
  # Ensure the output is 2D by averaging the last hidden state along the sequence dimension
19
- return outputs.last_hidden_state.mean(dim=1).detach().numpy()
20
 
21
  def generate_response(user_input, context_embedding):
22
- # Create a structured prompt for GPT-2
23
- combined_input = f"Question: {user_input}\nContext: {context_embedding}\nAnswer:"
 
24
 
25
- # Generate a response using GPT-2 with adjusted parameters
26
- gpt2_inputs = gpt2_tokenizer.encode(combined_input, return_tensors='pt')
27
- gpt2_outputs = gpt2_model.generate(
28
- gpt2_inputs,
29
- max_length=150,
30
- num_return_sequences=1,
31
- temperature=0.7,
32
  top_p=0.9,
33
- repetition_penalty=1.2
 
34
  )
35
- generated_text = gpt2_tokenizer.decode(gpt2_outputs[0], skip_special_tokens=True)
36
-
37
- return generated_text
38
 
39
  def chatbot(user_input, context=""):
40
  context_embedding = encode_text(context) if context else ""
@@ -47,7 +45,7 @@ iface = gr.Interface(
47
  inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Textbox(lines=2, placeholder="Enter context here (optional)...")],
48
  outputs="text",
49
  title="Context-Aware Dynamic Response Chatbot",
50
- description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-2 to generate dynamic responses."
51
  )
52
 
53
  # Launch the interface
@@ -57,3 +55,4 @@ iface.launch()
57
 
58
 
59
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import openai
4
+ import os
5
 
6
  # Load the NASA-specific bi-encoder model and tokenizer
7
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
8
  bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
9
  bi_model = AutoModel.from_pretrained(bi_encoder_model_name)
10
 
11
+ # Set up OpenAI API key
12
+ openai.api_key = os.getenv('OPENAI_API_KEY')
 
 
13
 
14
  def encode_text(text):
15
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
16
  outputs = bi_model(**inputs)
17
  # Ensure the output is 2D by averaging the last hidden state along the sequence dimension
18
+ return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
19
 
20
  def generate_response(user_input, context_embedding):
21
+ # Create a structured prompt for GPT-4
22
+ context_str = ' '.join(map(str, context_embedding)) # Convert context embedding to a string
23
+ combined_input = f"Question: {user_input}\nContext: {context_str}\nAnswer:"
24
 
25
+ # Generate a response using GPT-4
26
+ response = openai.Completion.create(
27
+ engine="gpt-4-turbo", # Use GPT-4 engine if available, otherwise use text-davinci-003
28
+ prompt=combined_input,
29
+ max_tokens=150,
30
+ temperature=0.5,
 
31
  top_p=0.9,
32
+ frequency_penalty=0.5,
33
+ presence_penalty=0.0
34
  )
35
+ return response.choices[0].text.strip()
 
 
36
 
37
  def chatbot(user_input, context=""):
38
  context_embedding = encode_text(context) if context else ""
 
45
  inputs=[gr.Textbox(lines=2, placeholder="Enter your message here..."), gr.Textbox(lines=2, placeholder="Enter context here (optional)...")],
46
  outputs="text",
47
  title="Context-Aware Dynamic Response Chatbot",
48
+ description="A chatbot using a NASA-specific bi-encoder model to understand the input context and GPT-4 to generate dynamic responses."
49
  )
50
 
51
  # Launch the interface
 
55
 
56
 
57
 
58
+