fthor commited on
Commit
b083d4d
·
1 Parent(s): ed1cd13

Avoiding CUDA Memory limit by rebatching inputs

Browse files
Files changed (1) hide show
  1. app.py +64 -11
app.py CHANGED
@@ -13,9 +13,7 @@ quantization_config = BitsAndBytesConfig(
13
  )
14
 
15
  embedder = SentenceTransformer('all-mpnet-base-v2')
16
-
17
  model_id = "llava-hf/llava-1.5-7b-hf"
18
-
19
  processor = AutoProcessor.from_pretrained(model_id)
20
  model = LlavaForConditionalGeneration.from_pretrained(
21
  model_id,
@@ -25,24 +23,79 @@ model = LlavaForConditionalGeneration.from_pretrained(
25
  low_cpu_mem_usage=True
26
  )
27
 
 
28
 
29
- def text_to_image(image, prompt, duplications: int):
30
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
31
 
32
  image_batch = [image]
33
  prompt_batch = [prompt]
34
- for _ in range(duplications):
35
  image_batch.append(deepcopy(image))
36
  prompt_batch.append(prompt)
37
 
38
- inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt").to(model.device)
39
- output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
40
- generated_text = processor.batch_decode(output, skip_special_tokens=True)
41
- text = generated_text.pop()
42
- text_output = text.split("ASSISTANT:")[-1]
43
- text_embeddings = embedder.encode(text_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- return text_output, dict(text_embeddings=text_embeddings)
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  demo = gr.Interface(
 
13
  )
14
 
15
  embedder = SentenceTransformer('all-mpnet-base-v2')
 
16
  model_id = "llava-hf/llava-1.5-7b-hf"
 
17
  processor = AutoProcessor.from_pretrained(model_id)
18
  model = LlavaForConditionalGeneration.from_pretrained(
19
  model_id,
 
23
  low_cpu_mem_usage=True
24
  )
25
 
26
+ MAXIMUM_PIXEL_VALUES = 3725568
27
 
28
+ def text_to_image(image, prompt, duplications: float):
29
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
30
 
31
  image_batch = [image]
32
  prompt_batch = [prompt]
33
+ for _ in range(int(duplications)):
34
  image_batch.append(deepcopy(image))
35
  prompt_batch.append(prompt)
36
 
37
+ inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt")
38
+
39
+ batched_inputs :list[dict[str, torch.Tensor]] = list()
40
+ if inputs['pixel_values'].flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
41
+ batch = dict(input_ids=list(), attention_mask=list(), pixel_values=list())
42
+ i = 0
43
+ while i < len(inputs['pixel_values']):
44
+ batch['input_ids'].append(inputs['input_ids'][i])
45
+ batch['attention_mask'].append(inputs['attention_mask'][i])
46
+ batch['pixel_values'].append(inputs['pixel_values'][i])
47
+
48
+ if torch.cat(batch['pixel_values'], dim=0).flatten().shape[0] > MAXIMUM_PIXEL_VALUES:
49
+ print(f'[{i}/{len(inputs["pixel_values"])}] - Reached max pixel values for batch prediction on T4 '
50
+ f'16GB GPU. Will split in more batches')
51
+ # Remove the last added image because it's too big to process
52
+ batch['input_ids'].pop()
53
+ batch['attention_mask'].pop()
54
+ batch['pixel_values'].pop()
55
+
56
+ # transform lists to tensors
57
+ batch['input_ids'] = torch.cat(batch['input_ids'], dim=0)
58
+ batch['attention_mask'] = torch.cat(batch['input_ids'], dim=0)
59
+ batch['pixel_values'] = torch.cat(batch['input_ids'], dim=0)
60
+
61
+ # Add to the batched_inputs
62
+ batched_inputs.append(batch)
63
+ else:
64
+ i += 1
65
+
66
+ maurice_description = list()
67
+ maurice_embeddings = list()
68
+ for batch in batched_inputs:
69
+ # Load on device
70
+ batch['input_ids'].to(model.device)
71
+ batch['attention_mask'].to(model.device)
72
+ batch['pixel_values'].to(model.device)
73
+ output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
74
+
75
+ # Unload GPU
76
+ batch['input_ids'].to('cpu')
77
+ batch['attention_mask'].to('cpu')
78
+ batch['pixel_values'].to('cpu')
79
+
80
+ generated_text = processor.batch_decode(output, skip_special_tokens=True)
81
+ output = output.to('cpu')
82
+
83
+ for text in generated_text:
84
+ text_output = text.split("ASSISTANT:")[-1]
85
+ text_embeddings = embedder.encode(text_output).to('cpu')
86
+ maurice_description.append(text_output)
87
+ maurice_embeddings.append(text_embeddings)
88
 
89
+ return '\n---\n'.join(maurice_description), dict(text_embeddings=maurice_embeddings)
90
+ # inputs = inputs.to(model.device)
91
+ # print()
92
+ # output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
93
+ # generated_text = processor.batch_decode(output, skip_special_tokens=True)
94
+ # text = generated_text.pop()
95
+ # text_output = text.split("ASSISTANT:")[-1]
96
+ # text_embeddings = embedder.encode(text_output)
97
+ #
98
+ # return text_output, dict(text_embeddings=text_embeddings)
99
 
100
 
101
  demo = gr.Interface(