whan12 commited on
Commit
239c1b0
Β·
verified Β·
1 Parent(s): 78a718b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -66
app.py CHANGED
@@ -29,7 +29,7 @@ def extract_response_pairs(text):
29
 
30
  def add_text(history, text):
31
  history = history + [[text, None]]
32
- return history, text
33
 
34
  def arnold_speak(text):
35
  arnold_phrases = [
@@ -51,93 +51,81 @@ def arnold_speak(text):
51
  return text
52
 
53
  def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
54
- outputs = pipe(images=image, prompt=prompt,
55
- generate_kwargs={"temperature": temperature,
56
- "length_penalty": length_penalty,
57
- "repetition_penalty": repetition_penalty,
58
- "max_length": max_length,
59
- "min_length": min_length,
60
- "top_p": top_p})
61
- inference_output = outputs[0]["generated_text"]
62
- return inference_output
63
-
64
- def bot(history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
 
 
 
 
65
  if text_input == "":
66
- gr.Warning("Please input text")
 
 
67
  if image is None:
68
- gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
 
69
 
70
- chat_history = " ".join([item for sublist in history_chat for item in sublist]) # Flatten history
71
-
72
- if arnold_mode:
73
- system_prompt = "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation."
74
- else:
75
- system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input."
76
 
77
- chat_history = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
78
 
79
- inference_result = infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
80
- chat_val = extract_response_pairs(inference_result)
81
 
82
- chat_state_list = copy.deepcopy(chat_val)
83
- chat_state_list[-1][1] = "" # empty last response
84
 
85
- response = chat_val[-1][1]
86
  if arnold_mode:
87
  response = arnold_speak(response)
88
 
89
- for character in response:
90
- chat_state_list[-1][1] += character
 
91
  time.sleep(0.05)
92
- yield chat_state_list
93
 
94
- css = """
95
- #mkd {
96
- height: 500px;
97
- overflow: auto;
98
- border: 1px solid #ccc;
99
- }
100
- """
101
-
102
- with gr.Blocks(css=css) as demo:
103
  gr.Markdown(DESCRIPTION)
104
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
105
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
106
- chatbot = gr.Chatbot(label="Chat", show_label=False)
107
- gr.Markdown("Input image and text and start chatting πŸ‘‡")
 
108
  with gr.Row():
109
  image = gr.Image(type="pil")
110
- text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
111
-
112
- history_chat = gr.State(value=[])
113
- arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode", value=False)
114
-
115
  with gr.Accordion(label="Advanced settings", open=False):
116
- temperature = gr.Slider(label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0)
117
- length_penalty = gr.Slider(label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0)
118
- repetition_penalty = gr.Slider(label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5)
119
- max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, step=1, value=200)
120
- min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
121
- top_p = gr.Slider(label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9)
122
-
123
- chat_inputs = [chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, history_chat, arnold_mode]
124
 
125
  with gr.Row():
126
- clear_chat_button = gr.Button("Clear")
127
- cancel_btn = gr.Button("Stop Generation")
128
- chat_button = gr.Button("Submit", variant="primary")
129
-
130
- chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
131
- bot, chat_inputs, chatbot
132
- )
133
-
134
- chat_event2 = text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
135
- bot, chat_inputs, chatbot
136
  )
137
 
138
- clear_chat_button.click(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False, api_name="clear")
139
- image.change(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False)
140
- cancel_btn.click(None, [], [], cancels=[chat_event1, chat_event2])
141
 
142
  examples = [
143
  ["./examples/baklava.png", "How to make this pastry?"],
 
29
 
30
  def add_text(history, text):
31
  history = history + [[text, None]]
32
+ return history, "" # Clear the input field after submission
33
 
34
  def arnold_speak(text):
35
  arnold_phrases = [
 
51
  return text
52
 
53
  def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
54
+ try:
55
+ outputs = pipe(images=image, prompt=prompt,
56
+ generate_kwargs={"temperature": temperature,
57
+ "length_penalty": length_penalty,
58
+ "repetition_penalty": repetition_penalty,
59
+ "max_length": max_length,
60
+ "min_length": min_length,
61
+ "top_p": top_p})
62
+ inference_output = outputs[0]["generated_text"]
63
+ return inference_output
64
+ except Exception as e:
65
+ print(f"Error during inference: {str(e)}")
66
+ return f"An error occurred during inference: {str(e)}"
67
+
68
+ def bot(history, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
69
  if text_input == "":
70
+ yield history + [["Please input text", None]]
71
+ return
72
+
73
  if image is None:
74
+ yield history + [["Please input image or wait for image to be uploaded before clicking submit.", None]]
75
+ return
76
 
77
+ chat_history = " ".join([item for sublist in history for item in sublist if item is not None]) # Flatten history
 
 
 
 
 
78
 
79
+ system_prompt = "You are a helpful AI assistant. " if not arnold_mode else "You are Arnold Schwarzenegger, the famous bodybuilder and actor. Respond in his iconic style, using his catchphrases and focusing on fitness and motivation."
80
 
81
+ prompt = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
 
82
 
83
+ response = infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
 
84
 
 
85
  if arnold_mode:
86
  response = arnold_speak(response)
87
 
88
+ history.append([text_input, ""])
89
+ for i in range(len(response)):
90
+ history[-1][1] = response[:i+1]
91
  time.sleep(0.05)
92
+ yield history
93
 
94
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
95
  gr.Markdown(DESCRIPTION)
96
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
97
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
98
+
99
+ chatbot = gr.Chatbot()
100
+
101
  with gr.Row():
102
  image = gr.Image(type="pil")
103
+ with gr.Column():
104
+ text_input = gr.Textbox(label="Chat Input", lines=3)
105
+ arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode")
106
+
 
107
  with gr.Accordion(label="Advanced settings", open=False):
108
+ temperature = gr.Slider(label="Temperature", minimum=0.5, maximum=1.0, value=1.0, step=0.1)
109
+ length_penalty = gr.Slider(label="Length Penalty", minimum=-1.0, maximum=2.0, value=1.0, step=0.2)
110
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=5.0, value=1.5, step=0.5)
111
+ max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, value=200, step=1)
112
+ min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, value=1, step=1)
113
+ top_p = gr.Slider(label="Top P", minimum=0.5, maximum=1.0, value=0.9, step=0.1)
 
 
114
 
115
  with gr.Row():
116
+ clear_button = gr.Button("Clear")
117
+ submit_button = gr.Button("Submit", variant="primary")
118
+
119
+ submit_button.click(
120
+ fn=bot,
121
+ inputs=[chatbot, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode],
122
+ outputs=chatbot
123
+ ).then(
124
+ fn=lambda: "",
125
+ outputs=text_input
126
  )
127
 
128
+ clear_button.click(lambda: ([], None), outputs=[chatbot, image], queue=False)
 
 
129
 
130
  examples = [
131
  ["./examples/baklava.png", "How to make this pastry?"],