Spaces:
Sleeping
Sleeping
File size: 3,447 Bytes
9ca89bc 2534236 9ca89bc 2534236 9ca89bc 2534236 fa7a443 2534236 65c80c2 fa7a443 2534236 65c80c2 fa7a443 2534236 fa7a443 9ca89bc 2534236 fa7a443 2534236 65c80c2 fa7a443 2534236 fa7a443 2534236 9ca89bc fa7a443 2534236 fa7a443 9ca89bc 65c80c2 2534236 fa7a443 2534236 9ca89bc 2534236 fa7a443 2534236 fa7a443 2534236 fa7a443 2534236 9ca89bc 2534236 fa7a443 2534236 9ca89bc 65c80c2 2534236 fa7a443 2534236 9ca89bc 65c80c2 9ca89bc 65c80c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class TextGenerationBot:
def __init__(self, model_name="umairrrkhan/english-text-generation"):
self.model_name = model_name
self.model = None
self.tokenizer = None
self.setup_model()
def setup_model(self):
"""
Load the model and tokenizer, and ensure pad_token and pad_token_id are set.
"""
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Ensure tokenizer has a pad token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Ensure model config has pad_token_id
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def generate_text(self, input_text, temperature=0.7, max_length=100):
"""
Generate text based on user input.
"""
# Tokenize input
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
# Generate output
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
temperature=temperature,
top_k=50,
top_p=0.95,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode and return the generated text
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def chat(self, message, history):
"""
Handle a chat conversation.
"""
if not history:
history = []
bot_response = self.generate_text(message)
history.append((message, bot_response))
return history, history
class ChatbotInterface:
def __init__(self):
self.bot = TextGenerationBot()
self.interface = None
self.setup_interface()
def setup_interface(self):
"""
Set up the Gradio interface for the chatbot.
"""
self.interface = gr.Interface(
fn=self.bot.chat,
inputs=[
gr.inputs.Textbox(label="Your Message"),
gr.inputs.State(label="Chat History"),
],
outputs=[
gr.outputs.Textbox(label="Bot Response"),
gr.outputs.State(label="Updated Chat History"),
],
title="AI Text Generation Chatbot",
description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
examples=[
["Tell me a short story about a brave knight"],
["What are the benefits of exercise?"],
["Write a poem about nature"],
],
)
def launch(self, **kwargs):
"""
Launch the Gradio interface.
"""
self.interface.launch(**kwargs)
def main():
chatbot = ChatbotInterface()
chatbot.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
debug=True,
)
if __name__ == "__main__":
main()
|