adithiyyha commited on
Commit
58af0b8
·
verified ·
1 Parent(s): 05d6807

Update AKSHAYRAJAA/inference.py

Browse files
Files changed (1) hide show
  1. AKSHAYRAJAA/inference.py +17 -6
AKSHAYRAJAA/inference.py CHANGED
@@ -49,21 +49,32 @@ def load_model():
49
  """
50
  Loads the model with the vocabulary and checkpoint.
51
  """
52
- print("Loading dataset and vocabulary...")
53
  dataset = load_dataset() # Load dataset to access vocabulary
54
  vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
55
 
56
- print("Initializing the model...")
57
  model = get_model_instance(vocabulary) # Initialize the model
58
 
59
  if can_load_checkpoint():
60
- print("Loading checkpoint...")
61
- load_checkpoint(model)
 
 
 
 
 
 
 
 
 
 
 
62
  else:
63
- print("No checkpoint found, starting with untrained model.")
64
 
65
  model.eval() # Set the model to evaluation mode
66
- print("Model is ready for inference.")
67
  return model
68
 
69
 
 
49
  """
50
  Loads the model with the vocabulary and checkpoint.
51
  """
52
+ st.write("Loading dataset and vocabulary...")
53
  dataset = load_dataset() # Load dataset to access vocabulary
54
  vocabulary = dataset.vocab # Assuming 'vocab' is an attribute of the dataset
55
 
56
+ st.write("Initializing the model...")
57
  model = get_model_instance(vocabulary) # Initialize the model
58
 
59
  if can_load_checkpoint():
60
+ st.write("Loading checkpoint...")
61
+ checkpoint = torch.load(config.CHECKPOINT_FILE, map_location=DEVICE)
62
+
63
+ # Print out the checkpoint layer sizes for debugging
64
+ print({k: v.shape for k, v in checkpoint['state_dict'].items()})
65
+
66
+ # Try loading the checkpoint with strict=False to ignore mismatched layers
67
+ try:
68
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
69
+ st.write("Checkpoint loaded successfully.")
70
+ except RuntimeError as e:
71
+ st.write(f"Error loading checkpoint: {e}")
72
+ st.write("Starting with untrained model.")
73
  else:
74
+ st.write("No checkpoint found, starting with untrained model.")
75
 
76
  model.eval() # Set the model to evaluation mode
77
+ st.write("Model is ready for inference.")
78
  return model
79
 
80