Upload app.py
Browse files
app.py
CHANGED
@@ -184,9 +184,36 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
184 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
185 |
model = SmolLM2(SmolLM2Config())
|
186 |
|
187 |
-
# Load trained weights
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
model.to(device)
|
191 |
model.eval()
|
192 |
|
|
|
184 |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer")
|
185 |
model = SmolLM2(SmolLM2Config())
|
186 |
|
187 |
+
# Load trained weights from zip
|
188 |
+
def load_checkpoint_from_zip(zip_path, device):
|
189 |
+
try:
|
190 |
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
191 |
+
with zf.open('checkpoint.pt') as f:
|
192 |
+
buffer = io.BytesIO(f.read())
|
193 |
+
checkpoint = torch.load(buffer, map_location=device)
|
194 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
195 |
+
return checkpoint['model_state_dict']
|
196 |
+
return checkpoint
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error loading checkpoint: {e}")
|
199 |
+
return None
|
200 |
+
|
201 |
+
# Try to load the checkpoint
|
202 |
+
try:
|
203 |
+
checkpoint_path = 'checkpoint_step_5000.zip' # or .pt file
|
204 |
+
if checkpoint_path.endswith('.zip'):
|
205 |
+
state_dict = load_checkpoint_from_zip(checkpoint_path, device)
|
206 |
+
else:
|
207 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
208 |
+
state_dict = checkpoint['model_state_dict'] if isinstance(checkpoint, dict) else checkpoint
|
209 |
+
|
210 |
+
if state_dict is not None:
|
211 |
+
model.load_state_dict(state_dict)
|
212 |
+
print("Checkpoint loaded successfully")
|
213 |
+
except Exception as e:
|
214 |
+
print(f"Error loading checkpoint: {e}")
|
215 |
+
print("Initializing model with random weights")
|
216 |
+
|
217 |
model.to(device)
|
218 |
model.eval()
|
219 |
|