Spaces:
Runtime error
Update app.py
Browse filesCode Comments: Add more comments to your code. This will make it easier for others (and you in the future) to understand what each part of the code does.
Error Handling: Add error handling to your code. This will make your application more robust and easier to debug. For example, you could add try/except blocks around areas of your code that might raise exceptions.
Function Documentation: Add docstrings to your functions. This will make it clear what each function does, what parameters it takes, and what it returns.
Code Organization: Consider organizing your code into classes or modules. This can make your code easier to read and maintain. For example, you could have a separate module for all your Gradio interface functions.
Variable Naming: Use more descriptive variable names. This can make your code easier to understand. For example, instead of cfg, you could use config.
Code Formatting: Follow the PEP 8 style guide for Python code. This will make your code easier to read and more consistent. For example, you should have spaces around operators and after commas, and your lines should not be too long.
@@ -20,6 +20,12 @@ from minigpt4.runners import *
|
|
20 |
from minigpt4.tasks import *
|
21 |
|
22 |
def parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
parser = argparse.ArgumentParser(description="Demo")
|
24 |
parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
|
25 |
parser.add_argument(
|
@@ -32,8 +38,13 @@ def parse_args():
|
|
32 |
args = parser.parse_args()
|
33 |
return args
|
34 |
|
35 |
-
|
36 |
def setup_seeds(config):
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
seed = config.run_cfg.seed + get_rank()
|
38 |
|
39 |
random.seed(seed)
|
@@ -42,37 +53,39 @@ def setup_seeds(config):
|
|
42 |
|
43 |
cudnn.benchmark = False
|
44 |
cudnn.deterministic = True
|
45 |
-
|
46 |
-
# ========================================
|
47 |
-
# Model Initialization
|
48 |
-
# ========================================
|
49 |
|
50 |
-
|
|
|
|
|
51 |
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
|
|
|
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
-
|
60 |
-
cfg = Config(parse_args())
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
vis_processor_cfg = cfg.datasets_cfg.cc_align.vis_processor.train
|
67 |
-
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
68 |
-
chat = Chat(model, vis_processor)
|
69 |
-
print('Initialization Finished')
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
-
|
|
|
|
|
76 |
if chat_state is not None:
|
77 |
chat_state.messages = []
|
78 |
if img_list is not None:
|
@@ -80,6 +93,17 @@ def gradio_reset(chat_state, img_list):
|
|
80 |
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
81 |
|
82 |
def upload_img(gr_img, text_input, chat_state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
if gr_img is None:
|
84 |
return None, None, gr.update(interactive=True), chat_state, None
|
85 |
chat_state = CONV_VISION.copy()
|
@@ -88,67 +112,99 @@ def upload_img(gr_img, text_input, chat_state):
|
|
88 |
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
89 |
|
90 |
def gradio_ask(user_message, chatbot, chat_state):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
if len(user_message) == 0:
|
92 |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
93 |
chat.ask(user_message, chat_state)
|
94 |
chatbot = chatbot + [[user_message, None]]
|
95 |
return '', chatbot, chat_state
|
96 |
|
97 |
-
|
98 |
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
|
100 |
chatbot[-1][1] = llm_message
|
101 |
return chatbot, chat_state, img_list
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
"""
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
with gr.
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
value=
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
-
demo.launch(enable_queue=True)
|
|
|
20 |
from minigpt4.tasks import *
|
21 |
|
22 |
def parse_args():
|
23 |
+
"""
|
24 |
+
Parse command line arguments.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
argparse.Namespace: Parsed command line arguments.
|
28 |
+
"""
|
29 |
parser = argparse.ArgumentParser(description="Demo")
|
30 |
parser.add_argument("--cfg-path", type=str, default='eval_configs/minigpt4.yaml', help="path to configuration file.")
|
31 |
parser.add_argument(
|
|
|
38 |
args = parser.parse_args()
|
39 |
return args
|
40 |
|
|
|
41 |
def setup_seeds(config):
|
42 |
+
"""
|
43 |
+
Set up random seeds for reproducibility.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
config (Config): Configuration object.
|
47 |
+
"""
|
48 |
seed = config.run_cfg.seed + get_rank()
|
49 |
|
50 |
random.seed(seed)
|
|
|
53 |
|
54 |
cudnn.benchmark = False
|
55 |
cudnn.deterministic = True
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
def initialize_chat():
|
58 |
+
"""
|
59 |
+
Initialize the chat model.
|
60 |
|
61 |
+
Returns:
|
62 |
+
Chat: Initialized chat model.
|
63 |
+
"""
|
64 |
+
print('Initializing Chat')
|
65 |
+
config = Config(parse_args())
|
66 |
|
67 |
+
model_config = config.model_cfg
|
68 |
+
model_cls = registry.get_model_class(model_config.arch)
|
69 |
+
model = model_cls.from_config(model_config).to('cuda:0')
|
70 |
|
71 |
+
vis_processor_cfg = config.datasets_cfg.cc_align.vis_processor.train
|
72 |
+
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
73 |
+
chat = Chat(model, vis_processor)
|
74 |
+
print('Initialization Finished')
|
75 |
|
76 |
+
return chat
|
|
|
77 |
|
78 |
+
def gradio_reset(chat_state, img_list):
|
79 |
+
"""
|
80 |
+
Reset the Gradio interface.
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
Parameters:
|
83 |
+
chat_state (gr.State): The current state of the chat.
|
84 |
+
img_list (gr.State): The current list of images.
|
85 |
|
86 |
+
Returns:
|
87 |
+
tuple: Updated Gradio interface elements.
|
88 |
+
"""
|
89 |
if chat_state is not None:
|
90 |
chat_state.messages = []
|
91 |
if img_list is not None:
|
|
|
93 |
return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
|
94 |
|
95 |
def upload_img(gr_img, text_input, chat_state):
|
96 |
+
"""
|
97 |
+
Upload an image and update the Gradio interface.
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
gr_img (gr.Image): The uploaded image.
|
101 |
+
text_input (gr.Textbox): The text input box.
|
102 |
+
chat_state (gr.State): The current state of the chat.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
tuple: Updated Gradio interface elements.
|
106 |
+
"""
|
107 |
if gr_img is None:
|
108 |
return None, None, gr.update(interactive=True), chat_state, None
|
109 |
chat_state = CONV_VISION.copy()
|
|
|
112 |
return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
|
113 |
|
114 |
def gradio_ask(user_message, chatbot, chat_state):
|
115 |
+
"""
|
116 |
+
Process user message and update the Gradio interface.
|
117 |
+
|
118 |
+
Parameters:
|
119 |
+
user_message (str): The message input by the user.
|
120 |
+
chatbot (list): The current state of the chatbot.
|
121 |
+
chat_state (gr.State): The current state of the chat.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
tuple: Updated Gradio interface elements.
|
125 |
+
"""
|
126 |
if len(user_message) == 0:
|
127 |
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
128 |
chat.ask(user_message, chat_state)
|
129 |
chatbot = chatbot + [[user_message, None]]
|
130 |
return '', chatbot, chat_state
|
131 |
|
|
|
132 |
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
133 |
+
"""
|
134 |
+
Generate a chatbot answer and update the Gradio interface.
|
135 |
+
|
136 |
+
Parameters:
|
137 |
+
chatbot (list): The current state of the chatbot.
|
138 |
+
chat_state (gr.State): The current state of the chat.
|
139 |
+
img_list (gr.State): The current list of images.
|
140 |
+
num_beams (int): The number of beams for the beam search.
|
141 |
+
temperature (float): The temperature for the generation.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
tuple: Updated Gradio interface elements.
|
145 |
+
"""
|
146 |
llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
|
147 |
chatbot[-1][1] = llm_message
|
148 |
return chatbot, chat_state, img_list
|
149 |
|
150 |
+
def main():
|
151 |
+
"""
|
152 |
+
Main function to run the Gradio interface.
|
153 |
+
"""
|
154 |
+
# Initialize the chat model
|
155 |
+
chat = initialize_chat()
|
156 |
+
|
157 |
+
# Set up the Gradio interface
|
158 |
+
title = """<h1 align="center">Demo of MiniGPT-4</h1>"""
|
159 |
+
description = """<h3>This is the demo of MiniGPT-4. Upload your images and start chatting!</h3>"""
|
160 |
+
article = """<div style='display:flex; gap: 0.25rem; '><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://github.com/TsuTikgiau/blip2-llm/blob/release_prepare/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
|
161 |
+
"""
|
162 |
+
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
gr.Markdown(title)
|
165 |
+
gr.Markdown(SHARED_UI_WARNING)
|
166 |
+
gr.Markdown(description)
|
167 |
+
gr.Markdown(article)
|
168 |
+
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column(scale=0.5):
|
171 |
+
image = gr.Image(type="pil")
|
172 |
+
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
173 |
+
clear = gr.Button("Restart")
|
174 |
+
|
175 |
+
num_beams = gr.Slider(
|
176 |
+
minimum=1,
|
177 |
+
maximum=5,
|
178 |
+
value=1,
|
179 |
+
step=1,
|
180 |
+
interactive=True,
|
181 |
+
label="beam search numbers)",
|
182 |
+
)
|
183 |
+
|
184 |
+
temperature = gr.Slider(
|
185 |
+
minimum=0.1,
|
186 |
+
maximum=2.0,
|
187 |
+
value=1.0,
|
188 |
+
step=0.1,
|
189 |
+
interactive=True,
|
190 |
+
label="Temperature",
|
191 |
+
)
|
192 |
+
|
193 |
+
with gr.Column():
|
194 |
+
chat_state = gr.State()
|
195 |
+
img_list = gr.State()
|
196 |
+
chatbot = gr.Chatbot(label='MiniGPT-4')
|
197 |
+
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
|
198 |
+
|
199 |
+
upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
|
200 |
+
|
201 |
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
202 |
+
gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
|
203 |
+
)
|
204 |
+
clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
|
205 |
+
|
206 |
+
demo.launch(enable_queue=True)
|
207 |
+
|
208 |
+
if __name__ == "__main__":
|
209 |
+
main()
|
210 |
|
|