xiezhe24 commited on
Commit
01b96b5
1 Parent(s): 7fd83b6

Update README and scripts

Browse files
README.md CHANGED
@@ -1,20 +1,22 @@
1
- ---
2
- license: apache-2.0
3
- base_model:
4
- - Qwen/Qwen2.5-14B-Instruct
5
- ---
6
  # ChatTS-14B Model
7
- This model is fine-tuned on the QWen2.5-14B-Instruct (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct) model. For more usage details, please refer to the `README.md` in the ChatTS repository.
 
8
 
9
- # Reference
 
 
 
 
 
 
10
  - QWen2.5-14B-Instruct (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct)
11
  - transformers (https://github.com/huggingface/transformers.git)
12
  - [ChatTS Paper](https://arxiv.org/pdf/2412.03104)
13
 
14
- # License
15
  This model is licensed under the [Apache License 2.0](LICENSE).
16
 
17
- # Cite
18
  ```
19
  @article{xie2024chatts,
20
  title={ChatTS: Aligning Time Series with LLMs via Synthetic Data for Enhanced Understanding and Reasoning},
@@ -22,4 +24,4 @@ This model is licensed under the [Apache License 2.0](LICENSE).
22
  journal={arXiv preprint arXiv:2412.03104},
23
  year={2024}
24
  }
25
- ```
 
 
 
 
 
 
1
  # ChatTS-14B Model
2
+ `ChatTS` focuses on **Understanding and Reasoning** about time series, much like what vision/video/audio-MLLMs do.
3
+ This repo provides code, datasets and model for `ChatTS`: [ChatTS: Aligning Time Series with LLMs via Synthetic Data for Enhanced Understanding and Reasoning](https://arxiv.org/pdf/2412.03104).
4
 
5
+ Here is an example of a ChatTS application, which allows users to interact with a LLM to understand and reason about time series data:
6
+ ![Chat](figures/chat_example.png)
7
+
8
+ ## Usage
9
+ This model is fine-tuned on the QWen2.5-14B-Instruct (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct) model. For more usage details, please refer to the `README.md` in the ChatTS repository.
10
+
11
+ ## Reference
12
  - QWen2.5-14B-Instruct (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct)
13
  - transformers (https://github.com/huggingface/transformers.git)
14
  - [ChatTS Paper](https://arxiv.org/pdf/2412.03104)
15
 
16
+ ## License
17
  This model is licensed under the [Apache License 2.0](LICENSE).
18
 
19
+ ## Cite
20
  ```
21
  @article{xie2024chatts,
22
  title={ChatTS: Aligning Time Series with LLMs via Synthetic Data for Enhanced Understanding and Reasoning},
 
24
  journal={arXiv preprint arXiv:2412.03104},
25
  year={2024}
26
  }
27
+ ```
configuration_qwen2.py CHANGED
@@ -1,5 +1,4 @@
1
  # coding=utf-8
2
- # The following code are reused from the QWen project (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct) of Alibaba Cloud.
3
  # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,10 +12,6 @@
13
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
-
17
- # The code is modified by ByteDance and Tsinghua University from the original implementation of Qwen:
18
- # - We changed Qwen2Config to Qwen2TSConfig to support time series modeling.
19
-
20
  """ Qwen2 model configuration"""
21
 
22
  from transformers import PretrainedConfig
 
1
  # coding=utf-8
 
2
  # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
 
 
 
 
15
  """ Qwen2 model configuration"""
16
 
17
  from transformers import PretrainedConfig
figures/chat_example.png ADDED
modeling_qwen2.py CHANGED
@@ -1,5 +1,4 @@
1
  # coding=utf-8
2
- # The following code are reused from the QWen project (https://huggingface.co/Qwen/Qwen2.5-14B-Instruct) of Alibaba Cloud.
3
  # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
4
  #
5
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
@@ -18,10 +17,6 @@
18
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
  # See the License for the specific language governing permissions and
20
  # limitations under the License.
21
-
22
- # The code is modified by ByteDance and Tsinghua University from the original implementation of Qwen:
23
- # - Support time series modality for Qwen2 model.
24
-
25
  """ PyTorch Qwen2 model."""
26
  import inspect
27
  import math
@@ -78,7 +73,6 @@ class TimeSeriesEmbedding(nn.Module):
78
  self.num_features = config['num_features']
79
 
80
  layers = []
81
- # 调整输入大小以包含掩码通道
82
  input_size = 1 * self.patch_size
83
 
84
  for _ in range(self.num_layers - 1):
@@ -97,7 +91,6 @@ class TimeSeriesEmbedding(nn.Module):
97
  valid_lengths = mask.sum(dim=1).long() # Shape: (batch_size)
98
 
99
  patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size # 向上取整
100
- # print(f"[DEBUG] TimeSeriesEmbedding: {valid_lengths=}, {patch_cnt=}, {mask.shape=}")
101
 
102
  patches_list = []
103
  for i in range(batch_size):
@@ -118,9 +111,7 @@ class TimeSeriesEmbedding(nn.Module):
118
  x_patches = torch.cat(patches_list, dim=0) # Shape: (total_patch_cnt, patch_size * num_features)
119
  x = self.mlp(x_patches)
120
  else:
121
- # 如果没有有效的 patches,返回空 tensor
122
  x = torch.empty(0, self.hidden_size, device=x.device)
123
- # print(f"[DEBUG] TimeSeriesEmbedding OUTPUT: {x.shape=}, {patch_cnt=}")
124
 
125
  return x, patch_cnt
126
 
@@ -1204,21 +1195,7 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1204
  return num_special_ts_tokens * (num_patches - 2) + input_ids.size(1)
1205
 
1206
  def _get_original_length(self, timeseries, input_ids, past_length):
1207
- """
1208
- 根据转换后的 past_length 计算对应的原始序列长度,并返回包含的 <ts> 标记数量。
1209
-
1210
- Args:
1211
- timeseries (Tensor): 时间序列数据张量,形状为 (batch_size, num_time_steps)。
1212
- input_ids (Tensor): 原始输入 IDs 张量,形状为 (batch_size, seq_length)。
1213
- past_length (int 或 Tensor): 转换后的序列长度(包含插入的时间序列特征 token),可以是标量或形状为 (batch_size,) 的张量。
1214
-
1215
- Returns:
1216
- Tuple[Tensor, Tensor]:
1217
- - original_length (Tensor): 每个样本对应的原始序列长度,形状为 (batch_size,)。
1218
- - num_special_ts_tokens_within_past (Tensor): 每个样本在 past_length 范围内包含的 <ts> 标记数量,形状为 (batch_size,)。
1219
- """
1220
  if timeseries is None:
1221
- # 如果没有时间序列特征插入,原始长度等于 past_length
1222
  if isinstance(past_length, int):
1223
  original_length = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1224
  else:
@@ -1226,45 +1203,32 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1226
  num_special_ts_tokens_within_past = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
1227
  return original_length, num_special_ts_tokens_within_past
1228
 
1229
- # 获取配置参数
1230
  patch_size = self.config.ts['patch_size']
1231
  num_patches = timeseries.size(1) * timeseries.size(2) // patch_size // self.config.ts['num_features']
1232
  ts_token_start_index = self.config.ts_token_start_index
1233
 
1234
- # 生成 mask,标识 <ts> token 的位置
1235
  ts_mask = (input_ids == ts_token_start_index).long() # (batch_size, seq_length)
1236
 
1237
- # 计算每个位置之前的 <ts> token 数量的累积和
1238
  cumsum_ts = torch.cumsum(ts_mask, dim=1) # (batch_size, seq_length)
1239
 
1240
- # 生成位置索引,从 1 开始
1241
  seq_length = input_ids.size(1)
1242
  positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1243
 
1244
- # 计算转换后的位置
1245
  transformed_length = positions + cumsum_ts * (num_patches - 2) # (batch_size, seq_length)
1246
 
1247
- # 处理 past_length,可以是标量或张量
1248
  if isinstance(past_length, int):
1249
  past_length_tensor = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1250
  else:
1251
  past_length_tensor = past_length.to(input_ids.device)
1252
 
1253
- # 创建一个 mask,标识哪些原始位置在转换后不超过 past_length
1254
  mask = transformed_length <= past_length_tensor.unsqueeze(1) # (batch_size, seq_length)
1255
 
1256
- # 对每个样本,计算满足条件的位置数量,即原始长度
1257
  original_length = torch.sum(mask, dim=1) # (batch_size,)
1258
-
1259
- # 计算在 original_length 范围内包含的 <ts> 标记数量
1260
- # 生成一个 mask,标识 original_length 范围内的 <ts> token
1261
- # 首先生成一个位置索引
1262
  original_positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1263
  original_mask = original_positions <= original_length.unsqueeze(1) # (batch_size, seq_length)
1264
  ts_within_original_mask = ts_mask.bool() & original_mask.bool() # (batch_size, seq_length)
1265
  num_special_ts_tokens_within_past = torch.sum(ts_within_original_mask, dim=1) # (batch_size,)
1266
 
1267
- # 确保 original_length 不为负数
1268
  original_length = torch.clamp(original_length, min=0)
1269
 
1270
  return original_length, num_special_ts_tokens_within_past
@@ -1280,7 +1244,6 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1280
  special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
1281
  special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
1282
  special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
1283
- # print("Special ts token mask:", special_ts_token_mask)
1284
  num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
1285
  # Correctly calculate the total number of patches per batch
1286
  num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
@@ -1291,8 +1254,8 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1291
  num_ts_in_batch = num_special_ts_tokens[i]
1292
  num_total_patches[i] = patch_cnt[patch_index:patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
1293
  for idx in range(patch_index, patch_index + num_ts_in_batch):
1294
- batch_idx, seq_idx = special_ts_token_mask_start_nonzero[idx]
1295
- special_ts_token_mask_start_with_size[batch_idx, seq_idx] *= (patch_cnt[idx].item() - 2)
1296
  patch_index += num_ts_in_batch
1297
 
1298
  # Compute the maximum embed dimension, considering both start and end tokens
@@ -1300,17 +1263,13 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1300
 
1301
  # batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1302
  batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1303
- # print("non_ts_indices:", non_ts_indices)
1304
- # print("batch_indices:", batch_indices)
1305
-
1306
  # 2. Compute the positions where text should be written
1307
  new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
1308
- # print("new_token_positions", new_token_positions)
1309
  nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
1310
  if left_padding:
1311
  new_token_positions += nb_ts_pad[:, None] # offset for left padding
1312
  text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
1313
- # print('nb_ts_pad', nb_ts_pad)
1314
 
1315
  # 3. Create the full embedding, already padded to the maximum position
1316
  final_embedding = torch.zeros(
@@ -1334,7 +1293,6 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1334
  # 4. Fill the embeddings based on the mask
1335
  final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
1336
  final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_ts_indices]
1337
- # print('final_attention_mask=', final_attention_mask)
1338
  if labels is not None:
1339
  final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
1340
 
@@ -1343,11 +1301,8 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1343
  (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
1344
  )
1345
  ts_to_overwrite[batch_indices, text_to_overwrite] = False
1346
- # print('ts_to_overwrite.long().cumsum(-1) - 1=', ts_to_overwrite.long().cumsum(-1) - 1)
1347
- # print('nb_ts_pad=', nb_ts_pad[:, None])
1348
  reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
1349
  ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
1350
- # print('ts_to_overwrite=', ts_to_overwrite)
1351
 
1352
  if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
1353
  raise ValueError(
@@ -1356,7 +1311,6 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1356
  )
1357
 
1358
  final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
1359
- # logger.warning(f"[DEBUG] {final_embedding[ts_to_overwrite][:, 0]=}")
1360
  final_attention_mask |= ts_to_overwrite
1361
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
1362
 
@@ -1423,47 +1377,16 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1423
  inputs_embeds = self.get_input_embeddings()(input_ids)
1424
 
1425
  if timeseries is not None and timeseries.shape[0] > 0:
 
1426
  use_cache = False
1427
- # print(f"[DEBUG] input timeseries.shape: {timeseries.shape}")
1428
-
1429
- # 调用 ts_encoder,并打印输入和输出的形状
1430
  ts_features, patch_cnt = self.ts_encoder(timeseries)
1431
- # print(f"[DEBUG] ts_features.shape: {ts_features.shape}")
1432
- # print(f"[DEBUG] patch_cnt: {patch_cnt}")
1433
 
1434
  inputs_embeds = inputs_embeds.to(ts_features.dtype)
1435
 
1436
- # 在合并前打印相关形状
1437
- # print(f"[DEBUG] Before merging:")
1438
- # print(f"{inputs_embeds[0, -5:, :5]=}")
1439
- # print(f"{attention_mask.sum()=}")
1440
- # print(f" inputs_embeds.shape: {inputs_embeds.shape}")
1441
- # print(f" input_ids.shape: {input_ids.shape}")
1442
- # print(f" attention_mask.shape: {attention_mask.shape}")
1443
- # if labels is not None:
1444
- # print(f" labels.shape: {labels.shape}")
1445
- # else:
1446
- # print(f" labels: None")
1447
- # print(f" patch_cnt.shape: {patch_cnt.shape}")
1448
-
1449
- # 调用 _merge_input_ids_with_time_series_features,并打印输出的形状
1450
  inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
1451
  ts_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
1452
  )
1453
 
1454
- # print(f"[DEBUG] After merging:")
1455
- # print(f" inputs_embeds.shape: {inputs_embeds.shape}")
1456
- # print(f" attention_mask.shape: {attention_mask.shape}")
1457
- # print(f"{attention_mask.sum()=}")
1458
- # print(f"{inputs_embeds[0, -5:, :5]=}")
1459
-
1460
- # print(f" position_ids.shape: {position_ids.shape}")
1461
- # if labels is not None:
1462
- # print(f" labels.shape: {labels.shape}")
1463
- # else:
1464
- # print(f" labels: None")
1465
-
1466
- # 继续模型的前向传播
1467
  outputs = self.model(
1468
  attention_mask=attention_mask,
1469
  position_ids=position_ids,
@@ -1518,8 +1441,6 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
1518
  cache_length = past_length = past_key_values[0][0].shape[2]
1519
  max_cache_length = None
1520
 
1521
- # print(f"[prepare_inputs_for_generation] {cache_length=}, {past_length=}, {max_cache_length=}")
1522
-
1523
  # Keep only the unprocessed tokens:
1524
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1525
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
 
1
  # coding=utf-8
 
2
  # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
 
 
20
  """ PyTorch Qwen2 model."""
21
  import inspect
22
  import math
 
73
  self.num_features = config['num_features']
74
 
75
  layers = []
 
76
  input_size = 1 * self.patch_size
77
 
78
  for _ in range(self.num_layers - 1):
 
91
  valid_lengths = mask.sum(dim=1).long() # Shape: (batch_size)
92
 
93
  patch_cnt = (valid_lengths + self.patch_size - 1) // self.patch_size # 向上取整
 
94
 
95
  patches_list = []
96
  for i in range(batch_size):
 
111
  x_patches = torch.cat(patches_list, dim=0) # Shape: (total_patch_cnt, patch_size * num_features)
112
  x = self.mlp(x_patches)
113
  else:
 
114
  x = torch.empty(0, self.hidden_size, device=x.device)
 
115
 
116
  return x, patch_cnt
117
 
 
1195
  return num_special_ts_tokens * (num_patches - 2) + input_ids.size(1)
1196
 
1197
  def _get_original_length(self, timeseries, input_ids, past_length):
 
 
 
 
 
 
 
 
 
 
 
 
 
1198
  if timeseries is None:
 
1199
  if isinstance(past_length, int):
1200
  original_length = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1201
  else:
 
1203
  num_special_ts_tokens_within_past = torch.zeros(input_ids.size(0), dtype=torch.long, device=input_ids.device)
1204
  return original_length, num_special_ts_tokens_within_past
1205
 
 
1206
  patch_size = self.config.ts['patch_size']
1207
  num_patches = timeseries.size(1) * timeseries.size(2) // patch_size // self.config.ts['num_features']
1208
  ts_token_start_index = self.config.ts_token_start_index
1209
 
 
1210
  ts_mask = (input_ids == ts_token_start_index).long() # (batch_size, seq_length)
1211
 
 
1212
  cumsum_ts = torch.cumsum(ts_mask, dim=1) # (batch_size, seq_length)
1213
 
 
1214
  seq_length = input_ids.size(1)
1215
  positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1216
 
 
1217
  transformed_length = positions + cumsum_ts * (num_patches - 2) # (batch_size, seq_length)
1218
 
 
1219
  if isinstance(past_length, int):
1220
  past_length_tensor = torch.full((input_ids.size(0),), past_length, dtype=torch.long, device=input_ids.device)
1221
  else:
1222
  past_length_tensor = past_length.to(input_ids.device)
1223
 
 
1224
  mask = transformed_length <= past_length_tensor.unsqueeze(1) # (batch_size, seq_length)
1225
 
 
1226
  original_length = torch.sum(mask, dim=1) # (batch_size,)
 
 
 
 
1227
  original_positions = torch.arange(1, seq_length + 1, device=input_ids.device).unsqueeze(0).expand_as(input_ids) # (batch_size, seq_length)
1228
  original_mask = original_positions <= original_length.unsqueeze(1) # (batch_size, seq_length)
1229
  ts_within_original_mask = ts_mask.bool() & original_mask.bool() # (batch_size, seq_length)
1230
  num_special_ts_tokens_within_past = torch.sum(ts_within_original_mask, dim=1) # (batch_size,)
1231
 
 
1232
  original_length = torch.clamp(original_length, min=0)
1233
 
1234
  return original_length, num_special_ts_tokens_within_past
 
1244
  special_ts_token_mask_start = input_ids == self.config.ts_token_start_index
1245
  special_ts_token_mask_end = input_ids == self.config.ts_token_end_index
1246
  special_ts_token_mask = special_ts_token_mask_start | special_ts_token_mask_end
 
1247
  num_special_ts_tokens = torch.sum(special_ts_token_mask_start, dim=-1)
1248
  # Correctly calculate the total number of patches per batch
1249
  num_total_patches = torch.zeros(batch_size, dtype=patch_cnt.dtype, device=patch_cnt.device)
 
1254
  num_ts_in_batch = num_special_ts_tokens[i]
1255
  num_total_patches[i] = patch_cnt[patch_index:patch_index + num_ts_in_batch].sum() - 2 * num_ts_in_batch
1256
  for idx in range(patch_index, patch_index + num_ts_in_batch):
1257
+ batch_idx, pos_idx = special_ts_token_mask_start_nonzero[idx]
1258
+ special_ts_token_mask_start_with_size[batch_idx, pos_idx] *= (patch_cnt[idx].item() - 2)
1259
  patch_index += num_ts_in_batch
1260
 
1261
  # Compute the maximum embed dimension, considering both start and end tokens
 
1263
 
1264
  # batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1265
  batch_indices, non_ts_indices = torch.where(~special_ts_token_mask)
1266
+
 
 
1267
  # 2. Compute the positions where text should be written
1268
  new_token_positions = torch.cumsum((special_ts_token_mask_start_with_size + 1), dim=-1) - 1
 
1269
  nb_ts_pad = max_embed_dim - 1 - new_token_positions[:, -1]
1270
  if left_padding:
1271
  new_token_positions += nb_ts_pad[:, None] # offset for left padding
1272
  text_to_overwrite = new_token_positions[batch_indices, non_ts_indices]
 
1273
 
1274
  # 3. Create the full embedding, already padded to the maximum position
1275
  final_embedding = torch.zeros(
 
1293
  # 4. Fill the embeddings based on the mask
1294
  final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_ts_indices]
1295
  final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_ts_indices]
 
1296
  if labels is not None:
1297
  final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_ts_indices]
1298
 
 
1301
  (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
1302
  )
1303
  ts_to_overwrite[batch_indices, text_to_overwrite] = False
 
 
1304
  reversed_cumsum = ts_to_overwrite.flip(dims=[-1]).cumsum(-1).flip(dims=[-1]) - 1
1305
  ts_to_overwrite &= reversed_cumsum >= nb_ts_pad[:, None].to(target_device)
 
1306
 
1307
  if ts_to_overwrite.sum() != time_series_features.shape[:-1].numel():
1308
  raise ValueError(
 
1311
  )
1312
 
1313
  final_embedding[ts_to_overwrite] = time_series_features.contiguous().reshape(-1, embed_dim).to(target_device)
 
1314
  final_attention_mask |= ts_to_overwrite
1315
  position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
1316
 
 
1377
  inputs_embeds = self.get_input_embeddings()(input_ids)
1378
 
1379
  if timeseries is not None and timeseries.shape[0] > 0:
1380
+ # Disable KV Cache as it has not been implemented yet
1381
  use_cache = False
 
 
 
1382
  ts_features, patch_cnt = self.ts_encoder(timeseries)
 
 
1383
 
1384
  inputs_embeds = inputs_embeds.to(ts_features.dtype)
1385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1386
  inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_time_series_features(
1387
  ts_features, inputs_embeds, input_ids, attention_mask, labels, patch_cnt
1388
  )
1389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1390
  outputs = self.model(
1391
  attention_mask=attention_mask,
1392
  position_ids=position_ids,
 
1441
  cache_length = past_length = past_key_values[0][0].shape[2]
1442
  max_cache_length = None
1443
 
 
 
1444
  # Keep only the unprocessed tokens:
1445
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1446
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as