ppak10 commited on
Commit
01f699b
·
1 Parent(s): 3e60839

Updates model.py file.

Browse files
Files changed (2) hide show
  1. model.py +30 -11
  2. pipeline.ipynb +0 -0
model.py CHANGED
@@ -1,19 +1,38 @@
1
  import torch.nn as nn
2
- from transformers import AutoModel, PreTrainedModel
3
 
4
- class LlamaClassificationModel(PreTrainedModel):
5
- def __init__(self, config):
6
- super().__init__(config)
7
- self.base_model = AutoModel.from_pretrained(config.model_path, config=config)
8
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
9
- self.config = config
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def forward(self, input_ids, attention_mask, labels=None):
12
- outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
13
- summed_representation = outputs.last_hidden_state.sum(dim=1)
14
- logits = self.classifier(summed_representation)
 
 
 
 
15
  loss = None
16
  if labels is not None:
17
  loss_fn = nn.BCEWithLogitsLoss()
18
  loss = loss_fn(logits, labels.float())
19
- return {"loss": loss, "logits": logits}
 
1
  import torch.nn as nn
2
+ import torch
3
 
4
+ from transformers import AutoModel
 
 
 
 
 
5
 
6
+ NUM_LABELS = 4
7
+
8
+ # Model with frozen LLaMA weights
9
+ class LlamaClassificationModel(nn.Module):
10
+ def __init__(self, model_path = "meta-llama/Llama-3.2-1B", freeze_weights = True):
11
+ super(LlamaClassificationModel, self).__init__()
12
+ self.base_model = AutoModel.from_pretrained(model_path)
13
+
14
+ # For push to hub.
15
+ self.config = self.base_model.config
16
+ print(self.base_model.config)
17
+
18
+ # Freeze the base model's weights
19
+ if freeze_weights:
20
+ for param in self.base_model.parameters():
21
+ param.requires_grad = False
22
+
23
+ # Add a classification head
24
+ self.classifier = nn.Linear(self.base_model.config.hidden_size, NUM_LABELS)
25
+
26
  def forward(self, input_ids, attention_mask, labels=None):
27
+ with torch.no_grad(): # No gradients for the base model
28
+ outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
29
+
30
+ # Sum hidden states over the sequence dimension
31
+ summed_representation = outputs.last_hidden_state.sum(dim=1) # Summing over sequence length
32
+
33
+ logits = self.classifier(summed_representation) # Pass the summed representation to the classifier
34
  loss = None
35
  if labels is not None:
36
  loss_fn = nn.BCEWithLogitsLoss()
37
  loss = loss_fn(logits, labels.float())
38
+ return {"loss": loss, "logits": logits}
pipeline.ipynb ADDED
File without changes