Files changed (1) hide show
  1. app.py +132 -76
app.py CHANGED
@@ -20,6 +20,12 @@ from minigpt4.runners import *
20
  from minigpt4.tasks import *
21
 
22
  def parse_args():
 
 
 
 
 
 
23
  parser = argparse.ArgumentParser(description="Demo")
24
  parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
25
  parser.add_argument(
@@ -32,8 +38,13 @@ def parse_args():
32
  args = parser.parse_args()
33
  return args
34
 
35
-
36
  def setup_seeds(config):
 
 
 
 
 
 
37
  seed = config.run_cfg.seed + get_rank()
38
 
39
  random.seed(seed)
@@ -42,37 +53,39 @@ def setup_seeds(config):
42
 
43
  cudnn.benchmark = False
44
  cudnn.deterministic = True
45
-
46
- # ========================================
47
- # Model Initialization
48
- # ========================================
49
 
50
- SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
 
 
51
 
52
- You can duplicate and use it with a paid private GPU.
 
 
 
 
53
 
54
- <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
 
 
55
 
56
- Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
57
- '''
 
 
58
 
59
- print('Initializing Chat')
60
- cfg = Config(parse_args())
61
 
62
- model_config = cfg.model_cfg
63
- model_cls = registry.get_model_class(model_config.arch)
64
- model = model_cls.from_config(model_config).to('cuda:0')
65
-
66
- vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
67
- vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
- chat = Chat(model, vis_processor)
69
- print('Initialization Finished')
70
 
71
- # ========================================
72
- # Gradio Setting
73
- # ========================================
74
 
75
- def gradio_reset(chat_state, img_list):
 
 
76
  if chat_state is not None:
77
  chat_state.messages = []
78
  if img_list is not None:
@@ -80,6 +93,17 @@ def gradio_reset(chat_state, img_list):
80
  return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
81
 
82
  def upload_img(gr_img, text_input, chat_state):
 
 
 
 
 
 
 
 
 
 
 
83
  if gr_img is None:
84
  return None, None, gr.update(interactive=True), chat_state, None
85
  chat_state = CONV_VISION.copy()
@@ -88,67 +112,99 @@ def upload_img(gr_img, text_input, chat_state):
88
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
89
 
90
  def gradio_ask(user_message, chatbot, chat_state):
 
 
 
 
 
 
 
 
 
 
 
91
  if len(user_message) == 0:
92
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
93
  chat.ask(user_message, chat_state)
94
  chatbot = chatbot + [[user_message, None]]
95
  return '', chatbot, chat_state
96
 
97
-
98
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
100
  chatbot[-1][1] = llm_message
101
  return chatbot, chat_state, img_list
102
 
103
- title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
104
- description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
105
- article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
106
- """
107
-
108
- #TODO show examples below
109
-
110
- with gr.Blocks() as demo:
111
- gr.Markdown(title)
112
- gr.Markdown(SHARED_UI_WARNING)
113
- gr.Markdown(description)
114
- gr.Markdown(article)
115
-
116
- with gr.Row():
117
- with gr.Column(scale=0.5):
118
- image = gr.Image(type="pil")
119
- upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
120
- clear = gr.Button("Restart")
121
-
122
- num_beams = gr.Slider(
123
- minimum=1,
124
- maximum=5,
125
- value=1,
126
- step=1,
127
- interactive=True,
128
- label="beam search numbers)",
129
- )
130
-
131
- temperature = gr.Slider(
132
- minimum=0.1,
133
- maximum=2.0,
134
- value=1.0,
135
- step=0.1,
136
- interactive=True,
137
- label="Temperature",
138
- )
139
-
140
-
141
- with gr.Column():
142
- chat_state = gr.State()
143
- img_list = gr.State()
144
- chatbot = gr.Chatbot(label='MiniGPT-4')
145
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
146
-
147
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
148
-
149
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
150
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
151
- )
152
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
 
 
 
 
 
 
 
 
 
 
153
 
154
- demo.launch(enable_queue=True)
 
20
  from minigpt4.tasks import *
21
 
22
  def parse_args():
23
+ """
24
+ Parse command line arguments.
25
+
26
+ Returns:
27
+ argparse.Namespace: Parsed command line arguments.
28
+ """
29
  parser = argparse.ArgumentParser(description="Demo")
30
  parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
31
  parser.add_argument(
 
38
  args = parser.parse_args()
39
  return args
40
 
 
41
  def setup_seeds(config):
42
+ """
43
+ Set up random seeds for reproducibility.
44
+
45
+ Parameters:
46
+ config (Config): Configuration object.
47
+ """
48
  seed = config.run_cfg.seed + get_rank()
49
 
50
  random.seed(seed)
 
53
 
54
  cudnn.benchmark = False
55
  cudnn.deterministic = True
 
 
 
 
56
 
57
+ def initialize_chat():
58
+ """
59
+ Initialize the chat model.
60
 
61
+ Returns:
62
+ Chat: Initialized chat model.
63
+ """
64
+ print('Initializing Chat')
65
+ config = Config(parse_args())
66
 
67
+ model_config = config.model_cfg
68
+ model_cls = registry.get_model_class(model_config.arch)
69
+ model = model_cls.from_config(model_config).to('cuda:0')
70
 
71
+ vis_processor_cfg = config.datasets_cfg.cc_align.vis_processor.train
72
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
73
+ chat = Chat(model, vis_processor)
74
+ print('Initialization Finished')
75
 
76
+ return chat
 
77
 
78
+ def gradio_reset(chat_state, img_list):
79
+ """
80
+ Reset the Gradio interface.
 
 
 
 
 
81
 
82
+ Parameters:
83
+ chat_state (gr.State): The current state of the chat.
84
+ img_list (gr.State): The current list of images.
85
 
86
+ Returns:
87
+ tuple: Updated Gradio interface elements.
88
+ """
89
  if chat_state is not None:
90
  chat_state.messages = []
91
  if img_list is not None:
 
93
  return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
94
 
95
  def upload_img(gr_img, text_input, chat_state):
96
+ """
97
+ Upload an image and update the Gradio interface.
98
+
99
+ Parameters:
100
+ gr_img (gr.Image): The uploaded image.
101
+ text_input (gr.Textbox): The text input box.
102
+ chat_state (gr.State): The current state of the chat.
103
+
104
+ Returns:
105
+ tuple: Updated Gradio interface elements.
106
+ """
107
  if gr_img is None:
108
  return None, None, gr.update(interactive=True), chat_state, None
109
  chat_state = CONV_VISION.copy()
 
112
  return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
113
 
114
  def gradio_ask(user_message, chatbot, chat_state):
115
+ """
116
+ Process user message and update the Gradio interface.
117
+
118
+ Parameters:
119
+ user_message (str): The message input by the user.
120
+ chatbot (list): The current state of the chatbot.
121
+ chat_state (gr.State): The current state of the chat.
122
+
123
+ Returns:
124
+ tuple: Updated Gradio interface elements.
125
+ """
126
  if len(user_message) == 0:
127
  return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
128
  chat.ask(user_message, chat_state)
129
  chatbot = chatbot + [[user_message, None]]
130
  return '', chatbot, chat_state
131
 
 
132
  def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
133
+ """
134
+ Generate a chatbot answer and update the Gradio interface.
135
+
136
+ Parameters:
137
+ chatbot (list): The current state of the chatbot.
138
+ chat_state (gr.State): The current state of the chat.
139
+ img_list (gr.State): The current list of images.
140
+ num_beams (int): The number of beams for the beam search.
141
+ temperature (float): The temperature for the generation.
142
+
143
+ Returns:
144
+ tuple: Updated Gradio interface elements.
145
+ """
146
  llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
147
  chatbot[-1][1] = llm_message
148
  return chatbot, chat_state, img_list
149
 
150
+ def main():
151
+ """
152
+ Main function to run the Gradio interface.
153
+ """
154
+ # Initialize the chat model
155
+ chat = initialize_chat()
156
+
157
+ # Set up the Gradio interface
158
+ title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
159
+ description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
160
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
161
+ """
162
+
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown(title)
165
+ gr.Markdown(SHARED_UI_WARNING)
166
+ gr.Markdown(description)
167
+ gr.Markdown(article)
168
+
169
+ with gr.Row():
170
+ with gr.Column(scale=0.5):
171
+ image = gr.Image(type="pil")
172
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
173
+ clear = gr.Button("Restart")
174
+
175
+ num_beams = gr.Slider(
176
+ minimum=1,
177
+ maximum=5,
178
+ value=1,
179
+ step=1,
180
+ interactive=True,
181
+ label="beam search numbers)",
182
+ )
183
+
184
+ temperature = gr.Slider(
185
+ minimum=0.1,
186
+ maximum=2.0,
187
+ value=1.0,
188
+ step=0.1,
189
+ interactive=True,
190
+ label="Temperature",
191
+ )
192
+
193
+ with gr.Column():
194
+ chat_state = gr.State()
195
+ img_list = gr.State()
196
+ chatbot = gr.Chatbot(label='MiniGPT-4')
197
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
198
+
199
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
200
+
201
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
202
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
203
+ )
204
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
205
+
206
+ demo.launch(enable_queue=True)
207
+
208
+ if __name__ == "__main__":
209
+ main()
210