whan12 commited on
Commit
78a718b
Β·
verified Β·
1 Parent(s): 8ed6e93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -174
app.py CHANGED
@@ -4,19 +4,19 @@ import copy
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
- from transformers import BitsAndBytesConfig, pipeline,LlavaNextProcessor, LlavaNextForConditionalGeneration
8
- import torch
9
  import re
10
  import time
 
11
 
12
- DESCRIPTION = "# LLaVA πŸ’ͺ - THE IRON PUMPING MACHINE VISION BEAST"
13
-
14
- model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
15
-
16
-
17
- pipe = LlavaNextForConditionalGeneration.from_pretrained(model_id , torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
-
19
 
 
 
 
 
 
 
20
 
21
  def extract_response_pairs(text):
22
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
@@ -25,72 +25,72 @@ def extract_response_pairs(text):
25
  for i in range(0, len(turns[1::2]), 2):
26
  if i + 1 < len(turns[1::2]):
27
  conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
28
-
29
  return conv_list
30
 
31
-
32
-
33
  def add_text(history, text):
34
- history = history.append([text, None])
35
- return history, text
36
-
37
- def infer(image, prompt,
38
- temperature,
39
- length_penalty,
40
- repetition_penalty,
41
- max_length,
42
- min_length,
43
- top_p):
44
-
45
- outputs = pipe(images=image, prompt=prompt,
46
- generate_kwargs={"temperature":temperature,
47
- "length_penalty":length_penalty,
48
- "repetition_penalty":repetition_penalty,
49
- "max_length":max_length,
50
- "min_length":min_length,
51
- "top_p":top_p})
52
- inference_output = outputs[0]["generated_text"]
53
- return inference_output
54
-
55
-
56
-
57
- def bot(history_chat, text_input, image,
58
- temperature,
59
- length_penalty,
60
- repetition_penalty,
61
- max_length,
62
- min_length,
63
- top_p):
64
-
 
 
 
65
  if text_input == "":
66
  gr.Warning("Please input text")
67
-
68
- if image==None:
69
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
70
- chat_history = " ".join(history_chat) # history as a str to be passed to model
71
- chat_history = "you are a bodybuilding coach,and you sounds like arnold schwarzenegger, give advice on my gains, training and inspire me at the end"+chat_history + f"USER: <image>\n{text_input}\nASSISTANT:" # add text input for prompting
72
- inference_result = infer(image, chat_history,
73
- temperature,
74
- length_penalty,
75
- repetition_penalty,
76
- max_length,
77
- min_length,
78
- top_p)
79
- # return inference and parse for new history
 
80
  chat_val = extract_response_pairs(inference_result)
81
 
82
- # create history list for yielding the last inference response
83
  chat_state_list = copy.deepcopy(chat_val)
84
- chat_state_list[-1][1] = "" # empty last response
 
 
 
 
85
 
86
- # add characters iteratively
87
- for character in chat_val[-1][1]:
88
  chat_state_list[-1][1] += character
89
  time.sleep(0.05)
90
- # yield history but with last response being streamed
91
  yield chat_state_list
92
 
93
-
94
  css = """
95
  #mkd {
96
  height: 500px;
@@ -98,137 +98,52 @@ css = """
98
  border: 1px solid #ccc;
99
  }
