umairrrkhan commited on
Commit
bf8e0d3
·
verified ·
1 Parent(s): fa7a443

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -103
app.py CHANGED
@@ -1,109 +1,26 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- class TextGenerationBot:
6
- def __init__(self, model_name="umairrrkhan/english-text-generation"):
7
- self.model_name = model_name
8
- self.model = None
9
- self.tokenizer = None
10
- self.setup_model()
11
-
12
- def setup_model(self):
13
- """
14
- Load the model and tokenizer, and ensure pad_token and pad_token_id are set.
15
- """
16
- self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
17
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
18
-
19
- # Ensure tokenizer has a pad token
20
- if self.tokenizer.pad_token is None:
21
- self.tokenizer.pad_token = self.tokenizer.eos_token
22
-
23
- # Ensure model config has pad_token_id
24
- if self.model.config.pad_token_id is None:
25
- self.model.config.pad_token_id = self.tokenizer.pad_token_id
26
-
27
- def generate_text(self, input_text, temperature=0.7, max_length=100):
28
- """
29
- Generate text based on user input.
30
- """
31
- # Tokenize input
32
- inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
33
-
34
- # Generate output
35
- with torch.no_grad():
36
- outputs = self.model.generate(
37
- input_ids=inputs["input_ids"],
38
- attention_mask=inputs["attention_mask"],
39
- max_length=max_length,
40
- temperature=temperature,
41
- top_k=50,
42
- top_p=0.95,
43
- do_sample=True,
44
- pad_token_id=self.tokenizer.pad_token_id,
45
- eos_token_id=self.tokenizer.eos_token_id,
46
- )
47
-
48
- # Decode and return the generated text
49
- return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
50
-
51
- def chat(self, message, history):
52
- """
53
- Handle a chat conversation.
54
- """
55
- if not history:
56
- history = []
57
- bot_response = self.generate_text(message)
58
- history.append((message, bot_response))
59
- return history, history
60
-
61
-
62
- class ChatbotInterface:
63
- def __init__(self):
64
- self.bot = TextGenerationBot()
65
- self.interface = None
66
- self.setup_interface()
67
-
68
- def setup_interface(self):
69
- """
70
- Set up the Gradio interface for the chatbot.
71
- """
72
- self.interface = gr.Interface(
73
- fn=self.bot.chat,
74
- inputs=[
75
- gr.inputs.Textbox(label="Your Message"),
76
- gr.inputs.State(label="Chat History"),
77
- ],
78
- outputs=[
79
- gr.outputs.Textbox(label="Bot Response"),
80
- gr.outputs.State(label="Updated Chat History"),
81
- ],
82
- title="AI Text Generation Chatbot",
83
- description="Chat with an AI model trained on English text. Try asking questions or providing prompts!",
84
- examples=[
85
- ["Tell me a short story about a brave knight"],
86
- ["What are the benefits of exercise?"],
87
- ["Write a poem about nature"],
88
- ],
89
- )
90
-
91
- def launch(self, **kwargs):
92
- """
93
- Launch the Gradio interface.
94
- """
95
- self.interface.launch(**kwargs)
96
-
97
-
98
- def main():
99
- chatbot = ChatbotInterface()
100
- chatbot.launch(
101
- server_name="0.0.0.0",
102
- server_port=7860,
103
- share=True,
104
- debug=True,
105
  )
 
106
 
 
 
107
 
108
- if __name__ == "__main__":
109
- main()
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
+ # Load the model and tokenizer from Hugging Face
5
+ model_name = "your-username/your-repo-name" # Replace with your actual model repo name
6
+ model = AutoModelForCausalLM.from_pretrained(model_name)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ # Define a function to generate text
10
+ def generate_text(prompt):
11
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
12
+ outputs = model.generate(
13
+ inputs['input_ids'],
14
+ max_length=50,
15
+ attention_mask=inputs['attention_mask'],
16
+ do_sample=True,
17
+ temperature=0.7,
18
+ top_k=50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
21
 
22
+ # Create a Gradio interface
23
+ iface = gr.Interface(fn=generate_text, inputs="text", outputs="text")
24
 
25
+ # Launch the interface
26
+ iface.launch()