YucYux commited on
Commit
db20615
·
verified ·
1 Parent(s): 724cd35

Added support for MMaDA-8B-MixCoT

Browse files
Files changed (1) hide show
  1. app.py +240 -25
app.py CHANGED
@@ -47,22 +47,23 @@ def get_num_transfer_tokens(mask_index, steps):
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
- CURRENT_MODEL_PATH = None
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
- "MMaDA-8B-MixCoT (coming soon)",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
- "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
 
66
  }
67
 
68
  def clear_outputs_action():
@@ -116,19 +117,91 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
116
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
117
 
118
  def handle_model_selection_change(selected_model_name_ui):
119
- if "coming soon" in selected_model_name_ui.lower():
120
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  MODEL = None
122
  TOKENIZER = None
123
  MASK_ID = None
124
  CURRENT_MODEL_PATH = None
125
- return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
126
-
127
- actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
128
- if not actual_path:
129
- return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
130
-
131
- return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
@@ -618,7 +691,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
618
  model_select_radio = gr.Radio(
619
  label="Select Text Generation Model",
620
  choices=MODEL_CHOICES,
621
- value=MODEL_CHOICES[0]
622
  )
623
  model_load_status_box = gr.Textbox(
624
  label="Model Load Status",
@@ -663,7 +736,27 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
663
 
664
 
665
 
666
- gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  examples=[
668
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
669
  ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
@@ -681,7 +774,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
681
  prompt_input_box_mmu = gr.Textbox(
682
  label="Enter your prompt:",
683
  lines=3,
684
- value="Please describe this image in detail."
685
  )
686
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
687
  with gr.Accordion("Generation Parameters", open=True):
@@ -689,7 +782,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
689
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
690
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
691
  with gr.Row():
692
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
693
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
694
  with gr.Row():
695
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
@@ -716,7 +809,81 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
716
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
717
 
718
 
719
- gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  examples=[
721
  [
722
  "figs/sunflower.jpg",
@@ -823,21 +990,69 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
823
  inputs=[thinking_mode_mmu],
824
  outputs=[thinking_mode_mmu, think_button_mmu]
825
  )
826
-
827
 
 
 
828
 
829
- def initialize_default_model():
830
- default_model = "MMaDA-8B-Base"
831
- result = handle_model_selection_change(default_model)
832
- return default_model, result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
 
834
  demo.load(
835
- fn=initialize_default_model,
836
  inputs=None,
837
- outputs=[model_select_radio, model_load_status_box],
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  queue=True
839
  )
840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
  def clear_outputs():
842
  return None, None, None # Clear image, visualization, and final text
843
 
 
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
+ CURRENT_MODEL_PATH = DEFAULT_MODEL_PATH
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
+ "MMaDA-8B-MixCoT",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
+ "MMaDA-8B-Base": "Gen-Verse/MMaDA-8B-Base",
66
+ "MMaDA-8B-MixCoT": "Gen-Verse/MMaDA-8B-MixCoT"
67
  }
68
 
69
  def clear_outputs_action():
 
117
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
118
 
119
  def handle_model_selection_change(selected_model_name_ui):
120
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
121
+
122
+ status_msg = ""
123
+ # 初始化 Examples 的可见性更新
124
+ vis_lm_base = gr.update(visible=False)
125
+ vis_lm_mixcot = gr.update(visible=False)
126
+ vis_lm_max = gr.update(visible=False)
127
+ vis_mmu_base = gr.update(visible=False)
128
+ vis_mmu_mixcot = gr.update(visible=False)
129
+ vis_mmu_max = gr.update(visible=False)
130
+
131
+ # 根据选择的模型决定 thinking mode 的默认状态
132
+ is_mixcot_model_selected = (selected_model_name_ui == "MMaDA-8B-MixCoT")
133
+
134
+ # 初始 thinking mode 状态和按钮标签
135
+ # 如果是 MixCoT 模型,则默认为 True (开启)
136
+ current_thinking_mode_lm_state = is_mixcot_model_selected
137
+ current_thinking_mode_mmu_state = is_mixcot_model_selected
138
+
139
+ lm_think_button_label = "Thinking Mode ✅" if current_thinking_mode_lm_state else "Thinking Mode ❌"
140
+ mmu_think_button_label = "Thinking Mode ✅" if current_thinking_mode_mmu_state else "Thinking Mode ❌"
141
+
142
+ update_think_button_lm = gr.update(value=lm_think_button_label)
143
+ update_think_button_mmu = gr.update(value=mmu_think_button_label)
144
+
145
+ if selected_model_name_ui == "MMaDA-8B-Max (coming soon)":
146
  MODEL = None
147
  TOKENIZER = None
148
  MASK_ID = None
149
  CURRENT_MODEL_PATH = None
150
+ status_msg = f"'{selected_model_name_ui}' is not yet available. Please select another model."
151
+ vis_lm_max = gr.update(visible=True)
152
+ vis_mmu_max = gr.update(visible=True)
153
+ # 对于非 MixCoT 模型,thinking mode 在上面已经根据 is_mixcot_model_selected 设置为 False
154
+ else:
155
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
156
+ if not actual_path:
157
+ MODEL = None
158
+ TOKENIZER = None
159
+ MASK_ID = None
160
+ CURRENT_MODEL_PATH = None
161
+ status_msg = f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
162
+ # 如果路径未定义(意味着不是有效的MixCoT加载),thinking mode应为False
163
+ if is_mixcot_model_selected: # 如果本应是MixCoT但路径没有
164
+ current_thinking_mode_lm_state = False
165
+ current_thinking_mode_mmu_state = False
166
+ update_think_button_lm = gr.update(value="Thinking Mode ❌")
167
+ update_think_button_mmu = gr.update(value="Thinking Mode ❌")
168
+ else:
169
+ # 尝试加载模型
170
+ status_msg = _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
171
+
172
+ # 检查模型是否成功加载
173
+ if "Error loading model" in status_msg or MODEL is None:
174
+ # 如果是 MixCoT 模型但加载失败,则关闭 thinking mode
175
+ if is_mixcot_model_selected:
176
+ current_thinking_mode_lm_state = False
177
+ current_thinking_mode_mmu_state = False
178
+ update_think_button_lm = gr.update(value="Thinking Mode ❌")
179
+ update_think_button_mmu = gr.update(value="Thinking Mode ❌")
180
+ if MODEL is None and "Error" not in status_msg: # 补充一个通用错误信息
181
+ status_msg = f"Failed to properly load model '{selected_model_name_ui}'. {status_msg}"
182
+ else: # 模型成功加载
183
+ if selected_model_name_ui == "MMaDA-8B-Base":
184
+ vis_lm_base = gr.update(visible=True)
185
+ vis_mmu_base = gr.update(visible=True)
186
+ elif selected_model_name_ui == "MMaDA-8B-MixCoT":
187
+ vis_lm_mixcot = gr.update(visible=True)
188
+ vis_mmu_mixcot = gr.update(visible=True)
189
+ # thinking mode 已经在函数开头根据 is_mixcot_model_selected 设置为 True
190
+
191
+ return (
192
+ status_msg,
193
+ vis_lm_base,
194
+ vis_lm_mixcot,
195
+ vis_lm_max,
196
+ vis_mmu_base,
197
+ vis_mmu_mixcot,
198
+ vis_mmu_max,
199
+ # 新增的返回值,用于更新 thinking_mode 状态和按钮
200
+ current_thinking_mode_lm_state, # 直接返回值给 gr.State
201
+ update_think_button_lm, # gr.update 对象给 gr.Button
202
+ current_thinking_mode_mmu_state,
203
+ update_think_button_mmu
204
+ )
205
 
206
 
207
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
 
691
  model_select_radio = gr.Radio(
692
  label="Select Text Generation Model",
693
  choices=MODEL_CHOICES,
694
+ value="MMaDA-8B-MixCoT"
695
  )
696
  model_load_status_box = gr.Textbox(
697
  label="Model Load Status",
 
736
 
737
 
738
 
739
+ examples_lm_base = gr.Examples(
740
+ examples=[
741
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
742
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
743
+ ],
744
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
745
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
746
+ fn=generate_viz_wrapper_lm,
747
+ cache_examples=False
748
+ )
749
+ examples_lm_mixcot = gr.Examples(
750
+ examples=[
751
+ ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
752
+ ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
753
+ ],
754
+ inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
755
+ outputs=[output_visualization_box_lm, output_final_text_box_lm],
756
+ fn=generate_viz_wrapper_lm,
757
+ cache_examples=False
758
+ )
759
+ examples_lm_max = gr.Examples(
760
  examples=[
761
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
762
  ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
 
774
  prompt_input_box_mmu = gr.Textbox(
775
  label="Enter your prompt:",
776
  lines=3,
777
+ value=""
778
  )
779
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
780
  with gr.Accordion("Generation Parameters", open=True):
 
782
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
783
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
784
  with gr.Row():
785
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length", info="gen_length must be divisible by this.")
786
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
787
  with gr.Row():
788
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
 
809
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
810
 
811
 
812
+ examples_mmu_base = gr.Examples(
813
+ examples=[
814
+ [
815
+ "figs/sunflower.jpg",
816
+ "Please describe this image in detail.",
817
+ 256,
818
+ 512,
819
+ 128,
820
+ 1,
821
+ 0,
822
+ "low_confidence"
823
+ ],
824
+ [
825
+ "figs/woman.jpg",
826
+ "Please describe this image in detail.",
827
+ 256,
828
+ 512,
829
+ 128,
830
+ 1,
831
+ 0,
832
+ "low_confidence"
833
+ ]
834
+ ],
835
+ inputs=[
836
+ image_upload_box,
837
+ prompt_input_box_mmu,
838
+ steps_slider_mmu,
839
+ gen_length_slider_mmu,
840
+ block_length_slider_mmu,
841
+ temperature_slider_mmu,
842
+ cfg_scale_slider_mmu,
843
+ remasking_dropdown_mmu
844
+ ],
845
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
846
+ fn=generate_viz_wrapper,
847
+ cache_examples=False
848
+ )
849
+ examples_mmu_mixcot = gr.Examples(
850
+ examples=[
851
+ [
852
+ "figs/geo.png",
853
+ "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?",
854
+ 256,
855
+ 512,
856
+ 64,
857
+ 1,
858
+ 0,
859
+ "low_confidence"
860
+ ],
861
+ [
862
+ "figs/bus.jpg",
863
+ "What are the colors of the bus?",
864
+ 256,
865
+ 512,
866
+ 64,
867
+ 1,
868
+ 0,
869
+ "low_confidence"
870
+ ]
871
+ ],
872
+ inputs=[
873
+ image_upload_box,
874
+ prompt_input_box_mmu,
875
+ steps_slider_mmu,
876
+ gen_length_slider_mmu,
877
+ block_length_slider_mmu,
878
+ temperature_slider_mmu,
879
+ cfg_scale_slider_mmu,
880
+ remasking_dropdown_mmu
881
+ ],
882
+ outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
883
+ fn=generate_viz_wrapper,
884
+ cache_examples=False
885
+ )
886
+ examples_mmu_max = gr.Examples(
887
  examples=[
888
  [
889
  "figs/sunflower.jpg",
 
990
  inputs=[thinking_mode_mmu],
991
  outputs=[thinking_mode_mmu, think_button_mmu]
992
  )
 
993
 
994
+ def initialize_app_state():
995
+ default_model_choice = "MMaDA-8B-MixCoT" # 默认加载 MixCoT
996
 
997
+ # handle_model_selection_change 现在返回更多项
998
+ status, lm_b_vis, lm_m_vis, lm_x_vis, \
999
+ mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1000
+ init_thinking_lm_state, init_think_lm_btn_update, \
1001
+ init_thinking_mmu_state, init_think_mmu_btn_update = handle_model_selection_change(default_model_choice)
1002
+
1003
+ return (
1004
+ default_model_choice,
1005
+ status,
1006
+ lm_b_vis,
1007
+ lm_m_vis,
1008
+ lm_x_vis,
1009
+ mmu_b_vis,
1010
+ mmu_m_vis,
1011
+ mmu_x_vis,
1012
+ init_thinking_lm_state,
1013
+ init_think_lm_btn_update,
1014
+ init_thinking_mmu_state,
1015
+ init_think_mmu_btn_update
1016
+ )
1017
 
1018
  demo.load(
1019
+ fn=initialize_app_state,
1020
  inputs=None,
1021
+ outputs=[
1022
+ model_select_radio,
1023
+ model_load_status_box,
1024
+ examples_lm_base,
1025
+ examples_lm_mixcot,
1026
+ examples_lm_max,
1027
+ examples_mmu_base,
1028
+ examples_mmu_mixcot,
1029
+ examples_mmu_max,
1030
+ thinking_mode_lm, # gr.State for LM thinking mode
1031
+ think_button_lm, # gr.Button for LM thinking mode
1032
+ thinking_mode_mmu, # gr.State for MMU thinking mode
1033
+ think_button_mmu # gr.Button for MMU thinking mode
1034
+ ],
1035
  queue=True
1036
  )
1037
 
1038
+ model_select_radio.change(
1039
+ fn=handle_model_selection_change,
1040
+ inputs=[model_select_radio],
1041
+ outputs=[
1042
+ model_load_status_box,
1043
+ examples_lm_base,
1044
+ examples_lm_mixcot,
1045
+ examples_lm_max,
1046
+ examples_mmu_base,
1047
+ examples_mmu_mixcot,
1048
+ examples_mmu_max,
1049
+ thinking_mode_lm,
1050
+ think_button_lm,
1051
+ thinking_mode_mmu,
1052
+ think_button_mmu
1053
+ ]
1054
+ )
1055
+
1056
  def clear_outputs():
1057
  return None, None, None # Clear image, visualization, and final text
1058