Spaces:
Running
Running
Try to report memory errors, and try clearing the traceback to avoid leaking the prior cache.
Browse files- custom_llm_inference.py +23 -0
custom_llm_inference.py
CHANGED
@@ -2,6 +2,25 @@ import torch
|
|
2 |
from transformers.cache_utils import DynamicCache
|
3 |
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
def get_tokenized_chat(tokenizer, prompt, doc):
|
6 |
messages = [
|
7 |
{
|
@@ -28,6 +47,7 @@ def tokenize_doc_in_progress(tokenizer, doc_in_progress):
|
|
28 |
return doc_in_progress_ids
|
29 |
|
30 |
|
|
|
31 |
def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
|
32 |
tokenized_chat = get_tokenized_chat(tokenizer, prompt, doc)
|
33 |
assert len(tokenized_chat.shape) == 1
|
@@ -63,6 +83,7 @@ def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
|
|
63 |
return highlights
|
64 |
|
65 |
|
|
|
66 |
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
|
67 |
"""
|
68 |
For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
|
@@ -113,6 +134,7 @@ def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, devic
|
|
113 |
return lookahead_sequences, next_token_logits
|
114 |
|
115 |
|
|
|
116 |
def get_next_token_predictions_inner(
|
117 |
model, tokenizer, original_doc, prompt, doc_in_progress, k):
|
118 |
|
@@ -145,6 +167,7 @@ def get_next_token_predictions_inner(
|
|
145 |
return decoded_next_tokens, next_token_logits
|
146 |
|
147 |
|
|
|
148 |
def get_next_token_predictions_slow(
|
149 |
model, tokenizer, original_doc, prompt, doc_in_progress, k):
|
150 |
|
|
|
2 |
from transformers.cache_utils import DynamicCache
|
3 |
|
4 |
|
5 |
+
def catch_and_report_memory_exceptions(func):
|
6 |
+
"""
|
7 |
+
Decorator to catch and report memory exceptions.
|
8 |
+
"""
|
9 |
+
def wrapper(*args, **kwargs):
|
10 |
+
# https://docs.pytorch.org/docs/stable/torch_cuda_memory.html
|
11 |
+
torch.cuda.memory._record_memory_history()
|
12 |
+
try:
|
13 |
+
return func(*args, **kwargs)
|
14 |
+
except torch.OutOfMemoryError as e:
|
15 |
+
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
16 |
+
print(f"Memory error: {e}")
|
17 |
+
# clear frames in the traceback to avoid memory leak
|
18 |
+
import traceback, sys
|
19 |
+
traceback.clear_frames(sys.exc_info()[2])
|
20 |
+
raise e
|
21 |
+
return wrapper
|
22 |
+
|
23 |
+
|
24 |
def get_tokenized_chat(tokenizer, prompt, doc):
|
25 |
messages = [
|
26 |
{
|
|
|
47 |
return doc_in_progress_ids
|
48 |
|
49 |
|
50 |
+
@catch_and_report_memory_exceptions
|
51 |
def get_highlights_inner(model, tokenizer, doc, prompt, updated_doc, k):
|
52 |
tokenized_chat = get_tokenized_chat(tokenizer, prompt, doc)
|
53 |
assert len(tokenized_chat.shape) == 1
|
|
|
83 |
return highlights
|
84 |
|
85 |
|
86 |
+
@catch_and_report_memory_exceptions
|
87 |
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
|
88 |
"""
|
89 |
For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
|
|
|
134 |
return lookahead_sequences, next_token_logits
|
135 |
|
136 |
|
137 |
+
@catch_and_report_memory_exceptions
|
138 |
def get_next_token_predictions_inner(
|
139 |
model, tokenizer, original_doc, prompt, doc_in_progress, k):
|
140 |
|
|
|
167 |
return decoded_next_tokens, next_token_logits
|
168 |
|
169 |
|
170 |
+
@catch_and_report_memory_exceptions
|
171 |
def get_next_token_predictions_slow(
|
172 |
model, tokenizer, original_doc, prompt, doc_in_progress, k):
|
173 |
|