sam749 commited on
Commit
b2d5a8f
·
verified ·
1 Parent(s): a9507d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -36
app.py CHANGED
@@ -2,53 +2,36 @@ import os
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
- from transformers import AutoModelForCausalLM,AutoProcessor
6
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  processor = AutoProcessor.from_pretrained("microsoft/git-base")
10
  model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device)
11
 
12
- def generate_captions(images:[Image],max_length=200):
13
- # prepare image for the model
14
- inputs = processor(images=images, return_tensors="pt").to(device)
15
- pixel_values = inputs.pixel_values
16
- generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
17
- generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
18
- return generated_caption
19
 
20
- def generate_caption(image,max_length=200):
21
- return generate_captions(image,max_length)[0]
22
 
23
-
24
- inputs = [
25
- gr.Image(sources=["upload", "clipboard"],
26
- height=400,
27
- type="pil"
28
- ),
29
- gr.Slider(minimum=10,
30
- maximum=400,
31
- value=200,
32
- label='max length',
33
- step=8,
34
- )
35
- ]
36
- outputs = [
37
- gr.Text(label="Generated Caption"),
38
- ]
39
 
40
  demo = gr.Interface(
41
  fn=generate_caption,
42
- inputs=inputs,
43
- outputs=outputs,
44
  title="Stable Diffusion Portrait Captioner",
45
- theme="gradio/monochrome",
46
- api_name="caption",
47
- submit_btn=gr.Button("caption it", variant="primary"),
48
- allow_flagging="never",
49
- )
50
- demo.queue(
51
- max_size=10,
52
  )
53
 
54
- demo.launch()
 
 
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
5
+ from transformers import AutoModelForCausalLM, AutoProcessor
6
 
7
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
8
 
9
  processor = AutoProcessor.from_pretrained("microsoft/git-base")
10
  model = AutoModelForCausalLM.from_pretrained("sam749/sd-portrait-caption").to(device)
11
 
12
+ def generate_captions(images, max_length=200):
13
+ # prepare image for the model
14
+ inputs = processor(images=images, return_tensors="pt").to(device)
15
+ pixel_values = inputs.pixel_values
16
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=max_length)
17
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)
18
+ return generated_caption
19
 
20
+ def generate_caption(image, max_length=200):
21
+ return generate_captions([image], max_length)[0]
22
 
23
+ image_input = gr.Image(source="upload", type="pil", label="Upload Image", height=400)
24
+ max_length_slider = gr.Slider(minimum=10, maximum=400, value=200, step=8, label="Max Length")
25
+ caption_output = gr.Textbox(label="Generated Caption")
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  demo = gr.Interface(
28
  fn=generate_caption,
29
+ inputs=[image_input, max_length_slider],
30
+ outputs=caption_output,
31
  title="Stable Diffusion Portrait Captioner",
32
+ theme="default",
33
+ allow_flagging="never"
 
 
 
 
 
34
  )
35
 
36
+ if __name__ == "__main__":
37
+ demo.launch()