xiezhe24 commited on
Commit
b08d47e
·
verified ·
1 Parent(s): e0a50c6

Upload 3 files

Browse files

Implemented processor.

Files changed (3) hide show
  1. config.json +7 -5
  2. modeling_qwen2.py +175 -107
  3. processing_qwen2_ts.py +171 -0
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "/mnt/bn/mllmhl/sft_checkpoints/qwen2.5-14b-ts-explaints-1124-stage1-sp/checkpoint-400",
3
  "architectures": [
4
  "Qwen2TSForCausalLM"
5
  ],
@@ -7,7 +7,8 @@
7
  "auto_map": {
8
  "AutoConfig": "configuration_qwen2.Qwen2TSConfig",
9
  "AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
10
- "AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM"
 
11
  },
12
  "bos_token_id": 151643,
13
  "eos_token_id": 151645,
@@ -33,10 +34,11 @@
33
  "hidden_size": 5120,
34
  "num_features": 2,
35
  "num_layers": 5,
36
- "patch_size": 16
 
37
  },
38
- "ts_token_end_index": 151665,
39
- "ts_token_start_index": 151666,
40
  "use_cache": false,
41
  "use_sliding_window": false,
42
  "vocab_size": 152064
 
1
  {
2
+ "_name_or_path": "chatts_release",
3
  "architectures": [
4
  "Qwen2TSForCausalLM"
5
  ],
 
7
  "auto_map": {
8
  "AutoConfig": "configuration_qwen2.Qwen2TSConfig",
9
  "AutoModel": "modeling_qwen2.Qwen2TSForCausalLM",
10
+ "AutoModelForCausalLM": "modeling_qwen2.Qwen2TSForCausalLM",
11
+ "AutoProcessor": "processing_qwen2_ts.Qwen2TSProcessor"
12
  },
13
  "bos_token_id": 151643,
14
  "eos_token_id": 151645,
 
34
  "hidden_size": 5120,
35
  "num_features": 2,
36
  "num_layers": 5,
37
+ "patch_size": 16,
38
+ "max_length": 2048
39
  },
40
+ "ts_token_end_index": 151666,
41
+ "ts_token_start_index": 151665,
42
  "use_cache": false,
43
  "use_sliding_window": false,
44
  "vocab_size": 152064
modeling_qwen2.py CHANGED
@@ -26,7 +26,7 @@
26
  import inspect
27
  import math
28
  import copy
29
- from typing import List, Optional, Tuple, Union
30
  from dataclasses import dataclass
31
 
32
  import torch
@@ -68,6 +68,44 @@ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
68
  _CONFIG_FOR_DOC = "Qwen2TSConfig"
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ########################Naive TS Embedding#####################
72
  class TimeSeriesEmbedding(nn.Module):
73
  def __init__(self, config):
@@ -1187,147 +1225,127 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1187
 
1188
  def get_decoder(self):
1189
  return self.model