100
  """
101
- with gr.Blocks(css="style.css") as demo:
 
102
  gr.Markdown(DESCRIPTION)
103
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
104
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
105
  chatbot = gr.Chatbot(label="Chat", show_label=False)
106
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
107
  with gr.Row():
108
-
109
- image = gr.Image(type="pil")
110
- text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
111
 
112
  history_chat = gr.State(value=[])
 
113
 
114
  with gr.Accordion(label="Advanced settings", open=False):
115
- temperature = gr.Slider(
116
- label="Temperature",
117
- info="Used with nucleus sampling.",
118
- minimum=0.5,
119
- maximum=1.0,
120
- step=0.1,
121
- value=1.0,
122
- )
123
- length_penalty = gr.Slider(
124
- label="Length Penalty",
125
- info="Set to larger for longer sequence, used with beam search.",
126
- minimum=-1.0,
127
- maximum=2.0,
128
- step=0.2,
129
- value=1.0,
130
- )
131
- repetition_penalty = gr.Slider(
132
- label="Repetition Penalty",
133
- info="Larger value prevents repetition.",
134
- minimum=1.0,
135
- maximum=5.0,
136
- step=0.5,
137
- value=1.5,
138
- )
139
- max_length = gr.Slider(
140
- label="Max Length",
141
- minimum=1,
142
- maximum=500,
143
- step=1,
144
- value=200,
145
- )
146
- min_length = gr.Slider(
147
- label="Minimum Length",
148
- minimum=1,
149
- maximum=100,
150
- step=1,
151
- value=1,
152
- )
153
- top_p = gr.Slider(
154
- label="Top P",
155
- info="Used with nucleus sampling.",
156
- minimum=0.5,
157
- maximum=1.0,
158
- step=0.1,
159
- value=0.9,
160
- )
161
- chat_output = [
162
- chatbot,
163
- history_chat
164
- ]
165
 
 
166
 
167
- chat_inputs = [
168
- image,
169
- text_input,
170
- temperature,
171
- length_penalty,
172
- repetition_penalty,
173
- max_length,
174
- min_length,
175
- top_p,
176
- history_chat
177
- ]
178
  with gr.Row():
179
- clear_chat_button = gr.Button("Clear")
180
- cancel_btn = gr.Button("Stop Generation")
181
- chat_button = gr.Button("Submit", variant="primary")
182
 
183
- chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(bot, [chatbot, text_input,
184
- image, temperature,
185
- length_penalty,
186
- repetition_penalty,
187
- max_length,
188
- min_length,
189
- top_p], chatbot)
190
-
191
- chat_event2 = text_input.submit(
192
- add_text,
193
- [chatbot, text_input],
194
- [chatbot, text_input]
195
- ).then(
196
- fn=bot,
197
- inputs=[chatbot, text_input, image, temperature,
198
- length_penalty,
199
- repetition_penalty,
200
- max_length,
201
- min_length,
202
- top_p],
203
- outputs=chatbot
204
  )
205
- clear_chat_button.click(
206
- fn=lambda: ([], []),
207
- inputs=None,
208
- outputs=[
209
- chatbot,
210
- history_chat
211
- ],
212
- queue=False,
213
- api_name="clear",
214
  )
215
- image.change(
216
- fn=lambda: ([], []),
217
- inputs=None,
218
- outputs=[
219
- chatbot,
220
- history_chat
221
- ],
222
- queue=False)
223
- cancel_btn.click(
224
- None, [], [],
225
- cancels=[chat_event1, chat_event2]
226
- )
227
- examples = [["./examples/baklava.png", "How to make this pastry?"],["./examples/bee.png","Describe this image."]]
228
- gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
229
 
 
 
 
230
 
231
-
 
 
 
 
232
 
233
  if __name__ == "__main__":
234
  demo.queue(max_size=10).launch(debug=True)
 
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
+ from transformers import BitsAndBytesConfig, pipeline
 
8
  import re
9
  import time
10
+ import random
11
 
12
+ DESCRIPTION = "# LLaVA πŸŒ‹πŸ’ͺ - Now with Arnold Mode!"
 
 
 
 
 
 
13
 
14
+ model_id = "llava-hf/llava-1.5-7b-hf"
15
+ quantization_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_compute_dtype=torch.float16
18
+ )
19
+ pipe = pipeline("image-to-text", model=model_id, model_kwargs={"quantization_config": quantization_config})
20
 
21
  def extract_response_pairs(text):
22
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
 
25
  for i in range(0, len(turns[1::2]), 2):
26
  if i + 1 < len(turns[1::2]):
27
  conv_list.append([turns[1::2][i].lstrip(":"), turns[1::2][i + 1].lstrip(":")])
 
28
  return conv_list
29
 
 
 
30
  def add_text(history, text):
31
+ history = history + [[text, None]]
32
+ return history, text
33
+
34
+ def arnold_speak(text):
35
+ arnold_phrases = [
36
+ "Come with me if you want to lift!",
37
+ "I'll be back... after my protein shake.",
38
+ "Hasta la vista, baby weight!",
39
+ "Get to da choppa... I mean, da squat rack!",
40
+ "You lack discipline! But don't worry, I'm here to pump you up!"
41
+ ]
42
+
43
+ text = text.replace(".", "!") # More enthusiastic punctuation
44
+ text = text.replace("gym", "iron paradise")
45
+ text = text.replace("exercise", "pump iron")
46
+ text = text.replace("workout", "sculpt your physique")
47
+
48
+ # Add random Arnold phrase to the end
49
+ text += " " + random.choice(arnold_phrases)
50
+
51
+ return text
52
+
53
+ def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
54
+ outputs = pipe(images=image, prompt=prompt,
55
+ generate_kwargs={"temperature": temperature,
56
+ "length_penalty": length_penalty,
57
+ "repetition_penalty": repetition_penalty,
58
+ "max_length": max_length,
59
+ "min_length": min_length,
60
+ "top_p": top_p})
61
+ inference_output = outputs[0]["generated_text"]
62
+ return inference_output
63
+
64
+ def bot(history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
65
  if text_input == "":
66
  gr.Warning("Please input text")
67
+ if image is None:
 
68
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
69
+
70
+ chat_history = " ".join([item for sublist in history_chat for item in sublist]) # Flatten history
71
+
72
+ if arnold_mode:
73
+ system_prompt = "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation."
74
+ else:
75
+ system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input."
76
+
77
+ chat_history = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
78
+
79
+ inference_result = infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
80
  chat_val = extract_response_pairs(inference_result)
81
 
 
82
  chat_state_list = copy.deepcopy(chat_val)
83
+ chat_state_list[-1][1] = "" # empty last response
84
+
85
+ response = chat_val[-1][1]
86
+ if arnold_mode:
87
+ response = arnold_speak(response)
88
 
89
+ for character in response:
 
90
  chat_state_list[-1][1] += character
91
  time.sleep(0.05)
 
92
  yield chat_state_list
93
 
 
94
  css = """
