SALT-SAM / AllinonSAM /driver_scratchpad.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import argparse
import yaml
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from data_utils import *
from model import *
from test import *
import pandas as pd
from train import *
import sys
import torch
import os
source_path = os.path.join("/home/abdelrahman.elsayed/CVPR/AllinonSAM/datasets")
sys.path.append(source_path)
from arcade import ArcadeDataset
from crfseg import CRF
import itertools
from utils import CosineAnnealingWarmupScheduler
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_config",
default="/home/abdelrahman.elsayed/CVPR/AllinonSAM/config_arcade.yml",
help="data config file path",
)
parser.add_argument(
"--model_config",
default="/home/abdelrahman.elsayed/CVPR/AllinonSAM/model_svdtuning.yml",
help="model config file path",
)
parser.add_argument("--pretrained_path", default=None, help="pretrained model path")
parser.add_argument(
"--save_path", default="checkpoints/temp.pth", help="pretrained model path"
)
parser.add_argument(
"--training_strategy", default="svdtuning", help="how to train the model"
)
parser.add_argument("--device", default="cuda:0", help="device to train on")
args = parser.parse_args()
return args
def main_onetime_functions(config):
dataset_dict, dataset_sizes, label_dict = get_data(
config,
tr_folder_start=0,
tr_folder_end=78000,
val_folder_start=0,
val_folder_end=104000,
use_norm=False,
)
for x in dataset_dict:
dataset_dict[x].one_time_generate_pos_neg_list_dicts(x)
def main_datautils(config, use_norm=True):
selected_idxs = [0, 12, 42, 79, 100]
print(config)
dataset_dict, dataset_sizes, label_dict = get_data(
config,
tr_folder_start=0,
tr_folder_end=78000,
val_folder_start=0,
val_folder_end=104000,
use_norm=use_norm,
)
# test without generating examples for legacy
# print(len(dataset_dict['train']))
# for i in selected_idxs:
# temp = (dataset_dict['train'][i])
# print(temp[-1])
# print(temp[-2])
# print(temp[0].shape)
# print(temp[1].shape)
# plt.imshow(temp[0].permute(1,2,0), cmap='gray')
# plt.show()
# plt.imshow(temp[1], cmap='gray')
# plt.show()
# test generate examples function
print("testing generate examples\n")
try:
dataset_dict["train"].generate_examples()
except:
pass
print(len(dataset_dict["train"]))
for i in selected_idxs:
temp = dataset_dict["train"][i]
print(temp[-1])
print(temp[-2])
print(temp[0].shape)
print(temp[1].shape)
try:
plt.imshow(temp[1], cmap="gray")
plt.show()
print(temp[0].min(), temp[0].max())
plt.imshow(temp[0].permute(1, 2, 0), cmap="gray")
plt.show()
except:
print("temp range: ", temp[0][0].min(), temp[0][0].max())
plt.imshow(temp[0][0].permute(1, 2, 0), cmap="gray")
plt.show()
print("temp label range: ", temp[1][0].min(), temp[1][0].max())
plt.imshow(temp[1][0], cmap="gray")
plt.show()
def main_model(config):
print(config)
training_strategy = "svdtuning"
label_dict = {"liver": 0, "tumor": 1}
model = Prompt_Adapted_SAM(config, label_dict)
# freeze correct weights
for p in model.parameters():
p.requires_grad = True
# unfreeze according to strategy:
for name, p in model.named_parameters():
# if training_strategy=='svdtuning':
# if 'trainable' in name.lower():
# p.requires_grad = True
# elif training_strategy=='biastuning':
# if ('bias' in name.lower()) and ('clip' not in name.lower()):
# p.requires_grad = True
# elif training_strategy=='svdbiastuning':
# if 'trainable' in name.lower():
# p.requires_grad = True
# if ('bias' in name.lower()) and ('clip' not in name.lower()):
# p.requires_grad = True
if model_config["prompts"]["USE_TEXT_PROMPT"]:
if "Text_Embedding_Affine" in name:
p.requires_grad = True
if "clip" in name:
p.requires_grad = False
# for name, p in model.named_parameters():
# if p.requires_grad:
# print(name)
print(
"number of trainable parameters: ",
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
return
def main_test(data_config, model_config, pretrained_path):
test_start = 104
test_end = 131
test(
data_config,
model_config,
pretrained_path,
test_start,
test_end,
device="cuda:0",
)
def lr_lambda(step):
if step < model_config["training"]["warmup_steps"]:
return step / model_config["training"]["warmup_steps"] # Linear warm-up
elif step < model_config["training"]["steps"][0]:
return 1.0 # Maintain initial learning rate
elif step < model_config["training"]["steps"][1]:
return 1 / model_config["training"]["decay_factor"] # First decay
else:
return 1 / (model_config["training"]["decay_factor"] ** 2) # Second decay
def main_train(
data_config,
model_config,
pretrained_path,
save_path,
training_strategy="biastuning",
device="cuda:0",
):
print(data_config)
print(model_config)
# load data
if data_config["data"]["name"] == "LITS":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=78,
val_folder_start=78,
val_folder_end=104,
)
elif data_config["data"]["name"] == "AMOS22":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=78,
val_folder_start=78,
val_folder_end=104,
)
elif data_config["data"]["name"] == "IDRID":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=40,
val_folder_start=40,
val_folder_end=104,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ENDOVIS":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=180,
val_folder_start=180,
val_folder_end=304,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ENDOVIS 18":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "CHESTXDET":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "CHOLEC 8K":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ULTRASOUND":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "KVASIRSEG":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "LITS2":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ISIC2018":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "Polyp":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "RITE":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "GLAS":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "Refuge":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "BTCV":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ATR":
dataset_dict, dataset_sizes, label_dict = get_data(
data_config,
tr_folder_start=0,
tr_folder_end=18000,
val_folder_start=0,
val_folder_end=34444,
)
dataloader_dict = {}
for x in ["train", "val"]:
dataloader_dict[x] = torch.utils.data.DataLoader(
dataset_dict[x],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
elif data_config["data"]["name"] == "ArcadeDataset":
print("HERE")
data_split_csv_path = data_config["data"]["data_split_csv"]
data_split = pd.read_csv(data_split_csv_path)
dataset_dict = {}
dataloader_dict = {}
use_norm = True
no_text_mode = False
for split in ["train", "val"]:
# Filter the CSV for the current split
split_data = data_split[data_split["split"] == split]["imgs"].tolist()
# Pass the filtered data to the dataset class (ArcadeDataset)
dataset_dict[split] = ArcadeDataset(
config=data_config,
file_list=split_data, # Pass file_list as (image_path, mask_path) tuples
shuffle_list=True,
is_train=(split == "train"),
apply_norm=use_norm,
no_text_mode=no_text_mode,
)
# Create DataLoader for each dataset
dataloader_dict[split] = torch.utils.data.DataLoader(
dataset_dict[split],
batch_size=model_config["training"]["batch_size"],
shuffle=True,
num_workers=4,
)
# Get dataset sizes
dataset_sizes = {split: len(dataset_dict[split]) for split in ["train", "val"]}
# Create label dictionary
label_dict = {
name: i for i, name in enumerate(data_config["data"]["label_names"])
}
# Print dataset sizes
print(f"Train dataset size: {dataset_sizes['train']}")
print(f"Val dataset size: {dataset_sizes['val']}")
# Get dataset sizes
dataset_sizes = {split: len(dataset_dict[split]) for split in ["train", "val"]}
# Create label dictionary
label_dict = {
name: i for i, name in enumerate(data_config["data"]["label_names"])
}
# Print dataset sizes
print(f"Train dataset size: {dataset_sizes['train']}")
print(f"Val dataset size: {dataset_sizes['val']}")
# load model
# change the img size in model config according to data config
model_config["sam"]["img_size"] = data_config["data_transforms"]["img_size"]
model_config["sam"]["num_classes"] = len(data_config["data"]["label_list"])
if training_strategy == "lora":
model_config["use_lora"] = True
else:
model_config["use_lora"] = False
if training_strategy == "biastuning":
model_config["decoder_training"] = "full"
if model_config["arch"] == "Prompt Adapted SAM":
model = Prompt_Adapted_SAM(
model_config, label_dict, device, training_strategy=training_strategy
)
# load model weights
if pretrained_path is not None:
model.load_state_dict(torch.load(pretrained_path))
# freeze correct weights
for p in model.parameters():
# p.requires_grad=True
p.requires_grad = False
# unfreeze according to strategy:
for name, p in model.named_parameters():
if training_strategy == "svdtuning":
if "trainable" in name.lower():
p.requires_grad = True
elif training_strategy == "biastuning":
if ("bias" in name.lower()) and ("clip" not in name.lower()):
p.requires_grad = True
elif training_strategy == "svdbiastuning":
if "trainable" in name.lower():
p.requires_grad = True
if ("bias" in name.lower()) and ("clip" not in name.lower()):
p.requires_grad = True
elif training_strategy == "lora":
if "trainable_lora" in name.lower():
p.requires_grad = True
if model_config["prompts"]["USE_TEXT_PROMPT"]:
if "Text_Embedding_Affine" in name:
p.requires_grad = True
if model_config["prompts"]["USE_SLICE_NUM"]:
if "slice" in name:
p.requires_grad = True
if model_config["decoder_training"] == "full":
if ("decoder" in name.lower()) and ("clip" not in name.lower()):
p.requires_grad = True
elif model_config["decoder_training"] == "svdtuning":
if "trainable" in name.lower():
p.requires_grad = True
elif model_config["decoder_training"] == "none":
if "decoder" in name.lower():
p.requires_grad = False
if "prompt_encoder" in name.lower():
p.requires_grad = False
# p.requires_grad = True
# common parameters
if "norm" in name.lower():
p.requires_grad = True
if "pos_embed" in name.lower():
p.requires_grad = True
if "clip" in name.lower():
p.requires_grad = False
# training parameters
training_params = model_config["training"]
if training_params["optimizer"] == "adamw":
optimizer = optim.AdamW(
model.parameters(),
lr=float(training_params["lr"]),
weight_decay=float(training_params["weight_decay"]),
)
elif training_params["optimizer"] == "sgd":
optimizer = optim.SGD(
model.parameters(),
lr=float(training_params["lr"]),
weight_decay=float(training_params["weight_decay"]),
momentum=0.9,
)
# USED LAMBDALR or CosineAnnealing instead of STEPLR
if training_params["schedular"] == "cosine_warmup":
return CosineAnnealingWarmupScheduler(
optimizer,
warmup_epochs=training_params["warmup_epochs"],#TODO: Add it the config file (organize it in more good way),
total_epochs=training_params["num_epochs"],
min_lr=training_params["min_lr"] , #TODO: Add it the config file (organize it in more good way)
warmup_start_lr=training_params["lr"]
)
# I STILL Use this for some of my experiments thats why I am keeping it
if training_params["schedular"] == "step":
exp_lr_scheduler = lr_scheduler.StepLR(
optimizer,
step_size=training_params["schedule_step"],
gamma=training_params["schedule_step_factor"],
)
else:
exp_lr_scheduler = lr_scheduler.LambdaLR(
optimizer,
lr_lambda,
)
criterion = []
if "dice" in training_params["loss"]:
criterion.append(dice_loss)
if "focal" in training_params["loss"]:
criterion.append(focal_loss)
if "CE" in training_params["loss"]:
criterion.append(nn.BCELoss())
if "weighted CE" in training_params["loss"]:
criterion.append(weighted_ce_loss)
if criterion == []:
criterion = [nn.BCELoss()]
# retain_graph = False if model_config['decoder_training']=='none' else True
retain_graph = False
# train the model
if data_config["data"]["name"] == "LITS":
model = train(
model,
dataset_dict["train"],
dataset_dict["val"],
criterion,
optimizer,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
)
elif data_config["data"]["name"] == "AMOS22":
model = train(
model,
dataset_dict["train"],
dataset_dict["val"],
criterion,
optimizer,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
)
# model = train_dl(model, dataset_dict, dataset_sizes, criterion, optimizer, exp_lr_scheduler, save_path, num_epochs=training_params['num_epochs'], bs=training_params['batch_size'], device=device, retain_graph=retain_graph, neg2pos_ratio=data_config['data']['negative_to_positive_ratio'], reg_multiplier=model_config['training']['reg_multiplier'])
elif data_config["data"]["name"] == "IDRID":
model = train_dl(
model,
dataloader_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ENDOVIS":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ENDOVIS 18":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "CHOLEC 8K":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ULTRASOUND":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "KVASIRSEG":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "CHESTXDET":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "LITS2":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ISIC2018":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "Polyp":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "RITE":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "GLAS":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "Refuge":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "BTCV":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ATR":
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
num_epochs=training_params["num_epochs"],
bs=training_params["batch_size"],
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
elif data_config["data"]["name"] == "ArcadeDataset":
save_path = "./models" + data_config["data"]["root_path"].split("/")[-1]
model = train_dl(
model,
dataset_dict,
dataset_sizes,
criterion,
optimizer,
exp_lr_scheduler,
save_path,
save_dir=f"./{args.training_strategy}/{data_config['data']['root_path'].split('/')[-1]}",
num_epochs=training_params["num_epochs"],
bs=5,
device=device,
retain_graph=retain_graph,
neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
reg_multiplier=model_config["training"]["reg_multiplier"],
)
# print("Starting RLHF fine-tuning...")
# model.train()
# # get the training dataloader
# train_datatloader = dataloader_dict["train"]
# val_dataloader = dataloader_dict["val"]
# rewardmodel = RewardModel(save_dir="DIAS_rhlf_30")
# rewardmodel = rewardmodel.to(device)
# rlhf_model = train_rlhf(
# model,
# model_config,
# label_dict,
# rewardmodel,
# train_datatloader,
# val_dataloader,
# 40,
# )
# # more tuning
# optimizer = optim.AdamW(
# rlhf_model.parameters(),
# lr=float(training_params["lr"]),
# weight_decay=float(training_params["weight_decay"]),
# )
# exp_lr_scheduler = lr_scheduler.StepLR(
# optimizer,
# step_size=training_params["schedule_step"],
# gamma=training_params["schedule_step_factor"],
# )
# final_model = train_dl(
# rlhf_model,
# dataset_dict,
# dataset_sizes,
# criterion,
# optimizer,
# exp_lr_scheduler,
# save_path,
# save_dir=f"./{args.training_strategy}/{data_config['data']['root_path'].split('/')[-1]}",
# num_epochs=50,
# bs=5,
# device=device,
# retain_graph=retain_graph,
# neg2pos_ratio=data_config["data"]["negative_to_positive_ratio"],
# reg_multiplier=model_config["training"]["reg_multiplier"],
# )
if __name__ == "__main__":
args = parse_args()
with open(args.data_config, "r") as f:
data_config = yaml.load(f, Loader=yaml.FullLoader)
with open(args.model_config, "r") as f:
model_config = yaml.load(f, Loader=yaml.FullLoader)
# main_onetime_functions(data_config)
# #for checking data_utils
# main_datautils(data_config, use_norm=False)
# #for checking model
# main_model(config=model_config)
# #for testing on the test dataset
# main_test(data_config, model_config, args.pretrained_path)
# # for training the model
main_train(
data_config,
model_config,
args.pretrained_path,
args.save_path,
args.training_strategy,
device=args.device,
)