1190
-
1191
- def _get_real_length(self, timeseries, input_ids):
1192
- # Return the embed length after inserting timeseries features
1193
- if timeseries is None:
1194
- return input_ids.size(1)
1195
-
1196
- num_time_steps = timeseries.size(1) * timeseries.size(2) // self.config.ts['num_features']
1197
- num_patches = num_time_steps // self.config.ts['patch_size']
1198
- special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
1199
- num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
1200
- return num_special_ts_tokens * (num_patches - 2) + input_ids.size(1)
1201
-
1202
- def _get_original_length(self, timeseries, input_ids, past_length):
1203
- if timeseries is None:
1204
- if isinstance(past_length, int):
1205
- original_length = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1206
- else:
1207
- original_length = past_length
1208
- num_special_ts_tokens_within_past = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
1209
- return original_length, num_special_ts_tokens_within_past
1210
-
1211
- patch_size = self.config.ts['patch_size']
1212
- num_patches = timeseries.size(1) * timeseries.size(2) // patch_size // self.config.ts['num_features']
1213
- ts_token_start_index = self.config.ts_token_start_index
1214
-
1215
- ts_mask = (input_ids == ts_token_start_index).long() # (batch_size, seq_length)
1216
-
1217
- cumsum_ts = torch.cumsum(ts_mask, dim=1) # (batch_size, seq_length)
1218
-
1219
- seq_length = input_ids.size(1)
1220
- positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1221
-
1222
- transformed_length = positions + cumsum_ts * (num_patches - 2) # (batch_size, seq_length)
1223
-
1224
- if isinstance(past_length, int):
1225
- past_length_tensor = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1226
- else:
1227
- past_length_tensor = past_length.to(input_ids.device)
1228
-
1229
- mask = transformed_length <= past_length_tensor.unsqueeze(1) # (batch_size, seq_length)
1230
-
1231
- original_length = torch.sum(mask, dim=1) # (batch_size,)
1232
- original_positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1233
- original_mask = original_positions <= original_length.unsqueeze(1) # (batch_size, seq_length)
1234
- ts_within_original_mask = ts_mask.bool() & original_mask.bool() # (batch_size, seq_length)
1235
- num_special_ts_tokens_within_past = torch.sum(ts_within_original_mask, dim=1) # (batch_size,)
1236
-
1237
- original_length = torch.clamp(original_length, min=0)
1238
-
1239
- return original_length, num_special_ts_tokens_within_past
1240
-
1241
  def _merge_input_ids_with_time_series_features(
1242
- self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
1243
- ):
1244
- total_time_steps, embed_dim = time_series_features.shape
1245
  batch_size, sequence_length = input_ids.shape
 
 
1246
  left_padding = False
1247
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1248
  # 1. Create a mask to know where special time series tokens are
1249
  special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
1250
  special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
1251
  special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
 
 
1252
  num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
 
 
1253
  # Correctly calculate the total number of patches per batch
 
1254
  num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
1255
  special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
1256
  special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
1257
- patch_index = 0
 
1258
  for i in range(batch_size):
1259
  num_ts_in_batch = num_special_ts_tokens[i]