95
  #mkd {
96
  height: 500px;
 
98
  border: 1px solid #ccc;
99
  }
100
  """
101
+
102
+ with gr.Blocks(css=css) as demo:
103
  gr.Markdown(DESCRIPTION)
104
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
105
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
106
  chatbot = gr.Chatbot(label="Chat", show_label=False)
107
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
108
  with gr.Row():
109
+ image = gr.Image(type="pil")
110
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
 
111
 
112
  history_chat = gr.State(value=[])
113
+ arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode", value=False)
114
 
115
  with gr.Accordion(label="Advanced settings", open=False):
116
+ temperature = gr.Slider(label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0)
117
+ length_penalty = gr.Slider(label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0)
118
+ repetition_penalty = gr.Slider(label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5)
119
+ max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, step=1, value=200)
120
+ min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
121
+ top_p = gr.Slider(label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ chat_inputs = [chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, history_chat, arnold_mode]
124
 
 
 
 
 
 
 
 
 
 
 
 
125
  with gr.Row():
126
+ clear_chat_button = gr.Button("Clear")
127
+ cancel_btn = gr.Button("Stop Generation")
128
+ chat_button = gr.Button("Submit", variant="primary")
129
 
130
+ chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
131
+ bot, chat_inputs, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  )
133
+
134
+ chat_event2 = text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
135
+ bot, chat_inputs, chatbot
 
 
 
 
 
 
136
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ clear_chat_button.click(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False, api_name="clear")
139
+ image.change(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False)
140
+ cancel_btn.click(None, [], [], cancels=[chat_event1, chat_event2])
141
 
142
+ examples = [
143
+ ["./examples/baklava.png", "How to make this pastry?"],
144
+ ["./examples/bee.png", "Describe this image."]
145
+ ]
146
+ gr.Examples(examples=examples, inputs=[image, text_input])
147
 
148
  if __name__ == "__main__":
149
  demo.queue(max_size=10).launch(debug=True)