def convert_weight(old_state_dict, new_state_dict, model_size: int = 38): # TODO: need to refactor shift = 1 for idx in range(model_size): new_list, old_list = [], [] for weight_name, weight_value in new_state_dict.items(): if weight_name.split(".")[0] == str(idx): new_list.append((weight_name, None)) for weight_name, weight_value in old_state_dict.items(): if f"model.{idx+shift}." in weight_name: old_list.append((weight_name, weight_value)) if len(new_list) == len(old_list): for (weight_name, _), (_, weight_value) in zip(new_list, old_list): new_state_dict[weight_name] = weight_value else: for weight_name, weight_value in old_list: if "dfl" in weight_name: continue _, _, conv_name, conv_idx, *details = weight_name.split(".") if conv_name == "cv4" or conv_name == "cv5": layer_idx = 22 shift = 2 else: layer_idx = 37 if conv_name == "cv2" or conv_name == "cv4": conv_task = "anchor_conv" if conv_name == "cv3" or conv_name == "cv5": conv_task = "class_conv" weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details]) new_state_dict[weight_name] = weight_value return new_state_dict