Upload 3 files
Browse filesImplemented processor.
- config.json +7 -5
- modeling_qwen2.py +175 -107
- processing_qwen2_ts.py +171 -0
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"Qwen2TSForCausalLM"
|
5 |
],
|
@@ -7,7 +7,8 @@
|
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "configuration_qwen2.Qwen2TSConfig",
|
9 |
"AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
|
10 |
-
"AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM"
|
|
|
11 |
},
|
12 |
"bos_token_id": 151643,
|
13 |
"eos_token_id": 151645,
|
@@ -33,10 +34,11 @@
|
|
33 |
"hidden_size": 5120,
|
34 |
"num_features": 2,
|
35 |
"num_layers": 5,
|
36 |
-
"patch_size": 16
|
|
|
37 |
},
|
38 |
-
"ts_token_end_index":
|
39 |
-
"ts_token_start_index":
|
40 |
"use_cache": false,
|
41 |
"use_sliding_window": false,
|
42 |
"vocab_size": 152064
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "chatts_release",
|
3 |
"architectures": [
|
4 |
"Qwen2TSForCausalLM"
|
5 |
],
|
|
|
7 |
"auto_map": {
|
8 |
"AutoConfig": "configuration_qwen2.Qwen2TSConfig",
|
9 |
"AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
|
10 |
+
"AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM",
|
11 |
+
"AutoProcessor": "processing_qwen2_ts.Qwen2TSProcessor"
|
12 |
},
|
13 |
"bos_token_id": 151643,
|
14 |
"eos_token_id": 151645,
|
|
|
34 |
"hidden_size": 5120,
|
35 |
"num_features": 2,
|
36 |
"num_layers": 5,
|
37 |
+
"patch_size": 16,
|
38 |
+
"max_length": 2048
|
39 |
},
|
40 |
+
"ts_token_end_index": 151666,
|
41 |
+
"ts_token_start_index": 151665,
|
42 |
"use_cache": false,
|
43 |
"use_sliding_window": false,
|
44 |
"vocab_size": 152064
|
modeling_qwen2.py
CHANGED
@@ -26,7 +26,7 @@
|
|
26 |
import inspect
|
27 |
import math
|
28 |
import copy
|
29 |
-
from typing import List, Optional, Tuple, Union
|
30 |
from dataclasses import dataclass
|
31 |
|
32 |
import torch
|
@@ -68,6 +68,44 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
|
|
68 |
_CONFIG_FOR_DOC = "Qwen2TSConfig"
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
########################Naive TS Embedding#####################
|
72 |
class TimeSeriesEmbedding(nn.Module):
|
73 |
def __init__(self, config):
|
@@ -1187,147 +1225,127 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1187 |
|
1188 |
def get_decoder(self):
|
1189 |
return self.model
|
1190 |
-
|
1191 |
-
def _get_real_length(self, timeseries, input_ids):
|
1192 |
-
# Return the embed length after inserting timeseries features
|
1193 |
-
if timeseries is None:
|
1194 |
-
return input_ids.size(1)
|
1195 |
-
|
1196 |
-
num_time_steps = timeseries.size(1) * timeseries.size(2) // self.config.ts['num_features']
|
1197 |
-
num_patches = num_time_steps // self.config.ts['patch_size']
|
1198 |
-
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
1199 |
-
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
1200 |
-
return num_special_ts_tokens * (num_patches - 2) + input_ids.size(1)
|
1201 |
-
|
1202 |
-
def _get_original_length(self, timeseries, input_ids, past_length):
|
1203 |
-
if timeseries is None:
|
1204 |
-
if isinstance(past_length, int):
|
1205 |
-
original_length = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
|
1206 |
-
else:
|
1207 |
-
original_length = past_length
|
1208 |
-
num_special_ts_tokens_within_past = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
|
1209 |
-
return original_length, num_special_ts_tokens_within_past
|
1210 |
-
|
1211 |
-
patch_size = self.config.ts['patch_size']
|
1212 |
-
num_patches = timeseries.size(1) * timeseries.size(2) // patch_size // self.config.ts['num_features']
|
1213 |
-
ts_token_start_index = self.config.ts_token_start_index
|
1214 |
-
|
1215 |
-
ts_mask = (input_ids == ts_token_start_index).long() # (batch_size, seq_length)
|
1216 |
-
|
1217 |
-
cumsum_ts = torch.cumsum(ts_mask, dim=1) # (batch_size, seq_length)
|
1218 |
-
|
1219 |
-
seq_length = input_ids.size(1)
|
1220 |
-
positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
|
1221 |
-
|
1222 |
-
transformed_length = positions + cumsum_ts * (num_patches - 2) # (batch_size, seq_length)
|
1223 |
-
|
1224 |
-
if isinstance(past_length, int):
|
1225 |
-
past_length_tensor = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
|
1226 |
-
else:
|
1227 |
-
past_length_tensor = past_length.to(input_ids.device)
|
1228 |
-
|
1229 |
-
mask = transformed_length <= past_length_tensor.unsqueeze(1) # (batch_size, seq_length)
|
1230 |
-
|
1231 |
-
original_length = torch.sum(mask, dim=1) # (batch_size,)
|
1232 |
-
original_positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
|
1233 |
-
original_mask = original_positions <= original_length.unsqueeze(1) # (batch_size, seq_length)
|
1234 |
-
ts_within_original_mask = ts_mask.bool() & original_mask.bool() # (batch_size, seq_length)
|
1235 |
-
num_special_ts_tokens_within_past = torch.sum(ts_within_original_mask, dim=1) # (batch_size,)
|
1236 |
-
|
1237 |
-
original_length = torch.clamp(original_length, min=0)
|
1238 |
-
|
1239 |
-
return original_length, num_special_ts_tokens_within_past
|
1240 |
-
|
1241 |
def _merge_input_ids_with_time_series_features(
|
1242 |
-
|
1243 |
-
|
1244 |
-
total_time_steps, embed_dim = time_series_features.shape
|
1245 |
batch_size, sequence_length = input_ids.shape
|
|
|
|
|
1246 |
left_padding = False
|
1247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1248 |
# 1. Create a mask to know where special time series tokens are
|
1249 |
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
1250 |
special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
|
1251 |
special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
|
|
|
|
|
1252 |
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
|
|
|
|
1253 |
# Correctly calculate the total number of patches per batch
|
|
|
1254 |
num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
|
1255 |
special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
|
1256 |
special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
|
1257 |
-
|
|
|
1258 |
for i in range(batch_size):
|
1259 |
num_ts_in_batch = num_special_ts_tokens[i]
|
1260 |
-
num_total_patches[i] = patch_cnt[patch_index:patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
|
1261 |
for idx in range(patch_index, patch_index + num_ts_in_batch):
|
1262 |
-
|
1263 |
-
special_ts_token_mask_start_with_size[
|
1264 |
patch_index += num_ts_in_batch
|
1265 |
-
|
1266 |
-
|
|
|
1267 |
max_embed_dim = sequence_length + num_total_patches.max()
|
1268 |
-
|
1269 |
-
#
|
1270 |
batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
|
1271 |
|
1272 |
-
#
|
1273 |
new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
|
|
|
|
|
1274 |
nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
1275 |
if left_padding:
|
1276 |
-
new_token_positions += nb_ts_pad[:, None]
|
|
|
1277 |
text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
|
1278 |
-
|
1279 |
-
#
|
1280 |
final_embedding = torch.zeros(
|
1281 |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
1282 |
)
|
1283 |
-
|
1284 |
-
|
1285 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1286 |
if labels is not None:
|
1287 |
final_labels = torch.full(
|
1288 |
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
1289 |
)
|
|
|
1290 |
target_device = inputs_embeds.device
|
1291 |
batch_indices, non_ts_indices, text_to_overwrite = (
|
1292 |
batch_indices.to(target_device),
|
1293 |
non_ts_indices.to(target_device),
|
1294 |
text_to_overwrite.to(target_device),
|
1295 |
)
|
1296 |
-
|
1297 |
-
|
1298 |
-
# 4. Fill the embeddings based on the mask
|
1299 |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
|
1300 |
-
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_ts_indices]
|
1301 |
if labels is not None:
|
1302 |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
|
1303 |
-
|
1304 |
-
#
|
1305 |
ts_to_overwrite = torch.full(
|
1306 |
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
1307 |
)
|
1308 |
ts_to_overwrite[batch_indices, text_to_overwrite] = False
|
|
|
1309 |
reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
|
1310 |
ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
|
1311 |
-
|
|
|
1312 |
if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
|
1313 |
raise ValueError(
|
1314 |
f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
|
1315 |
f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
|
1316 |
)
|
1317 |
-
|
1318 |
final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
1319 |
-
|
|
|
1320 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
1321 |
-
|
1322 |
-
|
1323 |
-
|
1324 |
-
|
1325 |
-
|
1326 |
-
|
|
|
|
|
1327 |
|
1328 |
-
if labels is None:
|
1329 |
-
final_labels = None
|
1330 |
-
|
1331 |
return final_embedding, final_attention_mask, position_ids, final_labels
|
1332 |
|
1333 |
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
@@ -1382,10 +1400,8 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1382 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
1383 |
|
1384 |
if timeseries is not None and timeseries.shape[0] > 0:
|
1385 |
-
#
|
1386 |
-
use_cache = False
|
1387 |
ts_features, patch_cnt = self.ts_encoder(timeseries)
|
1388 |
-
|
1389 |
inputs_embeds = inputs_embeds.to(ts_features.dtype)
|
1390 |
|
1391 |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
|
@@ -1424,14 +1440,63 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1424 |
output = (logits,) + outputs[1:]
|
1425 |
return (loss,) + output if loss is not None else output
|
1426 |
|
1427 |
-
|
|
|
1428 |
loss=loss,
|
1429 |
logits=logits,
|
1430 |
past_key_values=outputs.past_key_values,
|
1431 |
hidden_states=outputs.hidden_states,
|
1432 |
attentions=outputs.attentions,
|
|
|
1433 |
)
|
1434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1435 |
|
1436 |
def prepare_inputs_for_generation(
|
1437 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
|
@@ -1446,20 +1511,23 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1446 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1447 |
max_cache_length = None
|
1448 |
|
1449 |
-
|
1450 |
-
|
1451 |
-
|
1452 |
-
|
1453 |
-
|
1454 |
-
|
1455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1456 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1457 |
-
|
1458 |
-
|
1459 |
-
elif past_length < real_len:
|
1460 |
-
input_ids = input_ids[:, origin_past_len:]
|
1461 |
-
if timeseries is not None:
|
1462 |
-
timeseries = timeseries[past_num_ts:]
|
1463 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1464 |
|
1465 |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
@@ -1476,7 +1544,7 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1476 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
1477 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
1478 |
if past_key_values:
|
1479 |
-
position_ids = position_ids[:, -input_ids.
|
1480 |
|
1481 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1482 |
if inputs_embeds is not None and past_key_values is None:
|
|
|
26 |
import inspect
|
27 |
import math
|
28 |
import copy
|
29 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
30 |
from dataclasses import dataclass
|
31 |
|
32 |
import torch
|
|
|
68 |
_CONFIG_FOR_DOC = "Qwen2TSConfig"
|
69 |
|
70 |
|
71 |
+
@dataclass
|
72 |
+
class Qwen2TSCausalLMOutputWithPast(ModelOutput):
|
73 |
+
"""
|
74 |
+
Base class for Qwen2TS causal language model (or autoregressive) outputs.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
78 |
+
Language modeling loss (for next-token prediction).
|
79 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
80 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
81 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
82 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
83 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
84 |
+
|
85 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
86 |
+
`past_key_values` input) to speed up sequential decoding.
|
87 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
88 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
89 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
90 |
+
|
91 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
92 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
93 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
94 |
+
sequence_length)`.
|
95 |
+
|
96 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
97 |
+
heads.
|
98 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
99 |
+
Attentions mask, used to update attention mask and position_ids.
|
100 |
+
"""
|
101 |
+
|
102 |
+
loss: Optional[torch.FloatTensor] = None
|
103 |
+
logits: torch.FloatTensor = None
|
104 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
105 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
106 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
107 |
+
attention_mask: Optional[torch.FloatTensor] = None
|
108 |
+
|
109 |
########################Naive TS Embedding#####################
|
110 |
class TimeSeriesEmbedding(nn.Module):
|
111 |
def __init__(self, config):
|
|
|
1225 |
|
1226 |
def get_decoder(self):
|
1227 |
return self.model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1228 |
def _merge_input_ids_with_time_series_features(
|
1229 |
+
self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
|
1230 |
+
):
|
|
|
1231 |
batch_size, sequence_length = input_ids.shape
|
1232 |
+
_left_padding = torch.any(attention_mask[:, 0] == 0)
|
1233 |
+
_right_padding = torch.any(attention_mask[:, -1] == 0)
|
1234 |
left_padding = False
|
1235 |
+
if batch_size > 1:
|
1236 |
+
if _left_padding and not _right_padding:
|
1237 |
+
left_padding = True
|
1238 |
+
elif not _left_padding and _right_padding:
|
1239 |
+
left_padding = False
|
1240 |
+
elif not _left_padding and not _right_padding:
|
1241 |
+
left_padding = False
|
1242 |
+
else:
|
1243 |
+
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
|
1244 |
+
else:
|
1245 |
+
if _left_padding and not _right_padding:
|
1246 |
+
left_padding = True
|
1247 |
+
else:
|
1248 |
+
left_padding = False
|
1249 |
+
|
1250 |
# 1. Create a mask to know where special time series tokens are
|
1251 |
special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
|
1252 |
special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
|
1253 |
special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
|
1254 |
+
|
1255 |
+
# 2. Calculate patch count
|
1256 |
num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
|
1257 |
+
total_time_steps, embed_dim = time_series_features.shape
|
1258 |
+
|
1259 |
# Correctly calculate the total number of patches per batch
|
1260 |
+
patch_index = 0
|
1261 |
num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
|
1262 |
special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
|
1263 |
special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
|
1264 |
+
|
1265 |
+
attn_mask_cnt = attention_mask.sum(dim=-1)
|
1266 |
for i in range(batch_size):
|
1267 |
num_ts_in_batch = num_special_ts_tokens[i]
|
1268 |
+
num_total_patches[i] = patch_cnt[patch_index : patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
|
1269 |
for idx in range(patch_index, patch_index + num_ts_in_batch):
|
1270 |
+
b_idx, pos = special_ts_token_mask_start_nonzero[idx]
|
1271 |
+
special_ts_token_mask_start_with_size[b_idx, pos] *= (patch_cnt[idx].item() - 2)
|
1272 |
patch_index += num_ts_in_batch
|
1273 |
+
attn_mask_cnt[i] += num_total_patches[i].item()
|
1274 |
+
|
1275 |
+
# 3. Embeding length
|
1276 |
max_embed_dim = sequence_length + num_total_patches.max()
|
1277 |
+
|
1278 |
+
# 4. Non ts tokens
|
1279 |
batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
|
1280 |
|
1281 |
+
# 5. Text token in final text positions
|
1282 |
new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
|
1283 |
+
|
1284 |
+
# nb_ts_pad
|
1285 |
nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
1286 |
if left_padding:
|
1287 |
+
new_token_positions += nb_ts_pad[:, None]
|
1288 |
+
|
1289 |
text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
|
1290 |
+
|
1291 |
+
# 6. Final embedding and attention masks
|
1292 |
final_embedding = torch.zeros(
|
1293 |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
1294 |
)
|
1295 |
+
|
1296 |
+
final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device)
|
1297 |
+
for i in range(attention_mask.size(0)):
|
1298 |
+
if left_padding:
|
1299 |
+
final_attention_mask[i, max_embed_dim - attn_mask_cnt[i] :] = 1
|
1300 |
+
else:
|
1301 |
+
final_attention_mask[i, : attn_mask_cnt[i]] = 1
|
1302 |
+
|
1303 |
+
final_labels = None
|
1304 |
if labels is not None:
|
1305 |
final_labels = torch.full(
|
1306 |
(batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
|
1307 |
)
|
1308 |
+
|
1309 |
target_device = inputs_embeds.device
|
1310 |
batch_indices, non_ts_indices, text_to_overwrite = (
|
1311 |
batch_indices.to(target_device),
|
1312 |
non_ts_indices.to(target_device),
|
1313 |
text_to_overwrite.to(target_device),
|
1314 |
)
|
1315 |
+
|
1316 |
+
# 7. Move embedding and labels to final positions
|
|
|
1317 |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
|
|
|
1318 |
if labels is not None:
|
1319 |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
|
1320 |
+
|
1321 |
+
# 8. Move time series to final positions
|
1322 |
ts_to_overwrite = torch.full(
|
1323 |
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
|
1324 |
)
|
1325 |
ts_to_overwrite[batch_indices, text_to_overwrite] = False
|
1326 |
+
|
1327 |
reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
|
1328 |
ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
|
1329 |
+
|
1330 |
+
# Check that the number of time series tokens is correct
|
1331 |
if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
|
1332 |
raise ValueError(
|
1333 |
f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
|
1334 |
f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
|
1335 |
)
|
|
|
1336 |
final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
1337 |
+
|
1338 |
+
# 9. Calculate position ids
|
1339 |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
|
1340 |
+
if position_ids.size(-1) < input_ids.size(-1):
|
1341 |
+
position_ids = position_ids[:, -input_ids.size(-1) :]
|
1342 |
+
|
1343 |
+
# 10. Move attention mask to final positions
|
1344 |
+
pad_batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id)
|
1345 |
+
if len(pad_batch_indices) > 0:
|
1346 |
+
indices_to_mask = new_token_positions[pad_batch_indices, pad_indices]
|
1347 |
+
final_embedding[pad_batch_indices, indices_to_mask] = 0
|
1348 |
|
|
|
|
|
|
|
1349 |
return final_embedding, final_attention_mask, position_ids, final_labels
|
1350 |
|
1351 |
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
|
1400 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
1401 |
|
1402 |
if timeseries is not None and timeseries.shape[0] > 0:
|
1403 |
+
# use_cache = False
|
|
|
1404 |
ts_features, patch_cnt = self.ts_encoder(timeseries)
|
|
|
1405 |
inputs_embeds = inputs_embeds.to(ts_features.dtype)
|
1406 |
|
1407 |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
|
|
|
1440 |
output = (logits,) + outputs[1:]
|
1441 |
return (loss,) + output if loss is not None else output
|
1442 |
|
1443 |
+
|
1444 |
+
return Qwen2TSCausalLMOutputWithPast(
|
1445 |
loss=loss,
|
1446 |
logits=logits,
|
1447 |
past_key_values=outputs.past_key_values,
|
1448 |
hidden_states=outputs.hidden_states,
|
1449 |
attentions=outputs.attentions,
|
1450 |
+
attention_mask=attention_mask
|
1451 |
)
|
1452 |
|
1453 |
+
def _update_model_kwargs_for_generation(
|
1454 |
+
self,
|
1455 |
+
outputs: ModelOutput,
|
1456 |
+
model_kwargs: Dict[str, Any],
|
1457 |
+
is_encoder_decoder: bool = False,
|
1458 |
+
num_new_tokens: int = 1,
|
1459 |
+
) -> Dict[str, Any]:
|
1460 |
+
# update past_key_values keeping its naming used in model code
|
1461 |
+
cache_name, cache = self._extract_past_from_model_output(outputs)
|
1462 |
+
model_kwargs[cache_name] = cache
|
1463 |
+
if getattr(outputs, "state", None) is not None:
|
1464 |
+
model_kwargs["state"] = outputs.state
|
1465 |
+
|
1466 |
+
# update attention_mask
|
1467 |
+
if getattr(outputs, "attention_mask", None) is not None:
|
1468 |
+
model_kwargs["attention_mask"] = outputs.attention_mask
|
1469 |
+
|
1470 |
+
# update token_type_ids with last value
|
1471 |
+
if "token_type_ids" in model_kwargs:
|
1472 |
+
token_type_ids = model_kwargs["token_type_ids"]
|
1473 |
+
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
1474 |
+
|
1475 |
+
if not is_encoder_decoder:
|
1476 |
+
# update attention mask
|
1477 |
+
if "attention_mask" in model_kwargs:
|
1478 |
+
attention_mask = model_kwargs["attention_mask"]
|
1479 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1480 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1481 |
+
)
|
1482 |
+
else:
|
1483 |
+
# update decoder attention mask
|
1484 |
+
if "decoder_attention_mask" in model_kwargs:
|
1485 |
+
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
1486 |
+
model_kwargs["decoder_attention_mask"] = torch.cat(
|
1487 |
+
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
1488 |
+
dim=-1,
|
1489 |
+
)
|
1490 |
+
|
1491 |
+
if model_kwargs.get("use_cache", True):
|
1492 |
+
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
1493 |
+
else:
|
1494 |
+
past_positions = model_kwargs.pop("cache_position")
|
1495 |
+
new_positions = torch.arange(
|
1496 |
+
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
1497 |
+
).to(past_positions.device)
|
1498 |
+
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
1499 |
+
return model_kwargs
|
1500 |
|
1501 |
def prepare_inputs_for_generation(
|
1502 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
|
|
|
1511 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1512 |
max_cache_length = None
|
1513 |
|
1514 |
+
has_ts = timeseries is not None and len(timeseries) > 0
|
1515 |
+
|
1516 |
+
if has_ts and kwargs.get("attention_mask") is not None:
|
1517 |
+
attention_mask = kwargs["attention_mask"]
|
1518 |
+
attention_mask = torch.cat(
|
1519 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1520 |
+
)
|
1521 |
+
|
1522 |
+
# Set attention mask and input_ids
|
1523 |
+
if has_ts and past_length > 0:
|
1524 |
+
# We have only one token added and timeseries are already inferenced
|
1525 |
+
input_ids = input_ids[:, -1:]
|
1526 |
+
timeseries = None
|
1527 |
+
elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1528 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1529 |
+
elif past_length < input_ids.shape[1]:
|
1530 |
+
input_ids = input_ids[:, past_length:]
|
|
|
|
|
|
|
|
|
1531 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
1532 |
|
1533 |
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
|
1544 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
1545 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
1546 |
if past_key_values:
|
1547 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1548 |
|
1549 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1550 |
if inputs_embeds is not None and past_key_values is None:
|
processing_qwen2_ts.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Tsinghua University and ByteDance.
|
3 |
+
#
|
4 |
+
# Licensed under the MIT License (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://opensource.org/license/mit
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
from typing import List, Union, Tuple, Optional
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from transformers.feature_extraction_utils import BatchFeature
|
21 |
+
from transformers.processing_utils import ProcessorMixin
|
22 |
+
from transformers.tokenization_utils_base import (
|
23 |
+
PreTokenizedInput,
|
24 |
+
TextInput,
|
25 |
+
PaddingStrategy,
|
26 |
+
)
|
27 |
+
|
28 |
+
def sp_encoding(timeseries: np.ndarray, eots_token: bool = True) -> Tuple[np.ndarray, str, dict]:
|
29 |
+
"""
|
30 |
+
Encodes a time series with scalar normalization.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
timeseries (np.ndarray): The raw time series data (1D or 2D).
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
result_timeseries (np.ndarray): The encoded time series, shape [seq_len, 1].
|
37 |
+
prompt (str): The placeholder string with offset and scaling info.
|
38 |
+
metadata (dict): Metadata containing the offset and scaling factor.
|
39 |
+
"""
|
40 |
+
mean = np.mean(timeseries)
|
41 |
+
scaled_timeseries = timeseries - mean
|
42 |
+
scale_factor = 1.0
|
43 |
+
if np.any(np.abs(scaled_timeseries) >= 3.0):
|
44 |
+
scale_factor = np.max(np.abs(scaled_timeseries)) / 3.0
|
45 |
+
scaled_timeseries /= scale_factor
|
46 |
+
|
47 |
+
prompt = f"[Value Offset: {-mean:.4f}|Value Scaling: {scale_factor:.4f}]<ts>"
|
48 |
+
if eots_token:
|
49 |
+
prompt += '<ts/>'
|
50 |
+
|
51 |
+
result_timeseries = np.stack([scaled_timeseries, np.ones_like(scaled_timeseries)], axis=-1).reshape(-1, 1)
|
52 |
+
|
53 |
+
return result_timeseries, prompt, {"offset": float(-mean), "scale_factor": float(scale_factor)}
|
54 |
+
|
55 |
+
class Qwen2TSProcessor(ProcessorMixin):
|
56 |
+
"""
|
57 |
+
A processor for ChatTS that integrates text prompt processing and time series encoding.
|
58 |
+
"""
|
59 |
+
|
60 |
+
attributes = ["tokenizer"]
|
61 |
+
feature_extractor_class = None # You can add a feature extractor if needed
|
62 |
+
tokenizer_class = "AutoTokenizer"
|
63 |
+
|
64 |
+
def __init__(self, tokenizer=None):
|
65 |
+
"""
|
66 |
+
Args:
|
67 |
+
tokenizer: An optional tokenizer to process text prompts.
|
68 |
+
"""
|
69 |
+
super().__init__(tokenizer=tokenizer)
|
70 |
+
|
71 |
+
def __call__(
|
72 |
+
self,
|
73 |
+
text: List[str],
|
74 |
+
timeseries: List[List[np.ndarray]],
|
75 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
76 |
+
padding_side: str = 'left',
|
77 |
+
vllm_flag: bool = False,
|
78 |
+
**kwargs,
|
79 |
+
) -> BatchFeature:
|
80 |
+
"""
|
81 |
+
Encodes a prompt and its associated time series.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
prompt (List[str]): The input prompt containing <ts><ts/> placeholders.
|
85 |
+
timeseries (List[np.ndarray]): A list of time series matched to placeholders in the prompt.
|
86 |
+
padding (bool or str or PaddingStrategy, optional): Passed to the tokenizer for text padding.
|
87 |
+
return_tensors (str, optional): "pt" to return PyTorch tensors; None to return NumPy arrays.
|
88 |
+
**kwargs: Additional tokenizer parameters.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
BatchFeature: Contains processed prompt, encoded time series, and tokenizer outputs.
|
92 |
+
"""
|
93 |
+
if type(text) == str:
|
94 |
+
text = [text]
|
95 |
+
|
96 |
+
encoded_ts_arrays = []
|
97 |
+
reconstructed_prompts = []
|
98 |
+
total_ts_cnt = 0
|
99 |
+
for idx, prompt in enumerate(text):
|
100 |
+
# Split prompt by <ts><ts/> placeholders
|
101 |
+
last_ts_cnt = total_ts_cnt
|
102 |
+
prompt_segments = prompt.split("<ts><ts/>")
|
103 |
+
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
104 |
+
|
105 |
+
# Encode each time series and rebuild the prompt
|
106 |
+
reconstructed_prompt = prompt_segments[0]
|
107 |
+
|
108 |
+
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
109 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
110 |
+
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
111 |
+
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
112 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
113 |
+
|
114 |
+
reconstructed_prompts.append(reconstructed_prompt)
|
115 |
+
|
116 |
+
if len(timeseries) != len(encoded_ts_arrays):
|
117 |
+
raise ValueError(
|
118 |
+
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
119 |
+
f"and time series ({len(encoded_ts_arrays)})."
|
120 |
+
)
|
121 |
+
|
122 |
+
if len(encoded_ts_arrays) > 0:
|
123 |
+
# Pad time series to the same length
|
124 |
+
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
125 |
+
padded_ts_arrays = [
|
126 |
+
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
127 |
+
for ts in encoded_ts_arrays
|
128 |
+
]
|
129 |
+
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
130 |
+
|
131 |
+
# Convert to torch
|
132 |
+
concatenated_ts = torch.from_numpy(concatenated_ts).half()
|
133 |
+
else:
|
134 |
+
concatenated_ts = None
|
135 |
+
|
136 |
+
# Tokenize the processed prompt
|
137 |
+
tokenizer_outputs = {}
|
138 |
+
if self.tokenizer is not None:
|
139 |
+
tokenizer_outputs = self.tokenizer(reconstructed_prompts, padding=padding, padding_side=padding_side, **kwargs)
|
140 |
+
|
141 |
+
# Create the final output
|
142 |
+
outputs = {
|
143 |
+
"timeseries": concatenated_ts
|
144 |
+
}
|
145 |
+
outputs.update(tokenizer_outputs)
|
146 |
+
|
147 |
+
return BatchFeature(data=outputs)
|
148 |
+
|
149 |
+
@property
|
150 |
+
def model_input_names(self):
|
151 |
+
"""
|
152 |
+
Define the input names expected by the model.
|
153 |
+
"""
|
154 |
+
tokenizer_input_names = []
|
155 |
+
if self.tokenizer and hasattr(self.tokenizer, "model_input_names"):
|
156 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
157 |
+
return list(dict.fromkeys(["processed_prompt", "time_series"] + tokenizer_input_names))
|
158 |
+
|
159 |
+
def batch_decode(self, *args, **kwargs):
|
160 |
+
"""
|
161 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
162 |
+
refer to the docstring of this method for more information.
|
163 |
+
"""
|
164 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
165 |
+
|
166 |
+
def decode(self, *args, **kwargs):
|
167 |
+
"""
|
168 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
169 |
+
the docstring of this method for more information.
|
170 |
+
"""
|
171 |
+
return self.tokenizer.decode(*args, **kwargs)
|