kcarnold commited on
Commit
58ed19d
·
1 Parent(s): efaf817

Try to report memory errors, and try clearing the traceback to avoid leaking the prior cache.

Browse files
Files changed (1) hide show
  1. custom_llm_inference.py +23 -0
custom_llm_inference.py CHANGED
@@ -2,6 +2,25 @@ import torch
2
  from transformers.cache_utils import DynamicCache
3
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def get_tokenized_chat(tokenizer, prompt, doc):
6
  messages = [
7
  {
@@ -28,6 +47,7 @@ def tokenize_doc_in_progress(tokenizer, doc_in_progress):
28
  return doc_in_progress_ids
29
 
30
 
 
31
  def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
32
  tokenized_chat = get_tokenized_chat(tokenizer, prompt, doc)
33
  assert len(tokenized_chat.shape) == 1
@@ -63,6 +83,7 @@ def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
63
  return highlights
64
 
65
 
 
66
  def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
67
  """
68
  For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
@@ -113,6 +134,7 @@ def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, devic
113
  return lookahead_sequences, next_token_logits
114
 
115
 
 
116
  def get_next_token_predictions_inner(
117
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
118
 
@@ -145,6 +167,7 @@ def get_next_token_predictions_inner(
145
  return decoded_next_tokens, next_token_logits
146
 
147
 
 
148
  def get_next_token_predictions_slow(
149
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
150
 
 
2
  from transformers.cache_utils import DynamicCache
3
 
4
 
5
+ def catch_and_report_memory_exceptions(func):
6
+ """
7
+ Decorator to catch and report memory exceptions.
8
+ """
9
+ def wrapper(*args, **kwargs):
10
+ # https://docs.pytorch.org/docs/stable/torch_cuda_memory.html
11
+ torch.cuda.memory._record_memory_history()
12
+ try:
13
+ return func(*args, **kwargs)
14
+ except torch.OutOfMemoryError as e:
15
+ torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
16
+ print(f"Memory error: {e}")
17
+ # clear frames in the traceback to avoid memory leak
18
+ import traceback, sys
19
+ traceback.clear_frames(sys.exc_info()[2])
20
+ raise e
21
+ return wrapper
22
+
23
+
24
  def get_tokenized_chat(tokenizer, prompt, doc):
25
  messages = [
26
  {
 
47
  return doc_in_progress_ids
48
 
49
 
50
+ @catch_and_report_memory_exceptions
51
  def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
52
  tokenized_chat = get_tokenized_chat(tokenizer, prompt, doc)
53
  assert len(tokenized_chat.shape) == 1
 
83
  return highlights
84
 
85
 
86
+ @catch_and_report_memory_exceptions
87
  def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
88
  """
89
  For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
 
134
  return lookahead_sequences, next_token_logits
135
 
136
 
137
+ @catch_and_report_memory_exceptions
138
  def get_next_token_predictions_inner(
139
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
140
 
 
167
  return decoded_next_tokens, next_token_logits
168
 
169
 
170
+ @catch_and_report_memory_exceptions
171
  def get_next_token_predictions_slow(
172
  model, tokenizer, original_doc, prompt, doc_in_progress, k):
173