1260
- num_total_patches[i] = patch_cnt[patch_index:patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
1261
  for idx in range(patch_index, patch_index + num_ts_in_batch):
1262
- batch_idx, pos_idx = special_ts_token_mask_start_nonzero[idx]
1263
- special_ts_token_mask_start_with_size[batch_idx, pos_idx] *= (patch_cnt[idx].item() - 2)
1264
  patch_index += num_ts_in_batch
1265
-
1266
- # Compute the maximum embed dimension, considering both start and end tokens
 
1267
  max_embed_dim = sequence_length + num_total_patches.max()
1268
-
1269
- # batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1270
  batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1271
 
1272
- # 2. Compute the positions where text should be written
1273
  new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
 
 
1274
  nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
1275
  if left_padding:
1276
- new_token_positions += nb_ts_pad[:, None] # offset for left padding
 
1277
  text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
1278
-
1279
- # 3. Create the full embedding, already padded to the maximum position
1280
  final_embedding = torch.zeros(
1281
  batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
1282
  )
1283
- final_attention_mask = torch.zeros(
1284
- batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
1285
- )
 
 
 
 
 
 
1286
  if labels is not None:
1287
  final_labels = torch.full(
1288
  (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
1289
  )
 
1290
  target_device = inputs_embeds.device
1291
  batch_indices, non_ts_indices, text_to_overwrite = (
1292
  batch_indices.to(target_device),
1293
  non_ts_indices.to(target_device),
1294
  text_to_overwrite.to(target_device),
1295
  )
1296
- attention_mask = attention_mask.to(target_device)
1297
-
1298
- # 4. Fill the embeddings based on the mask
1299
  final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
1300
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_ts_indices]
1301
  if labels is not None:
1302
  final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
1303
-
1304
- # 5. Fill the embeddings corresponding to the time series
1305
  ts_to_overwrite = torch.full(
1306
  (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
1307
  )
1308
  ts_to_overwrite[batch_indices, text_to_overwrite] = False
 
1309
  reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
1310
  ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
1311
-
 
1312
  if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
1313
  raise ValueError(
1314
  f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
1315
  f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
1316
  )
1317
-
1318
  final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
1319
- final_attention_mask |= ts_to_overwrite
 
1320
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
1321
-
1322
- # 6. Mask out the embedding at padding positions
1323
- batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id)
1324
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
1325
-
1326
- final_embedding[batch_indices, indices_to_mask] = 0
 
 
1327
 
1328
- if labels is None:
1329
- final_labels = None
1330
-
1331
  return final_embedding, final_attention_mask, position_ids, final_labels
1332
 
1333
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@@ -1382,10 +1400,8 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1382
  inputs_embeds = self.get_input_embeddings()(input_ids)
1383
 
1384
  if timeseries is not None and timeseries.shape[0] > 0:
1385
- # Disable KV Cache as it has not been implemented yet
1386
- use_cache = False
1387
  ts_features, patch_cnt = self.ts_encoder(timeseries)
1388
-
1389
  inputs_embeds = inputs_embeds.to(ts_features.dtype)
1390
 
1391
  inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
@@ -1424,14 +1440,63 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1424
  output = (logits,) + outputs[1:]
1425
  return (loss,) + output if loss is not None else output
1426
 
1427
- return CausalLMOutputWithPast(
 
1428
  loss=loss,
1429
  logits=logits,
1430
  past_key_values=outputs.past_key_values,
1431
  hidden_states=outputs.hidden_states,
1432
  attentions=outputs.attentions,
 
1433
  )
1434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1435
 
1436
  def prepare_inputs_for_generation(
1437
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
@@ -1446,20 +1511,23 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1446
  cache_length = past_length = past_key_values[0][0].shape[2]
1447
  max_cache_length = None
1448
 
1449
- # Keep only the unprocessed tokens:
1450
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1451
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1452
- # input)
1453
- real_len = self._get_real_length(timeseries, input_ids)
1454
- origin_past_len, past_num_ts = self._get_original_length(timeseries, input_ids, past_length)
1455
- if attention_mask is not None and attention_mask.shape[1] > real_len:
 
 
 
 
 
 
 
1456
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1457
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1458
- # input_ids based on the past_length.
1459
- elif past_length < real_len:
1460
- input_ids = input_ids[:, origin_past_len:]
1461
- if timeseries is not None:
1462
- timeseries = timeseries[past_num_ts:]
1463
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1464
 
1465
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
@@ -1476,7 +1544,7 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1476
  position_ids = attention_mask.long().cumsum(-1) - 1
1477
  position_ids.masked_fill_(attention_mask == 0, 1)
1478
  if past_key_values:
1479
- position_ids = position_ids[:, -input_ids.size(1) :]
1480
 
1481
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1482
  if inputs_embeds is not None and past_key_values is None:
 
26
  import inspect
27
  import math
28
  import copy
29
+ from typing import List, Optional, Tuple, Union, Dict, Any
30
  from dataclasses import dataclass
31
 
32
  import torch
 
68
  _CONFIG_FOR_DOC = "Qwen2TSConfig"
69
 
70
 
71
+ @dataclass
72
+ class Qwen2TSCausalLMOutputWithPast(ModelOutput):
73
+ """
74
+ Base class for Qwen2TS causal language model (or autoregressive) outputs.
75
+
76
+ Args:
77
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
78
+ Language modeling loss (for next-token prediction).
79
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
80
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
81
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
82
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
83
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
84
+
85
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
86
+ `past_key_values` input) to speed up sequential decoding.
87
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
92
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ attention_mask (`torch.FloatTensor`, *optional*):
99
+ Attentions mask, used to update attention mask and position_ids.
100
+ """
101
+
102
+ loss: Optional[torch.FloatTensor] = None
103
+ logits: torch.FloatTensor = None
104
+ past_key_values: Optional[List[torch.FloatTensor]] = None
105
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
106
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
107
+ attention_mask: Optional[torch.FloatTensor] = None
108
+
109
  ########################Naive TS Embedding#####################
110
  class TimeSeriesEmbedding(nn.Module):
111
  def __init__(self, config):
 
1225
 
1226
  def get_decoder(self):
1227
  return self.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1228
  def _merge_input_ids_with_time_series_features(
1229
+ self, time_series_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
1230
+ ):
 
1231
  batch_size, sequence_length = input_ids.shape
1232
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
1233
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
1234
  left_padding = False
1235
+ if batch_size > 1:
1236
+ if _left_padding and not _right_padding:
1237
+ left_padding = True
1238
+ elif not _left_padding and _right_padding:
1239
+ left_padding = False
1240
+ elif not _left_padding and not _right_padding:
1241
+ left_padding = False
1242
+ else:
1243
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
1244
+ else:
1245
+ if _left_padding and not _right_padding:
1246
+ left_padding = True
1247
+ else:
1248
+ left_padding = False
1249
+
1250
  # 1. Create a mask to know where special time series tokens are
1251
  special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
1252
  special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
1253
  special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
1254
+
1255
+ # 2. Calculate patch count
1256
  num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
1257
+ total_time_steps, embed_dim = time_series_features.shape
1258
+
1259
  # Correctly calculate the total number of patches per batch
1260
+ patch_index = 0
1261
  num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
1262
  special_ts_token_mask_start_nonzero = special_ts_token_mask_start.nonzero()
1263
  special_ts_token_mask_start_with_size = special_ts_token_mask_start.clone().long()
1264
+
1265
+ attn_mask_cnt = attention_mask.sum(dim=-1)
1266
  for i in range(batch_size):
1267
  num_ts_in_batch = num_special_ts_tokens[i]
1268
+ num_total_patches[i] = patch_cnt[patch_index : patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
1269
  for idx in range(patch_index, patch_index + num_ts_in_batch):
1270
+ b_idx, pos = special_ts_token_mask_start_nonzero[idx]
1271
+ special_ts_token_mask_start_with_size[b_idx, pos] *= (patch_cnt[idx].item() - 2)
1272
  patch_index += num_ts_in_batch
1273
+ attn_mask_cnt[i] += num_total_patches[i].item()
1274
+
1275
+ # 3. Embeding length
1276
  max_embed_dim = sequence_length + num_total_patches.max()
1277
+
1278
+ # 4. Non ts tokens
1279
  batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1280
 
1281
+ # 5. Text token in final text positions
1282
  new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
1283
+
1284
+ # nb_ts_pad
1285
  nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
1286
  if left_padding:
1287
+ new_token_positions += nb_ts_pad[:, None]
1288
+
1289
  text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
1290
+
1291
+ # 6. Final embedding and attention masks
1292
  final_embedding = torch.zeros(
1293
  batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
1294
  )
1295
+
1296
+ final_attention_mask = torch.zeros(batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device)
1297
+ for i in range(attention_mask.size(0)):
1298
+ if left_padding:
1299
+ final_attention_mask[i, max_embed_dim - attn_mask_cnt[i] :] = 1
1300
+ else:
1301
+ final_attention_mask[i, : attn_mask_cnt[i]] = 1
1302
+
1303
+ final_labels = None
1304
  if labels is not None:
1305
  final_labels = torch.full(
1306
  (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
1307
  )
1308
+
1309
  target_device = inputs_embeds.device
1310
  batch_indices, non_ts_indices, text_to_overwrite = (
1311
  batch_indices.to(target_device),
1312
  non_ts_indices.to(target_device),
1313
  text_to_overwrite.to(target_device),
1314
  )
1315
+
1316
+ # 7. Move embedding and labels to final positions
 
1317
  final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
 
1318
  if labels is not None:
1319
  final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
1320
+
1321
+ # 8. Move time series to final positions
1322
  ts_to_overwrite = torch.full(
1323
  (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
1324
  )
1325
  ts_to_overwrite[batch_indices, text_to_overwrite] = False
1326
+
1327
  reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
1328
  ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
1329
+
1330
+ # Check that the number of time series tokens is correct
1331
  if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
1332
  raise ValueError(
1333
  f"The input provided to the model are wrong. The number of time series tokens is {torch.sum(special_ts_token_mask_start)} while"
1334
  f" the number of time series given to the model is {len(patch_cnt)}. This prevents correct indexing and breaks batch generation."
1335
  )
 
1336
  final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
1337
+
1338
+ # 9. Calculate position ids
1339
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
1340
+ if position_ids.size(-1) < input_ids.size(-1):
1341
+ position_ids = position_ids[:, -input_ids.size(-1) :]
1342
+
1343
+ # 10. Move attention mask to final positions
1344
+ pad_batch_indices, pad_indices = torch.where(input_ids == self.config.pad_token_id)
1345
+ if len(pad_batch_indices) > 0:
1346
+ indices_to_mask = new_token_positions[pad_batch_indices, pad_indices]
1347
+ final_embedding[pad_batch_indices, indices_to_mask] = 0
1348
 
 
 
 
1349
  return final_embedding, final_attention_mask, position_ids, final_labels
1350
 
1351
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
 
1400
  inputs_embeds = self.get_input_embeddings()(input_ids)
1401
 
1402
  if timeseries is not None and timeseries.shape[0] > 0:
1403
+ # use_cache = False
 
1404
  ts_features, patch_cnt = self.ts_encoder(timeseries)
 
1405
  inputs_embeds = inputs_embeds.to(ts_features.dtype)
1406
 
1407
  inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
 
1440
  output = (logits,) + outputs[1:]
1441
  return (loss,) + output if loss is not None else output
1442
 
1443
+
1444
+ return Qwen2TSCausalLMOutputWithPast(
1445
  loss=loss,
1446
  logits=logits,
1447
  past_key_values=outputs.past_key_values,
1448
  hidden_states=outputs.hidden_states,
1449
  attentions=outputs.attentions,
1450
+ attention_mask=attention_mask
1451
  )
1452
 
1453
+ def _update_model_kwargs_for_generation(
1454
+ self,
1455
+ outputs: ModelOutput,
1456
+ model_kwargs: Dict[str, Any],
1457
+ is_encoder_decoder: bool = False,
1458
+ num_new_tokens: int = 1,
1459
+ ) -> Dict[str, Any]:
1460
+ # update past_key_values keeping its naming used in model code
1461
+ cache_name, cache = self._extract_past_from_model_output(outputs)
1462
+ model_kwargs[cache_name] = cache
1463
+ if getattr(outputs, "state", None) is not None:
1464
+ model_kwargs["state"] = outputs.state
1465
+
1466
+ # update attention_mask
1467
+ if getattr(outputs, "attention_mask", None) is not None:
1468
+ model_kwargs["attention_mask"] = outputs.attention_mask
1469
+
1470
+ # update token_type_ids with last value
1471
+ if "token_type_ids" in model_kwargs:
1472
+ token_type_ids = model_kwargs["token_type_ids"]
1473
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
1474
+
1475
+ if not is_encoder_decoder:
1476
+ # update attention mask
1477
+ if "attention_mask" in model_kwargs:
1478
+ attention_mask = model_kwargs["attention_mask"]
1479
+ model_kwargs["attention_mask"] = torch.cat(
1480
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1481
+ )
1482
+ else:
1483
+ # update decoder attention mask
1484
+ if "decoder_attention_mask" in model_kwargs:
1485
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
1486
+ model_kwargs["decoder_attention_mask"] = torch.cat(
1487
+ [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
1488
+ dim=-1,
1489
+ )
1490
+
1491
+ if model_kwargs.get("use_cache", True):
1492
+ model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
1493
+ else:
1494
+ past_positions = model_kwargs.pop("cache_position")
1495
+ new_positions = torch.arange(
1496
+ past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
1497
+ ).to(past_positions.device)
1498
+ model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
1499
+ return model_kwargs
1500
 
1501
  def prepare_inputs_for_generation(
1502
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, timeseries=None, **kwargs
 
1511
  cache_length = past_length = past_key_values[0][0].shape[2]
1512
  max_cache_length = None
1513
 
1514
+ has_ts = timeseries is not None and len(timeseries) > 0
1515
+
1516
+ if has_ts and kwargs.get("attention_mask") is not None:
1517
+ attention_mask = kwargs["attention_mask"]
1518
+ attention_mask = torch.cat(
1519
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1520
+ )
1521
+
1522
+ # Set attention mask and input_ids
1523
+ if has_ts and past_length > 0:
1524
+ # We have only one token added and timeseries are already inferenced
1525
+ input_ids = input_ids[:, -1:]
1526
+ timeseries = None
1527
+ elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1528
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1529
+ elif past_length < input_ids.shape[1]:
1530
+ input_ids = input_ids[:, past_length:]
 
 
 
 
1531
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1532
 
1533
  # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
 
1544
  position_ids = attention_mask.long().cumsum(-1) - 1
1545
  position_ids.masked_fill_(attention_mask == 0, 1)
1546
  if past_key_values:
1547
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1548
 
1549
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1550
  if inputs_embeds is not None and past_key_values is None:
processing_qwen2_ts.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Tsinghua University and ByteDance.
3
+ #
4
+ # Licensed under the MIT License (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://opensource.org/license/mit
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import numpy as np
17
+ from typing import List, Union, Tuple, Optional
18
+ import torch
19
+
20
+ from transformers.feature_extraction_utils import BatchFeature
21
+ from transformers.processing_utils import ProcessorMixin
22
+ from transformers.tokenization_utils_base import (
23
+ PreTokenizedInput,
24
+ TextInput,
25
+ PaddingStrategy,
26
+ )
27
+
28
+ def sp_encoding(timeseries: np.ndarray, eots_token: bool = True) -> Tuple[np.ndarray, str, dict]:
29
+ """
30
+ Encodes a time series with scalar normalization.
31
+
32
+ Args:
33
+ timeseries (np.ndarray): The raw time series data (1D or 2D).
34
+
35
+ Returns:
36
+ result_timeseries (np.ndarray): The encoded time series, shape [seq_len, 1].
37
+ prompt (str): The placeholder string with offset and scaling info.
38
+ metadata (dict): Metadata containing the offset and scaling factor.
39
+ """
40
+ mean = np.mean(timeseries)
41
+ scaled_timeseries = timeseries - mean
42
+ scale_factor = 1.0
43
+ if np.any(np.abs(scaled_timeseries) >= 3.0):
44
+ scale_factor = np.max(np.abs(scaled_timeseries)) / 3.0
45
+ scaled_timeseries /= scale_factor
46
+
47
+ prompt = f"[Value Offset: {-mean:.4f}|Value Scaling: {scale_factor:.4f}]<ts>"
48
+ if eots_token:
49
+ prompt += '<ts/>'
50
+
51
+ result_timeseries = np.stack([scaled_timeseries, np.ones_like(scaled_timeseries)], axis=-1).reshape(-1, 1)
52
+
53
+ return result_timeseries, prompt, {"offset": float(-mean), "scale_factor": float(scale_factor)}
54
+
55
+ class Qwen2TSProcessor(ProcessorMixin):
56
+ """
57
+ A processor for ChatTS that integrates text prompt processing and time series encoding.
58
+ """
59
+
60
+ attributes = ["tokenizer"]
61
+ feature_extractor_class = None # You can add a feature extractor if needed
62
+ tokenizer_class = "AutoTokenizer"
63
+
64
+ def __init__(self, tokenizer=None):
65
+ """
66
+ Args:
67
+ tokenizer: An optional tokenizer to process text prompts.
68
+ """
69
+ super().__init__(tokenizer=tokenizer)
70
+
71
+ def __call__(
72
+ self,
73
+ text: List[str],
74
+ timeseries: List[List[np.ndarray]],
75
+ padding: Union[bool, str, PaddingStrategy] = False,
76
+ padding_side: str = 'left',
77
+ vllm_flag: bool = False,
78
+ **kwargs,
79
+ ) -> BatchFeature:
80
+ """
81
+ Encodes a prompt and its associated time series.
82
+
83
+ Args:
84
+ prompt (List[str]): The input prompt containing <ts><ts/> placeholders.
85
+ timeseries (List[np.ndarray]): A list of time series matched to placeholders in the prompt.
86
+ padding (bool or str or PaddingStrategy, optional): Passed to the tokenizer for text padding.
87
+ return_tensors (str, optional): "pt" to return PyTorch tensors; None to return NumPy arrays.
88
+ **kwargs: Additional tokenizer parameters.
89
+
90
+ Returns:
91
+ BatchFeature: Contains processed prompt, encoded time series, and tokenizer outputs.
92
+ """
93
+ if type(text) == str:
94
+ text = [text]
95
+
96
+ encoded_ts_arrays = []
97
+ reconstructed_prompts = []
98
+ total_ts_cnt = 0
99
+ for idx, prompt in enumerate(text):
100
+ # Split prompt by <ts><ts/> placeholders
101
+ last_ts_cnt = total_ts_cnt
102
+ prompt_segments = prompt.split("<ts><ts/>")
103
+ total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
104
+
105
+ # Encode each time series and rebuild the prompt
106
+ reconstructed_prompt = prompt_segments[0]
107
+
108
+ for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
109
+ encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
110
+ reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
111
+ # Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
112
+ encoded_ts_arrays.append(encoded_ts[None, ...])
113
+
114
+ reconstructed_prompts.append(reconstructed_prompt)
115
+
116
+ if len(timeseries) != len(encoded_ts_arrays):
117
+ raise ValueError(
118
+ f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
119
+ f"and time series ({len(encoded_ts_arrays)})."
120
+ )
121
+
122
+ if len(encoded_ts_arrays) > 0:
123
+ # Pad time series to the same length
124
+ max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
125
+ padded_ts_arrays = [
126
+ np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
127
+ for ts in encoded_ts_arrays
128
+ ]
129
+ concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
130
+
131
+ # Convert to torch
132
+ concatenated_ts = torch.from_numpy(concatenated_ts).half()
133
+ else:
134
+ concatenated_ts = None
135
+
136
+ # Tokenize the processed prompt
137
+ tokenizer_outputs = {}
138
+ if self.tokenizer is not None:
139
+ tokenizer_outputs = self.tokenizer(reconstructed_prompts, padding=padding, padding_side=padding_side, **kwargs)
140
+
141
+ # Create the final output
142
+ outputs = {
143
+ "timeseries": concatenated_ts
144
+ }
145
+ outputs.update(tokenizer_outputs)
146
+
147
+ return BatchFeature(data=outputs)
148
+
149
+ @property
150
+ def model_input_names(self):
151
+ """
152
+ Define the input names expected by the model.
153
+ """
154
+ tokenizer_input_names = []
155
+ if self.tokenizer and hasattr(self.tokenizer, "model_input_names"):
156
+ tokenizer_input_names = self.tokenizer.model_input_names
157
+ return list(dict.fromkeys(["processed_prompt", "time_series"] + tokenizer_input_names))
158
+
159
+ def batch_decode(self, *args, **kwargs):
160
+ """
161
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
162
+ refer to the docstring of this method for more information.
163
+ """
164
+ return self.tokenizer.batch_decode(*args, **kwargs)
165
+
166
+ def decode(self, *args, **kwargs):
167
+ """
168
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
169
+ the docstring of this method for more information.
170
+ """
171
+ return self.tokenizer.decode(*args, **kwargs)