Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
class TextGenerationBot: | |
def __init__(self, model_name="umairrrkhan/english-text-generation"): | |
self.model_name = model_name | |
self.model = None | |
self.tokenizer = None | |
self.history = [] | |
self.setup_model() | |
def setup_model(self): | |
self.model = AutoModelForCausalLM.from_pretrained(self.model_name) | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
# Set pad_token if not defined | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if self.model.config.pad_token_id is None: | |
self.model.config.pad_token_id = self.model.config.eos_token_id | |
def generate_text(self, input_text, temperature=0.7, max_length=100): | |
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True) | |
generation_config = { | |
'input_ids': inputs['input_ids'], | |
'max_length': max_length, | |
'num_return_sequences': 1, | |
'no_repeat_ngram_size': 2, | |
'temperature': temperature, | |
'top_p': 0.95, | |
'top_k': 50, | |
'do_sample': True, | |
'pad_token_id': self.tokenizer.pad_token_id, | |
'attention_mask': inputs['attention_mask'] | |
} | |
with torch.no_grad(): | |
outputs = self.model.generate(**generation_config) | |
return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
def chat(self, message, history=None): | |
self.history = history or [] | |
bot_response = self.generate_text(message) | |
self.history.append((message, bot_response)) | |
return self.history | |
class ChatbotInterface: | |
def __init__(self): | |
self.bot = TextGenerationBot() | |
self.setup_interface() | |
def setup_interface(self): | |
# Removed invalid arguments (retry_btn, undo_btn, clear_btn) | |
self.interface = gr.ChatInterface( | |
fn=self.bot.chat, | |
title="AI Text Generation Chatbot", | |
description="Chat with an AI model trained on English text. Try asking questions or providing prompts!", | |
examples=[ | |
["Tell me a short story about a brave knight"], | |
["What are the benefits of exercise?"], | |
["Write a poem about nature"] | |
], | |
theme=gr.themes.Soft() # Optional | |
) | |
def launch(self, **kwargs): | |
self.interface.launch(**kwargs) | |
def main(): | |
chatbot = ChatbotInterface() | |
chatbot.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
debug=True | |
) | |
if __name__ == "__main__": | |
main() | |