Tonic commited on
Commit
0025d7f
·
1 Parent(s): b94ef5d

assure tokenizer EOS tokens are properly moved

Browse files
Files changed (1) hide show
  1. app.py +17 -0
app.py CHANGED
@@ -420,6 +420,23 @@ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5
420
  for name, value in pipe.tokenizer.__dict__.items():
421
  if isinstance(value, torch.Tensor):
422
  pipe.tokenizer.__dict__[name] = value.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  # Force synchronization again to ensure all tensors are on GPU
425
  torch.cuda.synchronize()
 
420
  for name, value in pipe.tokenizer.__dict__.items():
421
  if isinstance(value, torch.Tensor):
422
  pipe.tokenizer.__dict__[name] = value.to(device)
423
+
424
+ # Ensure all tokenizer methods are on GPU
425
+ if hasattr(pipe.tokenizer, '_append_eos_token'):
426
+ # Create a wrapper for _append_eos_token to ensure device consistency
427
+ original_append_eos = pipe.tokenizer._append_eos_token
428
+ def wrapped_append_eos(token_ids, attention_mask):
429
+ # Ensure both tensors are on GPU
430
+ token_ids = token_ids.to(device)
431
+ attention_mask = attention_mask.to(device)
432
+ # Get the EOS token and ensure it's on GPU
433
+ eos_token = torch.tensor([pipe.tokenizer.eos_token_id], device=device)
434
+ eos_tokens = eos_token.unsqueeze(0).expand(token_ids.shape[0], 1)
435
+ # Concatenate on GPU
436
+ token_ids = torch.cat((token_ids, eos_tokens), dim=1)
437
+ attention_mask = torch.cat((attention_mask, torch.ones_like(eos_tokens)), dim=1)
438
+ return token_ids, attention_mask
439
+ pipe.tokenizer._append_eos_token = wrapped_append_eos
440
 
441
  # Force synchronization again to ensure all tensors are on GPU
442
  torch.cuda.synchronize()