|
import torch |
|
from timm import create_model |
|
|
|
|
|
def load_model(model_path): |
|
model = create_model('tf_efficientnet_b3_ns', num_classes=1605, pretrained=False) |
|
model.load_state_dict(torch.load(model_path)) |
|
return model |
|
|
|
|
|
def model_soup(models, weights): |
|
if len(models) != len(weights): |
|
raise ValueError("Number of models and weights must match") |
|
|
|
|
|
weights = [w / sum(weights) for w in weights] |
|
|
|
|
|
fused_model = create_model('tf_efficientnet_b3_ns', num_classes=1605, pretrained=False) |
|
fused_model_dict = fused_model.state_dict() |
|
|
|
for key in fused_model_dict.keys(): |
|
fused_model_dict[key] = sum(weight * models[i].state_dict()[key] for i, weight in enumerate(weights)) |
|
|
|
fused_model.load_state_dict(fused_model_dict) |
|
return fused_model |
|
|
|
|
|
model_paths = [ |
|
'/data/cjm/FungiCLEF2024/EfficientNet/output/trick_1.4.3/efficientnet_b3_epoch_28.pth', |
|
'/data/cjm/FungiCLEF2024/EfficientNet/output/trick_1.4.3.2/efficientnet_b3_epoch_28.pth', |
|
'/data/cjm/FungiCLEF2024/EfficientNet/output/trick_1.4.1/efficientnet_b3_epoch_23.pth', |
|
'/data/cjm/FungiCLEF2024/EfficientNet/output/trick_1.5.2/efficientnet_b3_epoch_21.pth', |
|
] |
|
|
|
|
|
models = [load_model(path) for path in model_paths] |
|
|
|
|
|
weights = [4, 2.6, 2.4, 1] |
|
|
|
|
|
fused_model = model_soup(models, weights) |
|
|
|
|
|
fused_model_path = '/data/cjm/FungiCLEF2024/EfficientNet/output/fused_model_soup.pth' |
|
torch.save(fused_model.state_dict(), fused_model_path) |
|
|
|
print(f"Fused model saved to {fused_model_path}") |
|
|