Fred808 commited on
Commit
1569bc7
·
verified ·
1 Parent(s): be7f0e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -121,18 +121,31 @@ class SearchEngine:
121
  return results
122
 
123
 
124
- # Conversational Model using GPT-2
125
  class Chatbot:
126
- def __init__(self, model_name="gpt2"):
 
 
 
127
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
128
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
129
 
130
  def generate_response(self, prompt, max_length=100):
131
  """
132
- Generates a response to a user query using GPT-2.
133
  """
134
- inputs = self.tokenizer.encode(prompt, return_tensors="pt")
135
- outputs = self.model.generate(inputs, max_length=max_length, num_return_sequences=1)
 
 
 
 
 
 
 
 
 
 
 
136
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
137
  return response
138
 
 
121
  return results
122
 
123
 
 
124
  class Chatbot:
125
+ def __init__(self, model_name="EleutherAI/gpt-neo-125M"):
126
+ """
127
+ Initializes the chatbot with GPT-Neo.
128
+ """
129
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
130
  self.model = AutoModelForCausalLM.from_pretrained(model_name)
131
 
132
  def generate_response(self, prompt, max_length=100):
133
  """
134
+ Generates a response to a user query using GPT-Neo.
135
  """
136
+ # Tokenize the input prompt
137
+ inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
138
+
139
+ # Generate the response
140
+ outputs = self.model.generate(
141
+ inputs.input_ids,
142
+ attention_mask=inputs.attention_mask, # Pass the attention mask
143
+ max_length=max_length,
144
+ num_return_sequences=1,
145
+ pad_token_id=self.tokenizer.eos_token_id, # Set pad_token_id to eos_token_id
146
+ )
147
+
148
+ # Decode the generated response
149
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
150
  return response
151