jupyterjazz commited on
Commit
4ff8c15
·
verified ·
1 Parent(s): 343dbf5

feat: support setting a default task

Browse files
Files changed (1) hide show
  1. custom_st.py +20 -5
custom_st.py CHANGED
@@ -65,6 +65,7 @@ class Transformer(nn.Module):
65
  self._adaptation_map = {
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
 
68
 
69
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
70
  tokenizer_args["model_max_length"] = max_seq_length
@@ -88,17 +89,31 @@ class Transformer(nn.Module):
88
  if tokenizer_name_or_path is not None:
89
  self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
90
 
91
- def forward(
92
- self, features: Dict[str, torch.Tensor], task: Optional[str] = None
93
- ) -> Dict[str, torch.Tensor]:
94
- """Returns token_embeddings, cls_token"""
 
 
 
 
 
 
 
 
95
  if task and task not in self._lora_adaptations:
96
  raise ValueError(
97
  f"Unsupported task '{task}'. "
98
- f"Supported tasks are: {', '.join(self.config.lora_adaptations)}."
99
  f"Alternatively, don't pass the `task` argument to disable LoRA."
100
  )
101
 
 
 
 
 
 
 
102
  adapter_mask = None
103
  if task:
104
  task_id = self._adaptation_map[task]
 
65
  self._adaptation_map = {
66
  name: idx for idx, name in enumerate(self._lora_adaptations)
67
  }
68
+ self._default_task = None
69
 
70
  if max_seq_length is not None and "model_max_length" not in tokenizer_args:
71
  tokenizer_args["model_max_length"] = max_seq_length
 
89
  if tokenizer_name_or_path is not None:
90
  self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
91
 
92
+
93
+ @property
94
+ def default_task(self):
95
+ return self._default_task
96
+
97
+ @default_task.setter
98
+ def default_task(self, task: Union[None, str]):
99
+ self._validate_task(task)
100
+ self._default_task = task
101
+
102
+
103
+ def _validate_task(self, task: str):
104
  if task and task not in self._lora_adaptations:
105
  raise ValueError(
106
  f"Unsupported task '{task}'. "
107
+ f"Supported tasks are: {', '.join(self.config.lora_adaptations)}. "
108
  f"Alternatively, don't pass the `task` argument to disable LoRA."
109
  )
110
 
111
+ def forward(
112
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None
113
+ ) -> Dict[str, torch.Tensor]:
114
+ """Returns token_embeddings, cls_token"""
115
+ self._validate_task(task)
116
+ task = task or self.default_task
117
  adapter_mask = None
118
  if task:
119
  task_id = self._adaptation_map[task]