YucYux commited on
Commit
d08f144
·
1 Parent(s): db20615

Revert "Added support for MMaDA-8B-MixCoT"

Browse files

This reverts commit db20615a9ccddd7b9c1ee9043750d591e46628a2.

Files changed (1) hide show
  1. app.py +25 -240
app.py CHANGED
@@ -47,23 +47,22 @@ def get_num_transfer_tokens(mask_index, steps):
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
- DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-MixCoT" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
- CURRENT_MODEL_PATH = DEFAULT_MODEL_PATH
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
- "MMaDA-8B-MixCoT",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
- "MMaDA-8B-Base": "Gen-Verse/MMaDA-8B-Base",
66
- "MMaDA-8B-MixCoT": "Gen-Verse/MMaDA-8B-MixCoT"
67
  }
68
 
69
  def clear_outputs_action():
@@ -117,91 +116,19 @@ def _load_model_and_tokenizer_core(model_path_to_load, model_display_name_for_st
117
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
118
 
119
  def handle_model_selection_change(selected_model_name_ui):
120
- global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH, DEVICE, uni_prompting
121
-
122
- status_msg = ""
123
- # 初始化 Examples 的可见性更新
124
- vis_lm_base = gr.update(visible=False)
125
- vis_lm_mixcot = gr.update(visible=False)
126
- vis_lm_max = gr.update(visible=False)
127
- vis_mmu_base = gr.update(visible=False)
128
- vis_mmu_mixcot = gr.update(visible=False)
129
- vis_mmu_max = gr.update(visible=False)
130
-
131
- # 根据选择的模型决定 thinking mode 的默认状态
132
- is_mixcot_model_selected = (selected_model_name_ui == "MMaDA-8B-MixCoT")
133
-
134
- # 初始 thinking mode 状态和按钮标签
135
- # 如果是 MixCoT 模型,则默认为 True (开启)
136
- current_thinking_mode_lm_state = is_mixcot_model_selected
137
- current_thinking_mode_mmu_state = is_mixcot_model_selected
138
-
139
- lm_think_button_label = "Thinking Mode ✅" if current_thinking_mode_lm_state else "Thinking Mode ❌"
140
- mmu_think_button_label = "Thinking Mode ✅" if current_thinking_mode_mmu_state else "Thinking Mode ❌"
141
-
142
- update_think_button_lm = gr.update(value=lm_think_button_label)
143
- update_think_button_mmu = gr.update(value=mmu_think_button_label)
144
-
145
- if selected_model_name_ui == "MMaDA-8B-Max (coming soon)":
146
  MODEL = None
147
  TOKENIZER = None
148
  MASK_ID = None
149
  CURRENT_MODEL_PATH = None
150
- status_msg = f"'{selected_model_name_ui}' is not yet available. Please select another model."
151
- vis_lm_max = gr.update(visible=True)
152
- vis_mmu_max = gr.update(visible=True)
153
- # 对于非 MixCoT 模型,thinking mode 在上面已经根据 is_mixcot_model_selected 设置为 False
154
- else:
155
- actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
156
- if not actual_path:
157
- MODEL = None
158
- TOKENIZER = None
159
- MASK_ID = None
160
- CURRENT_MODEL_PATH = None
161
- status_msg = f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
162
- # 如果路径未定义(意味着不是有效的MixCoT加载),thinking mode应为False
163
- if is_mixcot_model_selected: # 如果本应是MixCoT但路径没有
164
- current_thinking_mode_lm_state = False
165
- current_thinking_mode_mmu_state = False
166
- update_think_button_lm = gr.update(value="Thinking Mode ❌")
167
- update_think_button_mmu = gr.update(value="Thinking Mode ❌")
168
- else:
169
- # 尝试加载模型
170
- status_msg = _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
171
-
172
- # 检查模型是否成功加载
173
- if "Error loading model" in status_msg or MODEL is None:
174
- # 如果是 MixCoT 模型但加载失败,则关闭 thinking mode
175
- if is_mixcot_model_selected:
176
- current_thinking_mode_lm_state = False
177
- current_thinking_mode_mmu_state = False
178
- update_think_button_lm = gr.update(value="Thinking Mode ❌")
179
- update_think_button_mmu = gr.update(value="Thinking Mode ❌")
180
- if MODEL is None and "Error" not in status_msg: # 补充一个通用错误信息
181
- status_msg = f"Failed to properly load model '{selected_model_name_ui}'. {status_msg}"
182
- else: # 模型成功加载
183
- if selected_model_name_ui == "MMaDA-8B-Base":
184
- vis_lm_base = gr.update(visible=True)
185
- vis_mmu_base = gr.update(visible=True)
186
- elif selected_model_name_ui == "MMaDA-8B-MixCoT":
187
- vis_lm_mixcot = gr.update(visible=True)
188
- vis_mmu_mixcot = gr.update(visible=True)
189
- # thinking mode 已经在函数开头根据 is_mixcot_model_selected 设置为 True
190
-
191
- return (
192
- status_msg,
193
- vis_lm_base,
194
- vis_lm_mixcot,
195
- vis_lm_max,
196
- vis_mmu_base,
197
- vis_mmu_mixcot,
198
- vis_mmu_max,
199
- # 新增的返回值,用于更新 thinking_mode 状态和按钮
200
- current_thinking_mode_lm_state, # 直接返回值给 gr.State
201
- update_think_button_lm, # gr.update 对象给 gr.Button
202
- current_thinking_mode_mmu_state,
203
- update_think_button_mmu
204
- )
205
 
206
 
207
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
@@ -691,7 +618,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
691
  model_select_radio = gr.Radio(
692
  label="Select Text Generation Model",
693
  choices=MODEL_CHOICES,
694
- value="MMaDA-8B-MixCoT"
695
  )
696
  model_load_status_box = gr.Textbox(
697
  label="Model Load Status",
@@ -736,27 +663,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
736
 
737
 
738
 
739
- examples_lm_base = gr.Examples(
740
- examples=[
741
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
742
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
743
- ],
744
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
745
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
746
- fn=generate_viz_wrapper_lm,
747
- cache_examples=False
748
- )
749
- examples_lm_mixcot = gr.Examples(
750
- examples=[
751
- ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
752
- ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
753
- ],
754
- inputs=[prompt_input_box_lm, steps_slider_lm, gen_length_slider_lm, block_length_slider_lm, temperature_slider_lm, cfg_scale_slider_lm, remasking_dropdown_lm],
755
- outputs=[output_visualization_box_lm, output_final_text_box_lm],
756
- fn=generate_viz_wrapper_lm,
757
- cache_examples=False
758
- )
759
- examples_lm_max = gr.Examples(
760
  examples=[
761
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
762
  ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
@@ -774,7 +681,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
774
  prompt_input_box_mmu = gr.Textbox(
775
  label="Enter your prompt:",
776
  lines=3,
777
- value=""
778
  )
779
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
780
  with gr.Accordion("Generation Parameters", open=True):
@@ -782,7 +689,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
782
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
783
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
784
  with gr.Row():
785
- block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=64, step=32, label="Block Length", info="gen_length must be divisible by this.")
786
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
787
  with gr.Row():
788
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
@@ -809,81 +716,7 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
809
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
810
 
811
 
812
- examples_mmu_base = gr.Examples(
813
- examples=[
814
- [
815
- "figs/sunflower.jpg",
816
- "Please describe this image in detail.",
817
- 256,
818
- 512,
819
- 128,
820
- 1,
821
- 0,
822
- "low_confidence"
823
- ],
824
- [
825
- "figs/woman.jpg",
826
- "Please describe this image in detail.",
827
- 256,
828
- 512,
829
- 128,
830
- 1,
831
- 0,
832
- "low_confidence"
833
- ]
834
- ],
835
- inputs=[
836
- image_upload_box,
837
- prompt_input_box_mmu,
838
- steps_slider_mmu,
839
- gen_length_slider_mmu,
840
- block_length_slider_mmu,
841
- temperature_slider_mmu,
842
- cfg_scale_slider_mmu,
843
- remasking_dropdown_mmu
844
- ],
845
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
846
- fn=generate_viz_wrapper,
847
- cache_examples=False
848
- )
849
- examples_mmu_mixcot = gr.Examples(
850
- examples=[
851
- [
852
- "figs/geo.png",
853
- "In the given figure, a square ABCD is inscribed in a circle with center O. Point P is located on side CD. What is the value of angle APB?",
854
- 256,
855
- 512,
856
- 64,
857
- 1,
858
- 0,
859
- "low_confidence"
860
- ],
861
- [
862
- "figs/bus.jpg",
863
- "What are the colors of the bus?",
864
- 256,
865
- 512,
866
- 64,
867
- 1,
868
- 0,
869
- "low_confidence"
870
- ]
871
- ],
872
- inputs=[
873
- image_upload_box,
874
- prompt_input_box_mmu,
875
- steps_slider_mmu,
876
- gen_length_slider_mmu,
877
- block_length_slider_mmu,
878
- temperature_slider_mmu,
879
- cfg_scale_slider_mmu,
880
- remasking_dropdown_mmu
881
- ],
882
- outputs=[output_visualization_box_mmu, output_final_text_box_mmu],
883
- fn=generate_viz_wrapper,
884
- cache_examples=False
885
- )
886
- examples_mmu_max = gr.Examples(
887
  examples=[
888
  [
889
  "figs/sunflower.jpg",
@@ -990,69 +823,21 @@ with gr.Blocks(css=css_styles, theme=theme) as demo:
990
  inputs=[thinking_mode_mmu],
991
  outputs=[thinking_mode_mmu, think_button_mmu]
992
  )
 
993
 
994
- def initialize_app_state():
995
- default_model_choice = "MMaDA-8B-MixCoT" # 默认加载 MixCoT
996
 
997
- # handle_model_selection_change 现在返回更多项
998
- status, lm_b_vis, lm_m_vis, lm_x_vis, \
999
- mmu_b_vis, mmu_m_vis, mmu_x_vis, \
1000
- init_thinking_lm_state, init_think_lm_btn_update, \
1001
- init_thinking_mmu_state, init_think_mmu_btn_update = handle_model_selection_change(default_model_choice)
1002
-
1003
- return (
1004
- default_model_choice,
1005
- status,
1006
- lm_b_vis,
1007
- lm_m_vis,
1008
- lm_x_vis,
1009
- mmu_b_vis,
1010
- mmu_m_vis,
1011
- mmu_x_vis,
1012
- init_thinking_lm_state,
1013
- init_think_lm_btn_update,
1014
- init_thinking_mmu_state,
1015
- init_think_mmu_btn_update
1016
- )
1017
 
1018
  demo.load(
1019
- fn=initialize_app_state,
1020
  inputs=None,
1021
- outputs=[
1022
- model_select_radio,
1023
- model_load_status_box,
1024
- examples_lm_base,
1025
- examples_lm_mixcot,
1026
- examples_lm_max,
1027
- examples_mmu_base,
1028
- examples_mmu_mixcot,
1029
- examples_mmu_max,
1030
- thinking_mode_lm, # gr.State for LM thinking mode
1031
- think_button_lm, # gr.Button for LM thinking mode
1032
- thinking_mode_mmu, # gr.State for MMU thinking mode
1033
- think_button_mmu # gr.Button for MMU thinking mode
1034
- ],
1035
  queue=True
1036
  )
1037
 
1038
- model_select_radio.change(
1039
- fn=handle_model_selection_change,
1040
- inputs=[model_select_radio],
1041
- outputs=[
1042
- model_load_status_box,
1043
- examples_lm_base,
1044
- examples_lm_mixcot,
1045
- examples_lm_max,
1046
- examples_mmu_base,
1047
- examples_mmu_mixcot,
1048
- examples_mmu_max,
1049
- thinking_mode_lm,
1050
- think_button_lm,
1051
- thinking_mode_mmu,
1052
- think_button_mmu
1053
- ]
1054
- )
1055
-
1056
  def clear_outputs():
1057
  return None, None, None # Clear image, visualization, and final text
1058
 
 
47
  return num_transfer_tokens
48
 
49
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
50
+ DEFAULT_MODEL_PATH = "Gen-Verse/MMaDA-8B-Base" # Default
51
  MASK_ID = 126336
52
  MODEL = MMadaModelLM.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True, torch_dtype=torch.bfloat16).to(DEVICE).eval()
53
  TOKENIZER = AutoTokenizer.from_pretrained(DEFAULT_MODEL_PATH, trust_remote_code=True)
54
  uni_prompting = UniversalPrompting(TOKENIZER, max_text_len=512, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),ignore_id=-100, cond_dropout_prob=0.1, use_reserved_token=True)
55
  VQ_MODEL = MAGVITv2().from_pretrained("showlab/magvitv2").to(DEVICE)
56
 
57
+ CURRENT_MODEL_PATH = None
58
 
59
  MODEL_CHOICES = [
60
  "MMaDA-8B-Base",
61
+ "MMaDA-8B-MixCoT (coming soon)",
62
  "MMaDA-8B-Max (coming soon)"
63
  ]
64
  MODEL_ACTUAL_PATHS = {
65
+ "MMaDA-8B-Base": DEFAULT_MODEL_PATH,
 
66
  }
67
 
68
  def clear_outputs_action():
 
116
  # return f"Error loading model '{model_display_name_for_status}': {str(e)}"
117
 
118
  def handle_model_selection_change(selected_model_name_ui):
119
+ if "coming soon" in selected_model_name_ui.lower():
120
+ global MODEL, TOKENIZER, MASK_ID, CURRENT_MODEL_PATH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  MODEL = None
122
  TOKENIZER = None
123
  MASK_ID = None
124
  CURRENT_MODEL_PATH = None
125
+ return f"'{selected_model_name_ui}' is not yet available. Please select 'Model A'."
126
+
127
+ actual_path = MODEL_ACTUAL_PATHS.get(selected_model_name_ui)
128
+ if not actual_path:
129
+ return f"Path for '{selected_model_name_ui}' is not defined. Cannot load."
130
+
131
+ return _load_model_and_tokenizer_core(actual_path, selected_model_name_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def get_highlighted_text_tuples(current_x_ids_batch, prompt_input_ids, prompt_len, tk, current_mask_id, raw_prompt_attention_mask):
 
618
  model_select_radio = gr.Radio(
619
  label="Select Text Generation Model",
620
  choices=MODEL_CHOICES,
621
+ value=MODEL_CHOICES[0]
622
  )
623
  model_load_status_box = gr.Textbox(
624
  label="Model Load Status",
 
663
 
664
 
665
 
666
+ gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
  examples=[
668
  ["A rectangular prism has a length of 5 units, a width of 4 units, and a height of 3 units. What is the volume of the prism?", 256, 512, 128, 1, 0, "low_confidence"],
669
  ["Lily can run 12 kilometers per hour for 4 hours. After that, she can run 6 kilometers per hour. How many kilometers can she run in 8 hours?", 256, 512, 64, 1, 0, "low_confidence"]
 
681
  prompt_input_box_mmu = gr.Textbox(
682
  label="Enter your prompt:",
683
  lines=3,
684
+ value="Please describe this image in detail."
685
  )
686
  think_button_mmu = gr.Button("🧠 Enable Thinking Mode", elem_id="think_btn")
687
  with gr.Accordion("Generation Parameters", open=True):
 
689
  gen_length_slider_mmu = gr.Slider(minimum=64, maximum=1024, value=512, step=64, label="Generation Length", info="Number of tokens to generate.")
690
  steps_slider_mmu = gr.Slider(minimum=1, maximum=512, value=256, step=32, label="Total Sampling Steps", info="Must be divisible by (gen_length / block_length).")
691
  with gr.Row():
692
+ block_length_slider_mmu = gr.Slider(minimum=32, maximum=1024, value=128, step=32, label="Block Length", info="gen_length must be divisible by this.")
693
  remasking_dropdown_mmu = gr.Dropdown(choices=['low_confidence', 'random'], value='low_confidence', label="Remasking Strategy")
694
  with gr.Row():
695
  cfg_scale_slider_mmu = gr.Slider(minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="CFG Scale", info="Classifier-Free Guidance. 0 disables it.")
 
716
  output_final_text_box_mmu = gr.Textbox(label="Final Output", lines=8, interactive=False, show_copy_button=True)
717
 
718
 
719
+ gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
720
  examples=[
721
  [
722
  "figs/sunflower.jpg",
 
823
  inputs=[thinking_mode_mmu],
824
  outputs=[thinking_mode_mmu, think_button_mmu]
825
  )
826
+
827
 
 
 
828
 
829
+ def initialize_default_model():
830
+ default_model = "MMaDA-8B-Base"
831
+ result = handle_model_selection_change(default_model)
832
+ return default_model, result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
 
834
  demo.load(
835
+ fn=initialize_default_model,
836
  inputs=None,
837
+ outputs=[model_select_radio, model_load_status_box],
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  queue=True
839
  )
840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
  def clear_outputs():
842
  return None, None, None # Clear image, visualization, and final text
843