File size: 2,787 Bytes
9ca89bc
2534236
 
9ca89bc
65c80c2
2534236
 
 
 
 
 
 
9ca89bc
2534236
 
 
65c80c2
 
2534236
 
65c80c2
2534236
 
9ca89bc
2534236
 
65c80c2
2534236
 
 
 
 
 
 
 
 
 
 
 
65c80c2
2534236
 
65c80c2
2534236
9ca89bc
65c80c2
2534236
 
 
 
9ca89bc
65c80c2
2534236
 
 
 
9ca89bc
2534236
65c80c2
2534236
 
 
 
 
 
 
 
 
65c80c2
2534236
9ca89bc
2534236
 
9ca89bc
65c80c2
2534236
 
 
 
 
 
 
 
9ca89bc
65c80c2
9ca89bc
65c80c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class TextGenerationBot:
    def __init__(self, model_name="umairrrkhan/english-text-generation"):
        self.model_name = model_name
        self.model = None
        self.tokenizer = None
        self.history = []
        self.setup_model()

    def setup_model(self):
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        # Set pad_token if not defined
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        if self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.model.config.eos_token_id

    def generate_text(self, input_text, temperature=0.7, max_length=100):
        inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)

        generation_config = {
            'input_ids': inputs['input_ids'],
            'max_length': max_length,
            'num_return_sequences': 1,
            'no_repeat_ngram_size': 2,
            'temperature': temperature,
            'top_p': 0.95,
            'top_k': 50,
            'do_sample': True,
            'pad_token_id': self.tokenizer.pad_token_id,
            'attention_mask': inputs['attention_mask']
        }

        with torch.no_grad():
            outputs = self.model.generate(**generation_config)

        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def chat(self, message, history=None):
        self.history = history or []
        bot_response = self.generate_text(message)
        self.history.append((message, bot_response))
        return self.history


class ChatbotInterface:
    def __init__(self):
        self.bot = TextGenerationBot()
        self.setup_interface()

    def setup_interface(self):
        # Removed invalid arguments (retry_btn, undo_btn, clear_btn)
        self.interface = gr.ChatInterface(
            fn=self.bot.chat,
            title="AI Text Generation Chatbot",
            description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
            examples=[
                ["Tell me a short story about a brave knight"],
                ["What are the benefits of exercise?"],
                ["Write a poem about nature"]
            ],
            theme=gr.themes.Soft()  # Optional
        )

    def launch(self, **kwargs):
        self.interface.launch(**kwargs)


def main():
    chatbot = ChatbotInterface()
    chatbot.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        debug=True
    )


if __name__ == "__main__":
    main()