umairrrkhan commited on
Commit
fa7a443
·
verified ·
1 Parent(s): 65c80c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -28
app.py CHANGED
@@ -2,74 +2,96 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
-
6
  class TextGenerationBot:
7
  def __init__(self, model_name="umairrrkhan/english-text-generation"):
8
  self.model_name = model_name
9
  self.model = None
10
  self.tokenizer = None
11
- self.history = []
12
  self.setup_model()
13
 
14
  def setup_model(self):
 
 
 
15
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
16
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
 
18
- # Set pad_token if not defined
19
  if self.tokenizer.pad_token is None:
20
  self.tokenizer.pad_token = self.tokenizer.eos_token
21
 
 
22
  if self.model.config.pad_token_id is None:
23
- self.model.config.pad_token_id = self.model.config.eos_token_id
24
 
25
  def generate_text(self, input_text, temperature=0.7, max_length=100):
 
 
 
 
26
  inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
27
 
28
- generation_config = {
29
- 'input_ids': inputs['input_ids'],
30
- 'max_length': max_length,
31
- 'num_return_sequences': 1,
32
- 'no_repeat_ngram_size': 2,
33
- 'temperature': temperature,
34
- 'top_p': 0.95,
35
- 'top_k': 50,
36
- 'do_sample': True,
37
- 'pad_token_id': self.tokenizer.pad_token_id,
38
- 'attention_mask': inputs['attention_mask']
39
- }
40
-
41
  with torch.no_grad():
42
- outputs = self.model.generate(**generation_config)
43
-
 
 
 
 
 
 
 
 
 
 
 
44
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
- def chat(self, message, history=None):
47
- self.history = history or []
 
 
 
 
48
  bot_response = self.generate_text(message)
49
- self.history.append((message, bot_response))
50
- return self.history
51
 
52
 
53
  class ChatbotInterface:
54
  def __init__(self):
55
  self.bot = TextGenerationBot()
 
56
  self.setup_interface()
57
 
58
  def setup_interface(self):
59
- # Removed invalid arguments (retry_btn, undo_btn, clear_btn)
60
- self.interface = gr.ChatInterface(
 
 
61
  fn=self.bot.chat,
 
 
 
 
 
 
 
 
62
  title="AI Text Generation Chatbot",
63
  description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
64
  examples=[
65
  ["Tell me a short story about a brave knight"],
66
  ["What are the benefits of exercise?"],
67
- ["Write a poem about nature"]
68
  ],
69
- theme=gr.themes.Soft() # Optional
70
  )
71
 
72
  def launch(self, **kwargs):
 
 
 
73
  self.interface.launch(**kwargs)
74
 
75
 
@@ -79,7 +101,7 @@ def main():
79
  server_name="0.0.0.0",
80
  server_port=7860,
81
  share=True,
82
- debug=True
83
  )
84
 
85
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
5
  class TextGenerationBot:
6
  def __init__(self, model_name="umairrrkhan/english-text-generation"):
7
  self.model_name = model_name
8
  self.model = None
9
  self.tokenizer = None
 
10
  self.setup_model()
11
 
12
  def setup_model(self):
13
+ """
14
+ Load the model and tokenizer, and ensure pad_token and pad_token_id are set.
15
+ """
16
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
17
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
18
 
19
+ # Ensure tokenizer has a pad token
20
  if self.tokenizer.pad_token is None:
21
  self.tokenizer.pad_token = self.tokenizer.eos_token
22
 
23
+ # Ensure model config has pad_token_id
24
  if self.model.config.pad_token_id is None:
25
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
26
 
27
  def generate_text(self, input_text, temperature=0.7, max_length=100):
28
+ """
29
+ Generate text based on user input.
30
+ """
31
+ # Tokenize input
32
  inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
33
 
34
+ # Generate output
 
 
 
 
 
 
 
 
 
 
 
 
35
  with torch.no_grad():
36
+ outputs = self.model.generate(
37
+ input_ids=inputs["input_ids"],
38
+ attention_mask=inputs["attention_mask"],
39
+ max_length=max_length,
40
+ temperature=temperature,
41
+ top_k=50,
42
+ top_p=0.95,
43
+ do_sample=True,
44
+ pad_token_id=self.tokenizer.pad_token_id,
45
+ eos_token_id=self.tokenizer.eos_token_id,
46
+ )
47
+
48
+ # Decode and return the generated text
49
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
50
 
51
+ def chat(self, message, history):
52
+ """
53
+ Handle a chat conversation.
54
+ """
55
+ if not history:
56
+ history = []
57
  bot_response = self.generate_text(message)
58
+ history.append((message, bot_response))
59
+ return history, history
60
 
61
 
62
  class ChatbotInterface:
63
  def __init__(self):
64
  self.bot = TextGenerationBot()
65
+ self.interface = None
66
  self.setup_interface()
67
 
68
  def setup_interface(self):
69
+ """
70
+ Set up the Gradio interface for the chatbot.
71
+ """
72
+ self.interface = gr.Interface(
73
  fn=self.bot.chat,
74
+ inputs=[
75
+ gr.inputs.Textbox(label="Your Message"),
76
+ gr.inputs.State(label="Chat History"),
77
+ ],
78
+ outputs=[
79
+ gr.outputs.Textbox(label="Bot Response"),
80
+ gr.outputs.State(label="Updated Chat History"),
81
+ ],
82
  title="AI Text Generation Chatbot",
83
  description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
84
  examples=[
85
  ["Tell me a short story about a brave knight"],
86
  ["What are the benefits of exercise?"],
87
+ ["Write a poem about nature"],
88
  ],
 
89
  )
90
 
91
  def launch(self, **kwargs):
92
+ """
93
+ Launch the Gradio interface.
94
+ """
95
  self.interface.launch(**kwargs)
96
 
97
 
 
101
  server_name="0.0.0.0",
102
  server_port=7860,
103
  share=True,
104
+ debug=True,
105
  )
106
 
107