huseinzol05 commited on
Commit
0f99f61
1 Parent(s): 0208218

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -1
  2. model.safetensors +1 -1
  3. modeling.py +82 -0
config.json CHANGED
@@ -1,8 +1,11 @@
1
  {
2
- "_name_or_path": "embedding-model-llama-600m/checkpoint-96900",
3
  "architectures": [
4
  "LlamaModelEmbedding"
5
  ],
 
 
 
6
  "bos_token_id": 1,
7
  "eos_token_id": 2,
8
  "hidden_act": "silu",
 
1
  {
2
+ "_name_or_path": "embedding-model-llama-600m/checkpoint-99300",
3
  "architectures": [
4
  "LlamaModelEmbedding"
5
  ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling.LlamaModelEmbedding"
8
+ },
9
  "bos_token_id": 1,
10
  "eos_token_id": 2,
11
  "hidden_act": "silu",
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:539fad9026cd397dc0fdbca8ccb7c01c83eca98ab65afd1f279b48124ac6069c
3
  size 2168545568
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f80197bdfb20ec2d8f902185c65142feb7e305f1f1e241d7829a771b51873460
3
  size 2168545568
modeling.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import LlamaModel, LlamaConfig, LlamaTokenizer
3
+ from typing import Dict
4
+ from transformers.file_utils import ModelOutput
5
+ from typing import List, Optional, Tuple, Union
6
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
+ from torch import nn, Tensor
8
+ from dataclasses import dataclass
9
+ from torch import nn
10
+ from typing import Dict
11
+ import torch
12
+ from transformers.file_utils import ModelOutput
13
+
14
+
15
+ @dataclass
16
+ class EncoderOutput(ModelOutput):
17
+ q_reps: Optional[Tensor] = None
18
+ p_reps: Optional[Tensor] = None
19
+ loss: Optional[Tensor] = None
20
+ scores: Optional[Tensor] = None
21
+
22
+ class LlamaModelEmbedding(LlamaModel):
23
+ def __init__(self, config: LlamaConfig, **kwargs):
24
+ super().__init__(config, **kwargs)
25
+
26
+ self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
27
+ self.dense_layer = nn.Linear(self.config.hidden_size,1536)
28
+
29
+ def sentence_embedding(self, hidden_state, mask):
30
+ if self.config.sentence_pooling_method == 'mean':
31
+ s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
32
+ d = mask.sum(axis=1, keepdim=True).float()
33
+ return s / d
34
+ elif self.config.sentence_pooling_method == 'cls':
35
+ return hidden_state[:,0]
36
+
37
+ def encode(self, features):
38
+ if features is None:
39
+ return None
40
+ psg_out = super().forward(**features,return_dict=True)
41
+ output = self.dense_layer(psg_out.last_hidden_state)
42
+ p_reps = self.sentence_embedding(output, features['attention_mask'])
43
+ if self.config.normalized:
44
+ p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
45
+ return p_reps.contiguous()
46
+
47
+ def compute_similarity(self, q_reps, p_reps):
48
+ if len(p_reps.size()) == 2:
49
+ return torch.matmul(q_reps, p_reps.transpose(0, 1))
50
+ return torch.matmul(q_reps, p_reps.transpose(-2, -1))
51
+
52
+ def compute_loss(self, scores, target):
53
+ loss_fct = CrossEntropyLoss()
54
+ return loss_fct(scores, target)
55
+ # return self.cross_entropy(scores, target)
56
+
57
+ def forward(self, query: Dict[str, Tensor] = None,
58
+ passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
59
+ q_reps = self.encode(query)
60
+ p_reps = self.encode(passage)
61
+
62
+ if self.training:
63
+
64
+ scores = self.compute_similarity(q_reps, p_reps)
65
+ scores = scores / self.config.temperature
66
+ scores = scores.view(q_reps.size(0), -1)
67
+
68
+ target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
69
+ target = target * (p_reps.size(0) // q_reps.size(0))
70
+ loss = self.compute_loss(scores, target)
71
+
72
+ else:
73
+ scores = self.compute_similarity(q_reps, p_reps)
74
+ loss = None
75
+
76
+ return EncoderOutput(
77
+ loss=loss,
78
+ scores=scores,
79
+ q_reps=q_reps,
80
+ p_reps=p_reps,
81
+ )
82
+