DavidGF commited on
Commit
351ee27
·
verified ·
1 Parent(s): 27402b0

Update kraken_model/modeling_kraken.py

Browse files
Files changed (1) hide show
  1. kraken_model/modeling_kraken.py +10 -7
kraken_model/modeling_kraken.py CHANGED
@@ -40,11 +40,6 @@ class KrakenForCausalLM(PreTrainedModel):
40
  model_decision_index = self.models_indices[prediction]
41
  model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
  return model_keys[model_decision_index]
43
-
44
- def expert_tokenizer(self, text):
45
- model_key = self.determine_model(text)
46
- return self.tokenizers[model_key]
47
-
48
 
49
  def generate(self, input_ids, **generate_kwargs):
50
  # Tokenize the input_ids
@@ -75,8 +70,16 @@ class KrakenForCausalLM(PreTrainedModel):
75
  tok_input_ids = tok.input_ids.to(current_device)
76
  tok_attention_mask = tok.attention_mask.to(current_device)
77
 
78
- # Generate text using the retrieved model
79
- return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
 
 
 
 
 
 
 
 
80
 
81
 
82
 
 
40
  model_decision_index = self.models_indices[prediction]
41
  model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
  return model_keys[model_decision_index]
 
 
 
 
 
43
 
44
  def generate(self, input_ids, **generate_kwargs):
45
  # Tokenize the input_ids
 
70
  tok_input_ids = tok.input_ids.to(current_device)
71
  tok_attention_mask = tok.attention_mask.to(current_device)
72
 
73
+ # Generate text using the modified model
74
+ output_ids = model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
75
+
76
+ # Decode the output using the expert tokenizer
77
+ decoded_text = self.tokenizers[model_key].decode(output_ids[0], skip_special_tokens=True)
78
+
79
+ # Retokenize the decoded text using the base tokenizer for external compatibility
80
+ retokenized_ids = self.tokenizer(decoded_text, return_tensors="pt").input_ids.to(current_device)
81
+
82
+ return retokenized_ids
83
 
84
 
85