Spaces:
Sleeping
Sleeping
Update AKSHAYRAJAA/inference.py
Browse files- AKSHAYRAJAA/inference.py +17 -6
AKSHAYRAJAA/inference.py
CHANGED
@@ -49,21 +49,32 @@ def load_model():
|
|
49 |
"""
|
50 |
Loads the model with the vocabulary and checkpoint.
|
51 |
"""
|
52 |
-
|
53 |
dataset = load_dataset() # Load dataset to access vocabulary
|
54 |
vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
|
55 |
|
56 |
-
|
57 |
model = get_model_instance(vocabulary) # Initialize the model
|
58 |
|
59 |
if can_load_checkpoint():
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
else:
|
63 |
-
|
64 |
|
65 |
model.eval() # Set the model to evaluation mode
|
66 |
-
|
67 |
return model
|
68 |
|
69 |
|
|
|
49 |
"""
|
50 |
Loads the model with the vocabulary and checkpoint.
|
51 |
"""
|
52 |
+
st.write("Loading dataset and vocabulary...")
|
53 |
dataset = load_dataset() # Load dataset to access vocabulary
|
54 |
vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
|
55 |
|
56 |
+
st.write("Initializing the model...")
|
57 |
model = get_model_instance(vocabulary) # Initialize the model
|
58 |
|
59 |
if can_load_checkpoint():
|
60 |
+
st.write("Loading checkpoint...")
|
61 |
+
checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=DEVICE)
|
62 |
+
|
63 |
+
# Print out the checkpoint layer sizes for debugging
|
64 |
+
print({k: v.shape for k, v in checkpoint['state_dict'].items()})
|
65 |
+
|
66 |
+
# Try loading the checkpoint with strict=False to ignore mismatched layers
|
67 |
+
try:
|
68 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
69 |
+
st.write("Checkpoint loaded successfully.")
|
70 |
+
except RuntimeError as e:
|
71 |
+
st.write(f"Error loading checkpoint: {e}")
|
72 |
+
st.write("Starting with untrained model.")
|
73 |
else:
|
74 |
+
st.write("No checkpoint found, starting with untrained model.")
|
75 |
|
76 |
model.eval() # Set the model to evaluation mode
|
77 |
+
st.write("Model is ready for inference.")
|
78 |
return model
|
79 |
|
80 |
|