umairrrkhan commited on
Commit
65c80c2
·
verified ·
1 Parent(s): e5d4fb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
 
5
  class TextGenerationBot:
6
  def __init__(self, model_name="umairrrkhan/english-text-generation"):
7
  self.model_name = model_name
@@ -13,16 +14,17 @@ class TextGenerationBot:
13
  def setup_model(self):
14
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
15
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
16
-
 
17
  if self.tokenizer.pad_token is None:
18
  self.tokenizer.pad_token = self.tokenizer.eos_token
19
-
20
  if self.model.config.pad_token_id is None:
21
  self.model.config.pad_token_id = self.model.config.eos_token_id
22
 
23
  def generate_text(self, input_text, temperature=0.7, max_length=100):
24
  inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
25
-
26
  generation_config = {
27
  'input_ids': inputs['input_ids'],
28
  'max_length': max_length,
@@ -35,24 +37,26 @@ class TextGenerationBot:
35
  'pad_token_id': self.tokenizer.pad_token_id,
36
  'attention_mask': inputs['attention_mask']
37
  }
38
-
39
  with torch.no_grad():
40
  outputs = self.model.generate(**generation_config)
41
-
42
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
- def chat(self, message, history):
45
  self.history = history or []
46
  bot_response = self.generate_text(message)
47
  self.history.append((message, bot_response))
48
  return self.history
49
 
 
50
  class ChatbotInterface:
51
  def __init__(self):
52
  self.bot = TextGenerationBot()
53
  self.setup_interface()
54
 
55
  def setup_interface(self):
 
56
  self.interface = gr.ChatInterface(
57
  fn=self.bot.chat,
58
  title="AI Text Generation Chatbot",
@@ -62,15 +66,13 @@ class ChatbotInterface:
62
  ["What are the benefits of exercise?"],
63
  ["Write a poem about nature"]
64
  ],
65
- theme=gr.themes.Soft(),
66
- retry_btn="Regenerate Response",
67
- undo_btn="Remove Last Message",
68
- clear_btn="Clear Conversation",
69
  )
70
 
71
  def launch(self, **kwargs):
72
  self.interface.launch(**kwargs)
73
 
 
74
  def main():
75
  chatbot = ChatbotInterface()
76
  chatbot.launch(
@@ -80,5 +82,6 @@ def main():
80
  debug=True
81
  )
82
 
 
83
  if __name__ == "__main__":
84
- main()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+
6
  class TextGenerationBot:
7
  def __init__(self, model_name="umairrrkhan/english-text-generation"):
8
  self.model_name = model_name
 
14
  def setup_model(self):
15
  self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
16
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
17
+
18
+ # Set pad_token if not defined
19
  if self.tokenizer.pad_token is None:
20
  self.tokenizer.pad_token = self.tokenizer.eos_token
21
+
22
  if self.model.config.pad_token_id is None:
23
  self.model.config.pad_token_id = self.model.config.eos_token_id
24
 
25
  def generate_text(self, input_text, temperature=0.7, max_length=100):
26
  inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
27
+
28
  generation_config = {
29
  'input_ids': inputs['input_ids'],
30
  'max_length': max_length,
 
37
  'pad_token_id': self.tokenizer.pad_token_id,
38
  'attention_mask': inputs['attention_mask']
39
  }
40
+
41
  with torch.no_grad():
42
  outputs = self.model.generate(**generation_config)
43
+
44
  return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
45
 
46
+ def chat(self, message, history=None):
47
  self.history = history or []
48
  bot_response = self.generate_text(message)
49
  self.history.append((message, bot_response))
50
  return self.history
51
 
52
+
53
  class ChatbotInterface:
54
  def __init__(self):
55
  self.bot = TextGenerationBot()
56
  self.setup_interface()
57
 
58
  def setup_interface(self):
59
+ # Removed invalid arguments (retry_btn, undo_btn, clear_btn)
60
  self.interface = gr.ChatInterface(
61
  fn=self.bot.chat,
62
  title="AI Text Generation Chatbot",
 
66
  ["What are the benefits of exercise?"],
67
  ["Write a poem about nature"]
68
  ],
69
+ theme=gr.themes.Soft() # Optional
 
 
 
70
  )
71
 
72
  def launch(self, **kwargs):
73
  self.interface.launch(**kwargs)
74
 
75
+
76
  def main():
77
  chatbot = ChatbotInterface()
78
  chatbot.launch(
 
82
  debug=True
83
  )
84
 
85
+
86
  if __name__ == "__main__":
87
+ main()