Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class TextGenerationBot: | |
def __init__(self, model_name="umairrrkhan/english-text-generation"): | |
self.model_name = model_name | |
self.model = None | |
self.tokenizer = None | |
self.setup_model() | |
def setup_model(self): | |
""" | |
Load the model and tokenizer, and ensure pad_token and pad_token_id are set. | |
""" | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
# Ensure tokenizer has a pad token | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Ensure model config has pad_token_id | |
if self.model.config.pad_token_id is None: | |
self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
def generate_text(self, input_text, temperature=0.7, max_length=100): | |
""" | |
Generate text based on user input. | |
""" | |
# Tokenize input | |
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True) | |
# Generate output | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=max_length, | |
temperature=temperature, | |
top_k=50, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=self.tokenizer.pad_token_id, | |
eos_token_id=self.tokenizer.eos_token_id, | |
) | |
# Decode and return the generated text | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def chat(self, message, history): | |
""" | |
Handle a chat conversation. | |
""" | |
if not history: | |
history = [] | |
bot_response = self.generate_text(message) | |
history.append((message, bot_response)) | |
return history, history | |
class ChatbotInterface: | |
def __init__(self): | |
self.bot = TextGenerationBot() | |
self.interface = None | |
self.setup_interface() | |
def setup_interface(self): | |
""" | |
Set up the Gradio interface for the chatbot. | |
""" | |
self.interface = gr.Interface( | |
fn=self.bot.chat, | |
inputs=[ | |
gr.inputs.Textbox(label="Your Message"), | |
gr.inputs.State(label="Chat History"), | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Bot Response"), | |
gr.outputs.State(label="Updated Chat History"), | |
], | |
title="AI Text Generation Chatbot", | |
description="Chat with an AI model trained on English text. Try asking questions or providing prompts!", | |
examples=[ | |
["Tell me a short story about a brave knight"], | |
["What are the benefits of exercise?"], | |
["Write a poem about nature"], | |
], | |
) | |
def launch(self, **kwargs): | |
""" | |
Launch the Gradio interface. | |
""" | |
self.interface.launch(**kwargs) | |
def main(): | |
chatbot = ChatbotInterface() | |
chatbot.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
debug=True, | |
) | |
if __name__ == "__main__": | |
main() | |