Nikita Pavlichenko commited on
Commit
94d9602
·
1 Parent(s): e71ff86

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +44 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ import numpy as np
4
+ from diffusers import DiffusionPipeline
5
+
6
+
7
+ prompt_writer = pipeline('text-generation', model='toloka/gpt2-large-rl-prompt-writing')
8
+ prompt_reward_model = pipeline('text-classification', model='toloka/prompts_reward_model')
9
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
10
+
11
+
12
+ def write_prompt(img_desc):
13
+ prompts = [p['generated_text'] for p in prompt_writer(img_desc + '</s>', max_new_tokens=100, num_return_sequences=2)]
14
+ scores = [p['score'] for p in prompt_reward_model(prompts, function_to_apply='none')]
15
+ return prompts[np.argmax(scores)].split('</s>')[1].strip()
16
+
17
+
18
+ def generate(text):
19
+ prompt = write_prompt(text)
20
+ img = pipe(prompt=prompt, num_inference_steps=50).images[0]
21
+ return img, prompt
22
+
23
+ with gr.Blocks() as demo:
24
+ with gr.Column(variant="panel"):
25
+ with gr.Row(variant="compact"):
26
+ text = gr.Textbox(
27
+ label="Enter your image description",
28
+ show_label=False,
29
+ max_lines=1,
30
+ placeholder="Enter your prompt",
31
+ ).style(
32
+ container=False,
33
+ )
34
+ btn = gr.Button("Generate image").style(full_width=False)
35
+
36
+ written_prompt = gr.outputs.Textbox(label="Written prompt")
37
+ gen_img = gr.outputs.Image(type="pil",
38
+ label="Generated image",
39
+ ).style(object_fit="contain", height=512)
40
+
41
+ btn.click(generate, text, [gen_img, written_prompt])
42
+
43
+ if __name__ == "__main__":
44
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ transformers
4
+ Pillow
5
+ diffusers==0.12.1