Spaces:
Sleeping
Sleeping
File size: 2,787 Bytes
9ca89bc 2534236 9ca89bc 65c80c2 2534236 9ca89bc 2534236 65c80c2 2534236 65c80c2 2534236 9ca89bc 2534236 65c80c2 2534236 65c80c2 2534236 65c80c2 2534236 9ca89bc 65c80c2 2534236 9ca89bc 65c80c2 2534236 9ca89bc 2534236 65c80c2 2534236 65c80c2 2534236 9ca89bc 2534236 9ca89bc 65c80c2 2534236 9ca89bc 65c80c2 9ca89bc 65c80c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class TextGenerationBot:
def __init__(self, model_name="umairrrkhan/english-text-generation"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.history = []
self.setup_model()
def setup_model(self):
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Set pad_token if not defined
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.model.config.eos_token_id
def generate_text(self, input_text, temperature=0.7, max_length=100):
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
generation_config = {
'input_ids': inputs['input_ids'],
'max_length': max_length,
'num_return_sequences': 1,
'no_repeat_ngram_size': 2,
'temperature': temperature,
'top_p': 0.95,
'top_k': 50,
'do_sample': True,
'pad_token_id': self.tokenizer.pad_token_id,
'attention_mask': inputs['attention_mask']
}
with torch.no_grad():
outputs = self.model.generate(**generation_config)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def chat(self, message, history=None):
self.history = history or []
bot_response = self.generate_text(message)
self.history.append((message, bot_response))
return self.history
class ChatbotInterface:
def __init__(self):
self.bot = TextGenerationBot()
self.setup_interface()
def setup_interface(self):
# Removed invalid arguments (retry_btn, undo_btn, clear_btn)
self.interface = gr.ChatInterface(
fn=self.bot.chat,
title="AI Text Generation Chatbot",
description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
examples=[
["Tell me a short story about a brave knight"],
["What are the benefits of exercise?"],
["Write a poem about nature"]
],
theme=gr.themes.Soft() # Optional
)
def launch(self, **kwargs):
self.interface.launch(**kwargs)
def main():
chatbot = ChatbotInterface()
chatbot.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True
)
if __name__ == "__main__":
main()
|