zhiweili commited on
Commit
fc1393b
Β·
1 Parent(s): 688a239

add app_tensorrt

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. app_tensorrt.py +85 -0
  3. requirements.txt +4 -9
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
 
3
- from app_onediff import create_demo as create_demo_face
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
 
1
  import gradio as gr
2
 
3
+ from app_tensorrt import create_demo as create_demo_face
4
 
5
  with gr.Blocks(css="style.css") as demo:
6
  with gr.Tabs():
app_tensorrt.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch_tensorrt
3
+
4
+ from diffusers import (
5
+ DDPMScheduler,
6
+ StableDiffusionXLImg2ImgPipeline,
7
+ AutoencoderKL,
8
+ )
9
+
10
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
11
+ device = "cuda"
12
+
13
+ vae = AutoencoderKL.from_pretrained(
14
+ "madebyollin/sdxl-vae-fp16-fix",
15
+ torch_dtype=torch.float16,
16
+ )
17
+
18
+ base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
19
+ BASE_MODEL,
20
+ vae=vae,
21
+ torch_dtype=torch.float16,
22
+ variant="fp16",
23
+ use_safetensors=True,
24
+ )
25
+ base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
26
+ base_pipe.scheduler = DDPMScheduler.from_pretrained(
27
+ BASE_MODEL,
28
+ subfolder="scheduler",
29
+ )
30
+
31
+ backend = "torch_tensorrt"
32
+
33
+
34
+
35
+ # print('Loading compiled model...')
36
+ # loadedModel = torch_tensorrt.load("compiled_pipe.ep").module()
37
+ # print('Compiled model loaded!')
38
+
39
+
40
+
41
+ def create_demo() -> gr.Blocks:
42
+
43
+ @spaces.GPU(duration=30)
44
+ def text_to_image(
45
+ prompt:str,
46
+ steps:int,
47
+ ):
48
+ print('Compiling model...')
49
+ compiledModel = torch.compile(
50
+ base_pipe.unet,
51
+ backend=backend,
52
+ options={
53
+ "truncate_long_and_double": True,
54
+ "enabled_precisions": {torch.float32, torch.float16},
55
+ },
56
+ dynamic=False,
57
+ )
58
+ print('Model compiled!')
59
+
60
+ print('Saving compiled model...')
61
+ torch_tensorrt.save(compiledModel, "compiled_pipe.ep")
62
+ print('Compiled model saved!')
63
+
64
+ with gr.Blocks() as demo:
65
+ with gr.Row():
66
+ with gr.Column():
67
+ prompt = gr.Textbox(label="Prompt", placeholder="Write a prompt here", lines=2, value="A beautiful sunset over the city")
68
+ with gr.Column():
69
+ steps = gr.Slider(minimum=1, maximum=100, value=5, step=1, label="Num Steps")
70
+ g_btn = gr.Button("Generate")
71
+
72
+ with gr.Row():
73
+ with gr.Column():
74
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
75
+ with gr.Column():
76
+ time_cost = gr.Textbox(label="Time Cost", lines=1, interactive=False)
77
+
78
+ g_btn.click(
79
+ fn=text_to_image,
80
+ inputs=[prompt, steps],
81
+ # outputs=[generated_image, time_cost],
82
+ outputs=[],
83
+ )
84
+
85
+ return demo
requirements.txt CHANGED
@@ -1,13 +1,8 @@
1
  gradio
2
- torch
3
- torchvision
 
4
  diffusers
5
  transformers
6
  accelerate
7
- spaces
8
- git+https://github.com/XPixelGroup/BasicSR@master
9
- gfpgan
10
- facexlib
11
- realesrgan
12
- triton
13
- xformers
 
1
  gradio
2
+ torch==2.5.0
3
+ torch_tensorrt==2.5.0
4
+ torchvision==0.20.0
5
  diffusers
6
  transformers
7
  accelerate
8
+ spaces