rayl-aoit commited on
Commit
d476e13
·
verified ·
1 Parent(s): 5753c7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -1,16 +1,31 @@
1
  import gradio as gr
2
  from transformers import pipeline
 
3
 
4
  playground = gr.Blocks()
5
 
6
  image_pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
7
  summary_pipe = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
8
  ner_pipe = pipeline("ner", model="dslim/bert-base-NER")
 
 
9
 
10
  def launch_image_pipe(input):
11
  out = image_pipe(input)
12
  return out[0]['generated_text']
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def translate(input_text, source, target):
15
  try:
16
  model = f"Helsinki-NLP/opus-mt-{source}-{target}"
@@ -100,12 +115,17 @@ with playground:
100
  with gr.Column():
101
  img = gr.Image(type='pil')
102
  with gr.Column():
103
- generated_textbox = gr.Textbox(lines=2, placeholder="", label="Generated Text")
104
-
105
- image_pipeline_button.click(launch_image_pipe,
106
- inputs=[img],
107
- outputs=[generated_textbox])
108
-
 
 
 
 
 
109
  with gr.TabItem("Text"):
110
  with gr.Row():
111
  with gr.Column(scale=4):
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from diffusers import StableDiffusionPipeline
4
 
5
  playground = gr.Blocks()
6
 
7
  image_pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
8
  summary_pipe = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
9
  ner_pipe = pipeline("ner", model="dslim/bert-base-NER")
10
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
11
+ pipe = pipe.to("cuda")
12
 
13
  def launch_image_pipe(input):
14
  out = image_pipe(input)
15
  return out[0]['generated_text']
16
 
17
+ def base64_to_pil(img_base64):
18
+ base64_decoded = base64.b64decode(img_base64)
19
+ byte_stream = io.BytesIO(base64_decoded)
20
+ pil_image = Image.open(byte_stream)
21
+ return pil_image
22
+
23
+ def image_generate(prompt):
24
+ # output = get_completion(prompt)
25
+ image = pipe(prompt).images[0]
26
+ result_image = base64_to_pil(image)
27
+ return result_image
28
+
29
  def translate(input_text, source, target):
30
  try:
31
  model = f"Helsinki-NLP/opus-mt-{source}-{target}"
 
115
  with gr.Column():
116
  img = gr.Image(type='pil')
117
  with gr.Column():
118
+ with gr.Row():
119
+ with gr.Column(scale=4):
120
+ generated_textbox = gr.Textbox(lines=2, placeholder="", label="Generated Text")
121
+ with gr.Column(scale=1):
122
+ image_generation_button = gr.Button(value="GEN")
123
+ with gr.Row():
124
+ generated_image = gr.Image(label="Generated Image"))
125
+
126
+ image_pipeline_button.click(launch_image_pipe, inputs=[img], outputs=[generated_textbox])
127
+ image_generation_button.click(image_generate, inputs=[generated_textbox], outputs=[generated_image])
128
+
129
  with gr.TabItem("Text"):
130
  with gr.Row():
131
  with gr.Column(scale=4):