Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -121,18 +121,31 @@ class SearchEngine:
|
|
121 |
return results
|
122 |
|
123 |
|
124 |
-
# Conversational Model using GPT-2
|
125 |
class Chatbot:
|
126 |
-
def __init__(self, model_name="
|
|
|
|
|
|
|
127 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
128 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
129 |
|
130 |
def generate_response(self, prompt, max_length=100):
|
131 |
"""
|
132 |
-
Generates a response to a user query using GPT-
|
133 |
"""
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
137 |
return response
|
138 |
|
|
|
121 |
return results
|
122 |
|
123 |
|
|
|
124 |
class Chatbot:
|
125 |
+
def __init__(self, model_name="EleutherAI/gpt-neo-125M"):
|
126 |
+
"""
|
127 |
+
Initializes the chatbot with GPT-Neo.
|
128 |
+
"""
|
129 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
130 |
self.model = AutoModelForCausalLM.from_pretrained(model_name)
|
131 |
|
132 |
def generate_response(self, prompt, max_length=100):
|
133 |
"""
|
134 |
+
Generates a response to a user query using GPT-Neo.
|
135 |
"""
|
136 |
+
# Tokenize the input prompt
|
137 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
|
138 |
+
|
139 |
+
# Generate the response
|
140 |
+
outputs = self.model.generate(
|
141 |
+
inputs.input_ids,
|
142 |
+
attention_mask=inputs.attention_mask, # Pass the attention mask
|
143 |
+
max_length=max_length,
|
144 |
+
num_return_sequences=1,
|
145 |
+
pad_token_id=self.tokenizer.eos_token_id, # Set pad_token_id to eos_token_id
|
146 |
+
)
|
147 |
+
|
148 |
+
# Decode the generated response
|
149 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
150 |
return response
|
151 |
|