csuhan commited on
Commit
8d8ef52
·
1 Parent(s): 719a1dd
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -66,7 +66,7 @@ def load(
66
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
67
  # ckpt_path = checkpoints[local_rank]
68
  print("Loading")
69
- checkpoint = torch.load(ckpt_path, map_location="cuda")
70
  instruct_adapter_checkpoint = torch.load(
71
  instruct_adapter_path, map_location="cpu")
72
  caption_adapter_checkpoint = torch.load(
@@ -92,6 +92,7 @@ def load(
92
  torch.set_default_tensor_type(torch.FloatTensor)
93
  model.load_state_dict(checkpoint, strict=False)
94
  del checkpoint
 
95
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
96
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
97
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
 
66
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
67
  # ckpt_path = checkpoints[local_rank]
68
  print("Loading")
69
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
70
  instruct_adapter_checkpoint = torch.load(
71
  instruct_adapter_path, map_location="cpu")
72
  caption_adapter_checkpoint = torch.load(
 
92
  torch.set_default_tensor_type(torch.FloatTensor)
93
  model.load_state_dict(checkpoint, strict=False)
94
  del checkpoint
95
+ torch.cuda.empty_cache()
96
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
97
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
98
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)