smhh24's picture
Upload 90 files
560b597 verified
raw
history blame
12.9 kB
from collections import defaultdict
from functools import partial, wraps
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from scipy import interpolate
def max_stack(tensors):
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).max(dim=-1).values
def last_stack(tensors):
return tensors[-1]
def first_stack(tensors):
return tensors[0]
def softmax_stack(tensors, temperature=1.0):
if len(tensors) == 1:
return tensors[0]
return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
def mean_stack(tensors):
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).mean(dim=-1)
def sum_stack(tensors):
if len(tensors) == 1:
return tensors[0]
return torch.stack(tensors, dim=-1).sum(dim=-1)
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
if l.bias is not None:
l.bias.data = l.bias.data.float()
def format_seconds(seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return f"{hours:d}:{minutes:02d}:{seconds:02d}"
def get_params(module, lr, wd):
skip_list = {}
skip_keywords = {}
if hasattr(module, "no_weight_decay"):
skip_list = module.no_weight_decay()
if hasattr(module, "no_weight_decay_keywords"):
skip_keywords = module.no_weight_decay_keywords()
has_decay = []
no_decay = []
for name, param in module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if (
(name in skip_list)
or any((kw in name for kw in skip_keywords))
or len(param.shape) == 1
):
# if (name in skip_list) or any((kw in name for kw in skip_keywords)):
# print(name, skip_keywords)
no_decay.append(param)
else:
has_decay.append(param)
group1 = {
"params": has_decay,
"weight_decay": wd,
"lr": lr,
"weight_decay_init": wd,
"weight_decay_base": wd,
"lr_init": lr,
"lr_base": lr,
}
group2 = {
"params": no_decay,
"weight_decay": 0.0,
"lr": lr,
"weight_decay_init": 0.0,
"weight_decay_base": 0.0,
"weight_decay_final": 0.0,
"lr_init": lr,
"lr_base": lr,
}
return [group1, group2], [lr, lr]
def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("layers"):
if var_name.split(".")[2] == "blocks":
stage_id = int(var_name.split(".")[1])
layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
return layer_id + 1
elif var_name.split(".")[2] == "downsample":
stage_id = int(var_name.split(".")[1])
layer_id = sum(layers_per_stage[: stage_id + 1])
return layer_id
else:
return num_max_layer - 1
def get_params_layerdecayswin(module, lr, wd, ld):
skip_list = {}
skip_keywords = {}
if hasattr(module, "no_weight_decay"):
skip_list = module.no_weight_decay()
if hasattr(module, "no_weight_decay_keywords"):
skip_keywords = module.no_weight_decay_keywords()
layers_per_stage = module.depths
num_layers = sum(layers_per_stage) + 1
lrs = []
params = []
for name, param in module.named_parameters():
if not param.requires_grad:
print(f"{name} frozen")
continue # frozen weights
layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
lr_cur = lr * ld ** (num_layers - layer_id - 1)
# if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
if (name in skip_list) or any((kw in name for kw in skip_keywords)):
wd_cur = 0.0
else:
wd_cur = wd
params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
lrs.append(lr_cur)
return params, lrs
def log(t, eps: float = 1e-5):
return torch.log(t.clamp(min=eps))
def l2norm(t):
return F.normalize(t, dim=-1)
def exists(val):
return val is not None
def identity(t, *args, **kwargs):
return t
def divisible_by(numer, denom):
return (numer % denom) == 0
def first(arr, d=None):
if len(arr) == 0:
return d
return arr[0]
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def maybe(fn):
@wraps(fn)
def inner(x):
if not exists(x):
return x
return fn(x)
return inner
def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner
def _many(fn):
@wraps(fn)
def inner(tensors, pattern, **kwargs):
return (fn(tensor, pattern, **kwargs) for tensor in tensors)
return inner
rearrange_many = _many(rearrange)
repeat_many = _many(repeat)
reduce_many = _many(reduce)
def load_pretrained(state_dict, checkpoint):
checkpoint_model = checkpoint["model"]
if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
checkpoint_model = {
k.replace("encoder.", ""): v
for k, v in checkpoint_model.items()
if k.startswith("encoder.")
}
print("Detect pre-trained model, remove [encoder.] prefix.")
else:
print("Detect non-pre-trained model, pass without doing anything.")
print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
def load_checkpoint_swin(model, checkpoint_model):
state_dict = model.state_dict()
# Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_bias_table" in key:
relative_position_bias_table_pretrained = checkpoint_model[key]
relative_position_bias_table_current = state_dict[key]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
print(f"Error in loading {key}, passing......")
else:
if L1 != L2:
print(f"{key}: Interpolate relative_position_bias_table using geo.")
src_size = int(L1**0.5)
dst_size = int(L2**0.5)
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
# if q > 1.090307:
# q = 1.090307
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q ** (i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
print("Original positions = %s" % str(x))
print("Target positions = %s" % str(dx))
all_rel_pos_bias = []
for i in range(nH1):
z = (
relative_position_bias_table_pretrained[:, i]
.view(src_size, src_size)
.float()
.numpy()
)
f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx, dy))
.contiguous()
.view(-1, 1)
.to(relative_position_bias_table_pretrained.device)
)
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
checkpoint_model[key] = new_rel_pos_bias
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in checkpoint_model.keys() if "relative_position_index" in k
]
for k in relative_position_index_keys:
del checkpoint_model[k]
# delete relative_coords_table since we always re-init it
relative_coords_table_keys = [
k for k in checkpoint_model.keys() if "relative_coords_table" in k
]
for k in relative_coords_table_keys:
del checkpoint_model[k]
# # re-map keys due to name change
rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
for k in rpe_mlp_keys:
checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del checkpoint_model[k]
encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
for k in encoder_keys:
checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
return checkpoint_model
def add_padding_metas(out, image_metas):
device = out.device
# left, right, top, bottom
paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas]
paddings = torch.stack(paddings).to(device)
outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
return torch.stack(outs)
def remove_padding(out, paddings):
B, C, H, W = out.shape
device = out.device
# left, right, top, bottom
paddings = torch.stack(paddings).to(device)
outs = [
o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]]
for padding, o in zip(paddings, out)
]
return torch.stack(outs)
def remove_padding_metas(out, image_metas):
# left, right, top, bottom
paddings = [
torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas
]
return remove_padding(out, paddings)
def ssi_helper(tensor1, tensor2):
stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
tensor2_one.T @ tensor1.unsqueeze(1)
)
scale, shift = scale_shift.squeeze().chunk(2, dim=0)
return scale, shift
def calculate_mean_values(names, values):
# Create a defaultdict to store sum and count for each name
name_values = {name: {} for name in names}
# Iterate through the lists and accumulate values for each name
for name, value in zip(names, values):
name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
# Calculate mean values and create the output dictionary
output_dict = {
name: name_values[name]["sum"] / name_values[name]["count"]
for name in name_values
}
return output_dict
def remove_leading_dim(infos):
if isinstance(infos, dict):
return {k: remove_leading_dim(v) for k, v in infos.items()}
elif isinstance(infos, torch.Tensor):
return infos.squeeze(0)
else:
return infos
def to_cpu(infos):
if isinstance(infos, dict):
return {k: to_cpu(v) for k, v in infos.items()}
elif isinstance(infos, torch.Tensor):
return infos.detach()
else:
return infos