Text Classification
Safetensors
deberta-v2
tcapelle commited on
Commit
9a08e1a
1 Parent(s): a3fe9de

missing files and rename

Browse files
Files changed (3) hide show
  1. custom_pipeline.py +29 -0
  2. model.py +0 -20
  3. modelling_deberta_multi.py +31 -0
custom_pipeline.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextClassificationPipeline, AutoTokenizer
2
+
3
+ class CustomTextClassificationPipeline(TextClassificationPipeline):
4
+ def __init__(self, model, tokenizer=None, **kwargs):
5
+ # Initialize tokenizer first
6
+ if tokenizer is None:
7
+ tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
8
+ # Make sure we store the tokenizer before calling super().__init__
9
+ self.tokenizer = tokenizer
10
+ super().__init__(model=model, tokenizer=tokenizer, **kwargs)
11
+
12
+ def _sanitize_parameters(self, **kwargs):
13
+ preprocess_kwargs = {}
14
+ return preprocess_kwargs, {}, {}
15
+
16
+ def preprocess(self, inputs):
17
+ return self.tokenizer(inputs, return_tensors='pt', truncation=False)
18
+
19
+ def _forward(self, model_inputs):
20
+ input_ids = model_inputs['input_ids']
21
+ attention_mask = (input_ids != 0).long()
22
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
23
+ return outputs
24
+
25
+ def postprocess(self, model_outputs):
26
+ predictions = model_outputs.logits.argmax(dim=-1).squeeze().tolist()
27
+ categories = ["Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence", "Other"]
28
+ return dict(zip(categories, predictions))
29
+
model.py DELETED
@@ -1,20 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import DebertaV2Model, DebertaV2PreTrainedModel
4
-
5
- class MultiHeadDebertaForSequenceClassification(DebertaV2PreTrainedModel):
6
- def __init__(self, config, num_heads=5):
7
- super().__init__(config)
8
- self.num_heads = num_heads
9
- self.deberta = DebertaV2Model(config)
10
- self.heads = nn.ModuleList([nn.Linear(config.hidden_size, 4) for _ in range(num_heads)])
11
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
12
- self.post_init()
13
-
14
- def forward(self, input_ids=None, attention_mask=None):
15
- outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
16
- sequence_output = outputs[0]
17
- logits_list = [head(self.dropout(sequence_output[:, 0, :])) for head in self.heads]
18
- logits = torch.stack(logits_list, dim=1)
19
- return logits
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modelling_deberta_multi.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from typing import Optional
4
+ from transformers import DebertaV2PreTrainedModel, DebertaV2Model
5
+ from .configuration_deberta_multi import MultiHeadDebertaV2Config
6
+
7
+ class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel):
8
+
9
+ config_class = MultiHeadDebertaV2Config
10
+ def __init__(self, config): # type: ignore
11
+ super().__init__(config)
12
+ self.deberta = DebertaV2Model(config)
13
+ self.heads = nn.ModuleList(
14
+ [nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)]
15
+ )
16
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
17
+ self.post_init()
18
+
19
+ def forward(
20
+ self,
21
+ input_ids: Optional["Tensor"] = None,
22
+ attention_mask: Optional["Tensor"] = None,
23
+ ) -> "Tensor":
24
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
25
+ sequence_output = outputs[0]
26
+ logits_list = [
27
+ head(self.dropout(sequence_output[:, 0, :])) for head in self.heads
28
+ ]
29
+ logits = torch.stack(logits_list, dim=1)
30
+ outputs["logits"] = logits
31
+ return outputs