zhiweili commited on
Commit
a1553b6
Β·
1 Parent(s): 8491a0f

test onediff

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +10 -0
  3. app_onediff.py +85 -0
  4. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode
2
+ .DS_Store
3
+ __pycache__
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_base import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch()
app_onediff.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from diffusers import (
7
+ DDPMScheduler,
8
+ AutoPipelineForText2Image,
9
+ AutoencoderTiny,
10
+ )
11
+
12
+ import oneflow as flow
13
+ from onediff.infer_compiler import oneflow_compile
14
+
15
+ BASE_MODEL = "stabilityai/sdxl-turbo"
16
+ device = "cuda"
17
+
18
+ vae = AutoencoderTiny.from_pretrained(
19
+ 'madebyollin/taesdxl',
20
+ use_safetensors=True,
21
+ torch_dtype=torch.float16,
22
+ ).to('cuda')
23
+ base_pipe = AutoPipelineForText2Image.from_pretrained(
24
+ BASE_MODEL,
25
+ vae=vae,
26
+ torch_dtype=torch.float16,
27
+ variant="fp16",
28
+ use_safetensors=True,
29
+ )
30
+ base_pipe.to(device)
31
+
32
+ base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
33
+ base_pipe.scheduler = DDPMScheduler.from_pretrained(
34
+ BASE_MODEL,
35
+ subfolder="scheduler",
36
+ )
37
+ base_pipe.unet = oneflow_compile(base_pipe.unet)
38
+ # base_pipe.vae.decoder = oneflow_compile(base_pipe.vae.decoder)
39
+
40
+ def create_demo() -> gr.Blocks:
41
+
42
+ @spaces.GPU(duration=10)
43
+ def text_to_image(
44
+ prompt:str,
45
+ steps:int,
46
+ ):
47
+ run_task_time = 0
48
+ time_cost_str = ''
49
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
50
+ generated_image = base_pipe(
51
+ prompt=prompt,
52
+ num_inference_steps=steps,
53
+ ).images[0]
54
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
55
+ return generated_image
56
+
57
+ def get_time_cost(run_task_time, time_cost_str):
58
+ now_time = int(time.time()*1000)
59
+ if run_task_time == 0:
60
+ time_cost_str = 'start'
61
+ else:
62
+ if time_cost_str != '':
63
+ time_cost_str += f'-->'
64
+ time_cost_str += f'{now_time - run_task_time}'
65
+ run_task_time = now_time
66
+ return run_task_time, time_cost_str
67
+
68
+ with gr.Blocks() as demo:
69
+ with gr.Row():
70
+ with gr.Column():
71
+ prompt = gr.Textbox(label="Prompt", placeholder="Write a prompt here", lines=2, value="A beautiful sunset over the city")
72
+ with gr.Column():
73
+ steps = gr.Slider(label="Inference Steps", min=1, max=30, step=1, value=5)
74
+ g_btn = gr.Button("Generate")
75
+
76
+ with gr.Row():
77
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
78
+
79
+ g_btn.click(
80
+ fn=text_to_image,
81
+ inputs=[prompt, steps],
82
+ outputs=[generated_image],
83
+ )
84
+
85
+ return demo
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ diffusers
5
+ transformers
6
+ accelerate
7
+ spaces