import gradio as gr import json import os import re from collections import defaultdict def load_main_data(json_file_path): """假设 JSON 是一个列表,每个元素形如: { 'image_path': 'some_path', 'subject': 'xxx', 'object': 'yyy', 'options': { 'state': [...], 'action': [...], 'spatial': [...] } } """ with open(json_file_path, 'r', encoding='utf-8') as f: return json.load(f) def load_output_dict(output_file): """读取已标注数据。如果不存在,则返回空字典。""" if os.path.exists(output_file): try: with open(output_file, 'r', encoding='utf-8') as f: data = json.load(f) if not isinstance(data, dict): data = {} return data except json.JSONDecodeError: return {} else: return {} def save_output_dict(output_file, data): """保存标注结果到 output.json""" print(f"Try to save output") with open(output_file, 'w', encoding='utf-8') as f: json.dump(data, f, indent=2, ensure_ascii=False) def extract_image_id_from_path(full_path): """ 从 full_path 提取 'image_数字' 这一段,若找不到则返回去掉目录的文件名。 例如: annotated_image_folder\\split_4\\output_images_1592503\\image_1592503_pair_2_black bus_parked on.jpg -> 'image_1592503' """ # 先统一斜杠 full_path = full_path.replace("\\", "/") filename = os.path.basename(full_path) # 用正则匹配 "image_后面若干数字" m = re.search(r"(image_\d+)", filename) if m: return m.group(1) return filename # 如果失败,就退而求其次(不建议这么多文件都失败) def gradio_interface(json_file_path='sample_4.json'): """ 主要变化: 1) 用 extract_image_id_from_path 提取 image_XXXX 做分组,以便同一原图的多个 pair 正确显示 "Pair x/y for this image"。 2) 保留 Subject / Object 并排显示,并在 status 中额外显示:此 pair 在当前图片中是第几/共几。 """ data = load_main_data(json_file_path) # 假设是 list output_file = 'output.json' labeled_data = load_output_dict(output_file) # --------------------------------------------------- # 1) 预处理:根据 "image_id" 分组 # --------------------------------------------------- image_to_indices = defaultdict(list) for idx, item in enumerate(data): raw_path = item.get("image_path", "") image_id = extract_image_id_from_path(raw_path) image_to_indices[image_id].append(idx) local_index_map = {} local_count_map = {} for image_id, idx_list in image_to_indices.items(): # 保持出现顺序 for local_i, real_idx in enumerate(idx_list): local_index_map[real_idx] = local_i local_count_map[real_idx] = len(idx_list) # --------------------------------------------------- # 2) 一些辅助函数 # --------------------------------------------------- def get_item_info(idx): item = data[idx] image_path = item.get("image_path", "") # print(f"The current image_path is {image_path}") if not os.path.exists(image_path): image_path = "placeholder.jpg" subject = item.get("subject", "") obj = item.get("object", "") opts = item.get("options", {}) return image_path, subject, obj, opts def split_options(options_list): """前5个给 Radio,其余给 Dropdown""" if len(options_list) <= 5: return options_list, [] else: return options_list[:5], options_list[5:] def update_final_selection(radio_val, dropdown_val): """Radio 优先,否则 Dropdown""" if radio_val: return radio_val return dropdown_val or None def update_skip_value(checked): """skip_checkbox => bool -> str""" return str(checked) # --------------------------------------------------- # 3) 初始化:idx=0 # --------------------------------------------------- init_idx = 0 init_image, init_sub, init_obj, init_opts = get_item_info(init_idx) state_radio_list, state_dropdown_list = split_options(init_opts.get("state", [])) action_radio_list, action_dropdown_list = split_options(init_opts.get("action", [])) spatial_radio_list, spatial_dropdown_list = split_options(init_opts.get("spatial", [])) init_radio_val = None init_dropdown_val = None init_skip_val = False # --------------------------------------------------- # 4) 搭建 Gradio 界面 # --------------------------------------------------- with gr.Blocks() as demo: cur_idx_state = gr.State(init_idx) with gr.Row(): # 左侧:图像、Status、Details,以及翻页按钮 with gr.Column(scale=1): img_view = gr.Image(value=init_image, label="Image") # 这里的 status_box 会显示 全局进度+当前图片内的进度 status_box = gr.Textbox( value="", label="Status", interactive=False ) info_box = gr.Textbox( value="Details: (will be updated...)", label="Details", interactive=False ) with gr.Row(): btn_prev = gr.Button("← Previous", variant="secondary") btn_next = gr.Button("Next →", variant="primary") # 右侧:主逻辑 with gr.Column(scale=1): # 在同一个 Row 显示 (Subject -> Object) + skip_checkbox with gr.Row(): subject_object_md = gr.Markdown( f"**{init_sub} → {init_obj}**", elem_id="subject_object_header" ) skip_checkbox = gr.Checkbox( value=init_skip_val, label="No relation (skip this pair)" ) skip_final = gr.Textbox(value=str(init_skip_val), visible=False) skip_checkbox.change( fn=update_skip_value, inputs=[skip_checkbox], outputs=[skip_final] ) # --- State --- gr.Markdown("### State") state_radio = gr.Radio(choices=state_radio_list, value=init_radio_val, label="Top 5") state_dd = gr.Dropdown(choices=state_dropdown_list, value=init_dropdown_val, label="More Options") state_final = gr.Textbox(value=None, visible=False, label="Final State") state_radio.change( fn=update_final_selection, inputs=[state_radio, state_dd], outputs=state_final ) state_dd.change( fn=update_final_selection, inputs=[state_radio, state_dd], outputs=state_final ) # --- Action --- gr.Markdown("### Action") action_radio = gr.Radio(choices=action_radio_list, value=init_radio_val, label="Top 5") action_dd = gr.Dropdown(choices=action_dropdown_list, value=init_dropdown_val, label="More Options") action_final = gr.Textbox(value=None, visible=False, label="Final Action") action_radio.change( fn=update_final_selection, inputs=[action_radio, action_dd], outputs=action_final ) action_dd.change( fn=update_final_selection, inputs=[action_radio, action_dd], outputs=action_final ) # --- Spatial --- gr.Markdown("### Spatial") spatial_radio = gr.Radio(choices=spatial_radio_list, value=init_radio_val, label="Top 5") spatial_dd = gr.Dropdown(choices=spatial_dropdown_list, value=init_dropdown_val, label="More Options") spatial_final = gr.Textbox(value=None, visible=False, label="Final Spatial") spatial_radio.change( fn=update_final_selection, inputs=[spatial_radio, spatial_dd], outputs=spatial_final ) spatial_dd.change( fn=update_final_selection, inputs=[spatial_radio, spatial_dd], outputs=spatial_final ) # 底部的 Save with gr.Row(): btn_save = gr.Button("Save", variant="primary") # --------------------------------------------------- # 5) 翻页函数 # --------------------------------------------------- def go_next(cur_idx): new_idx = (cur_idx + 1) % len(data) return _jump_to_index(new_idx) def go_prev(cur_idx): new_idx = (cur_idx - 1) % len(data) return _jump_to_index(new_idx) def _jump_to_index(new_idx): # 获取数据 image_path, sub, obj, opts = get_item_info(new_idx) # 全局进度:new_idx+1 / len(data) global_status = f"Currently showing: {new_idx+1}/{len(data)}" # 获取本图的局部索引 local_idx = local_index_map[new_idx] # 从 0 开始 local_count = local_count_map[new_idx] # 组合显示 new_status = f"{global_status}. (Pair {local_idx+1}/{local_count} for this image.)" new_info = f"Subject: {sub}, Object: {obj}" # 改 Markdown: "**sub -> obj**" subobj_md = f"**{sub} → {obj}**" st_list, st_dd = split_options(opts.get("state", [])) ac_list, ac_dd = split_options(opts.get("action", [])) sp_list, sp_dd = split_options(opts.get("spatial", [])) rec = labeled_data.get(str(new_idx), {}) skip_val = rec.get("skip", False) if skip_val is True: final_st_val = None final_ac_val = None final_sp_val = None else: final_st_val = rec.get("state", None) final_ac_val = rec.get("action", None) final_sp_val = rec.get("spatial", None) return ( # 更新索引 new_idx, # 更新图像 image_path, # 更新 Status, Info new_status, new_info, # 更新 subject_object_md subobj_md, # skip bool(skip_val), str(skip_val), # state gr.update(choices=st_list, value=None), gr.update(choices=st_dd, value=None), final_st_val, # action gr.update(choices=ac_list, value=None), gr.update(choices=ac_dd, value=None), final_ac_val, # spatial gr.update(choices=sp_list, value=None), gr.update(choices=sp_dd, value=None), final_sp_val ) btn_next.click( fn=go_next, inputs=[cur_idx_state], outputs=[ cur_idx_state, img_view, status_box, info_box, subject_object_md, skip_checkbox, skip_final, state_radio, state_dd, state_final, action_radio, action_dd, action_final, spatial_radio, spatial_dd, spatial_final ] ) btn_prev.click( fn=go_prev, inputs=[cur_idx_state], outputs=[ cur_idx_state, img_view, status_box, info_box, subject_object_md, skip_checkbox, skip_final, state_radio, state_dd, state_final, action_radio, action_dd, action_final, spatial_radio, spatial_dd, spatial_final ] ) # --------------------------------------------------- # 6) 保存逻辑 # --------------------------------------------------- def handle_save(st_val, ac_val, sp_val, cur_idx, skip_val): skip_flag = (skip_val == "True") image_path, sub, obj, _ = get_item_info(cur_idx) if skip_flag: labeled_data[str(cur_idx)] = { "subject": sub, "object": obj, "skip": True } save_output_dict(output_file, labeled_data) return f"Skipped pair: {sub} - {obj}." else: if not st_val or not ac_val or not sp_val: return "Please select all 3 categories or check 'no suitable option'!" labeled_data[str(cur_idx)] = { "subject": sub, "object": obj, "skip": False, "state": st_val, "action": ac_val, "spatial": sp_val } save_output_dict(output_file, labeled_data) return f"Saved: {sub}, {obj}, state={st_val}, action={ac_val}, spatial={sp_val}" btn_save.click( fn=handle_save, inputs=[state_final, action_final, spatial_final, cur_idx_state, skip_final], outputs=status_box ) return demo if __name__ == '__main__': gradio_interface().launch(share=True)