Sukanyaaa commited on
Commit
1634179
·
1 Parent(s): bee09fe

fix inference_app.py

Browse files
Files changed (1) hide show
  1. inference_app.py +9 -2
inference_app.py CHANGED
@@ -104,6 +104,7 @@ def get_system(system_id: str) -> PinderSystem:
104
  return PinderSystem(system_id)
105
  from Bio import PDB
106
  from Bio.PDB.PDBIO import PDBIO
 
107
 
108
  def extract_coordinates_from_pdb(filename):
109
  r"""
@@ -873,8 +874,14 @@ class FinalMPNNModelight(pl.LightningModule):
873
  },
874
  }
875
 
876
- model_path = "/home/sukanya/iitm_bisect_pinder_submission/EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt"
877
- model = FinalMPNNModelight.load_from_checkpoint(model_path)
 
 
 
 
 
 
878
  trainer = pl.Trainer(
879
 
880
 
 
104
  return PinderSystem(system_id)
105
  from Bio import PDB
106
  from Bio.PDB.PDBIO import PDBIO
107
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
108
 
109
  def extract_coordinates_from_pdb(filename):
110
  r"""
 
874
  },
875
  }
876
 
877
+ # model_path = "/home/sukanya/iitm_bisect_pinder_submission/EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt"
878
+ # model = FinalMPNNModelight.load_from_checkpoint(model_path)
879
+ model = FinalMPNNModelight()
880
+ model = FSDP(model)
881
+ checkpoint = torch.load("/home/sukanya/iitm_bisect_pinder_submission/EquiMPNN-epoch=413-val_loss=9.25-val_acc=0.00.ckpt")
882
+ model_state_dict = checkpoint['state_dict']
883
+ model.load_state_dict(model_state_dict)
884
+
885
  trainer = pl.Trainer(
886
 
887