File size: 3,447 Bytes
9ca89bc
2534236
 
9ca89bc
2534236
 
 
 
 
 
9ca89bc
2534236
fa7a443
 
 
2534236
 
65c80c2
fa7a443
2534236
 
65c80c2
fa7a443
2534236
fa7a443
9ca89bc
2534236
fa7a443
 
 
 
2534236
65c80c2
fa7a443
2534236
fa7a443
 
 
 
 
 
 
 
 
 
 
 
 
2534236
9ca89bc
fa7a443
 
 
 
 
 
2534236
fa7a443
 
9ca89bc
65c80c2
2534236
 
 
fa7a443
2534236
9ca89bc
2534236
fa7a443
 
 
 
2534236
fa7a443
 
 
 
 
 
 
 
2534236
 
 
 
 
fa7a443
2534236
 
9ca89bc
2534236
fa7a443
 
 
2534236
9ca89bc
65c80c2
2534236
 
 
 
 
 
fa7a443
2534236
9ca89bc
65c80c2
9ca89bc
65c80c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class TextGenerationBot:
    def __init__(self, model_name="umairrrkhan/english-text-generation"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.setup_model()

    def setup_model(self):
        """
        Load the model and tokenizer, and ensure pad_token and pad_token_id are set.
        """
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        # Ensure tokenizer has a pad token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Ensure model config has pad_token_id
        if self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

    def generate_text(self, input_text, temperature=0.7, max_length=100):
        """
        Generate text based on user input.
        """
        # Tokenize input
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

        # Generate output
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                temperature=temperature,
                top_k=50,
                top_p=0.95,
                do_sample=True,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        # Decode and return the generated text
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def chat(self, message, history):
        """
        Handle a chat conversation.
        """
        if not history:
            history = []
        bot_response = self.generate_text(message)
        history.append((message, bot_response))
        return history, history


class ChatbotInterface:
    def __init__(self):
        self.bot = TextGenerationBot()
        self.interface = None
        self.setup_interface()

    def setup_interface(self):
        """
        Set up the Gradio interface for the chatbot.
        """
        self.interface = gr.Interface(
            fn=self.bot.chat,
            inputs=[
                gr.inputs.Textbox(label="Your Message"),
                gr.inputs.State(label="Chat History"),
            ],
            outputs=[
                gr.outputs.Textbox(label="Bot Response"),
                gr.outputs.State(label="Updated Chat History"),
            ],
            title="AI Text Generation Chatbot",
            description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
            examples=[
                ["Tell me a short story about a brave knight"],
                ["What are the benefits of exercise?"],
                ["Write a poem about nature"],
            ],
        )

    def launch(self, **kwargs):
        """
        Launch the Gradio interface.
        """
        self.interface.launch(**kwargs)


def main():
    chatbot = ChatbotInterface()
    chatbot.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        debug=True,
    )


if __name__ == "__main__":
    main()