fthor commited on
Commit
ed1cd13
·
1 Parent(s): bc91b52

duplicaction test

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoProcessor, LlavaForConditionalGeneration
@@ -24,10 +26,16 @@ model = LlavaForConditionalGeneration.from_pretrained(
24
  )
25
 
26
 
27
- def text_to_image(image, prompt):
28
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
29
 
30
- inputs = processor([prompt], images=[image], padding=True, return_tensors="pt").to(model.device)
 
 
 
 
 
 
31
  output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
32
  generated_text = processor.batch_decode(output, skip_special_tokens=True)
33
  text = generated_text.pop()
@@ -41,7 +49,8 @@ demo = gr.Interface(
41
  fn=text_to_image,
42
  inputs=[
43
  gr.Image(label='Select an image to analyze', type='pil'),
44
- gr.Textbox(label='Enter Prompt')
 
45
  ],
46
  outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
47
  )
 
1
+ from copy import deepcopy
2
+
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoProcessor, LlavaForConditionalGeneration
 
26
  )
27
 
28
 
29
+ def text_to_image(image, prompt, duplications: int):
30
  prompt = f'USER: <image>\n{prompt}\nASSISTANT:'
31
 
32
+ image_batch = [image]
33
+ prompt_batch = [prompt]
34
+ for _ in range(duplications):
35
+ image_batch.append(deepcopy(image))
36
+ prompt_batch.append(prompt)
37
+
38
+ inputs = processor(prompt_batch, images=image_batch, padding=True, return_tensors="pt").to(model.device)
39
  output = model.generate(**inputs, max_new_tokens=500, temperature=0.3)
40
  generated_text = processor.batch_decode(output, skip_special_tokens=True)
41
  text = generated_text.pop()
 
49
  fn=text_to_image,
50
  inputs=[
51
  gr.Image(label='Select an image to analyze', type='pil'),
52
+ gr.Textbox(label='Enter Prompt'),
53
+ gr.Number(label='How many duplications of the image (to test memory load)', value=0)
54
  ],
55
  outputs=[gr.Textbox(label='Maurice says:'), gr.JSON(label='Embedded text')]
56
  )