anjikum commited on
Commit
3241ea4
·
verified ·
1 Parent(s): 2c7c575

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -184,9 +184,36 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
184
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
185
  model = SmolLM2(SmolLM2Config())
186
 
187
- # Load trained weights
188
- checkpoint = torch.load('checkpoint_step_5000.pt', map_location=device)
189
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  model.to(device)
191
  model.eval()
192
 
 
184
  tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
185
  model = SmolLM2(SmolLM2Config())
186
 
187
+ # Load trained weights from zip
188
+ def load_checkpoint_from_zip(zip_path, device):
189
+ try:
190
+ with zipfile.ZipFile(zip_path, 'r') as zf:
191
+ with zf.open('checkpoint.pt') as f:
192
+ buffer = io.BytesIO(f.read())
193
+ checkpoint = torch.load(buffer, map_location=device)
194
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
195
+ return checkpoint['model_state_dict']
196
+ return checkpoint
197
+ except Exception as e:
198
+ print(f"Error loading checkpoint: {e}")
199
+ return None
200
+
201
+ # Try to load the checkpoint
202
+ try:
203
+ checkpoint_path = 'checkpoint_step_5000.zip' # or .pt file
204
+ if checkpoint_path.endswith('.zip'):
205
+ state_dict = load_checkpoint_from_zip(checkpoint_path, device)
206
+ else:
207
+ checkpoint = torch.load(checkpoint_path, map_location=device)
208
+ state_dict = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) else checkpoint
209
+
210
+ if state_dict is not None:
211
+ model.load_state_dict(state_dict)
212
+ print("Checkpoint loaded successfully")
213
+ except Exception as e:
214
+ print(f"Error loading checkpoint: {e}")
215
+ print("Initializing model with random weights")
216
+
217
  model.to(device)
218
  model.eval()
219