ppak10 commited on
Commit
fd94ff6
·
1 Parent(s): 5ea142c

Updates model.py.

Browse files
Files changed (1) hide show
  1. model.py +11 -30
model.py CHANGED
@@ -1,38 +1,19 @@
1
  import torch.nn as nn
2
- import torch
3
 
4
- from transformers import AutoModel
 
 
 
 
 
5
 
6
- NUM_LABELS = 4
7
-
8
- # Model with frozen LLaMA weights
9
- class LlamaClassificationModel(nn.Module):
10
- def __init__(self, model_path = "meta-llama/Llama-3.2-1B", freeze_weights = True):
11
- super(LlamaClassificationModel, self).__init__()
12
- self.base_model = AutoModel.from_pretrained(model_path)
13
-
14
- # For push to hub.
15
- self.config = self.base_model.config
16
- print(self.base_model.config)
17
-
18
- # Freeze the base model's weights
19
- if freeze_weights:
20
- for param in self.base_model.parameters():
21
- param.requires_grad = False
22
-
23
- # Add a classification head
24
- self.classifier = nn.Linear(self.base_model.config.hidden_size, NUM_LABELS)
25
-
26
  def forward(self, input_ids, attention_mask, labels=None):
27
- with torch.no_grad(): # No gradients for the base model
28
- outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
29
-
30
- # Sum hidden states over the sequence dimension
31
- summed_representation = outputs.last_hidden_state.sum(dim=1) # Summing over sequence length
32
-
33
- logits = self.classifier(summed_representation) # Pass the summed representation to the classifier
34
  loss = None
35
  if labels is not None:
36
  loss_fn = nn.BCEWithLogitsLoss()
37
  loss = loss_fn(logits, labels.float())
38
- return {"loss": loss, "logits": logits}
 
1
  import torch.nn as nn
2
+ from transformers import AutoModel, PreTrainedModel
3
 
4
+ class LlamaClassificationModel(PreTrainedModel):
5
+ def __init__(self, config):
6
+ super().__init__(config)
7
+ self.base_model = AutoModel.from_pretrained(config.model_path, config=config)
8
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
9
+ self.config = config
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def forward(self, input_ids, attention_mask, labels=None):
12
+ outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
13
+ summed_representation = outputs.last_hidden_state.sum(dim=1)
14
+ logits = self.classifier(summed_representation)
 
 
 
 
15
  loss = None
16
  if labels is not None:
17
  loss_fn = nn.BCEWithLogitsLoss()
18
  loss = loss_fn(logits, labels.float())
19
+ return {"loss": loss, "logits": logits}