zhiweili commited on
Commit
4b80f72
Β·
1 Parent(s): 7a58dc4

change to image to imagw

Browse files
Files changed (1) hide show
  1. app_tensorrt.py +40 -28
app_tensorrt.py CHANGED
@@ -50,39 +50,51 @@ compiledModel = torch.compile(
50
 
51
  base_pipe.unet = compiledModel
52
 
53
- init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png")
54
- generated_image = base_pipe(
55
- image=init_image,
56
- prompt="A white cat",
57
- num_inference_steps=5,
58
- ).images[0]
59
 
60
- generated_image.save("/tmp/gradio/generated_image.png")
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def create_demo() -> gr.Blocks:
64
 
65
  @spaces.GPU(duration=30)
66
- def text_to_image(
 
67
  prompt:str,
68
  steps:int,
69
  ):
70
- import torch_tensorrt
71
- print('Compiling model...')
72
- compiledModel = torch.compile(
73
- base_pipe.unet,
74
- backend=backend,
75
- options={
76
- "truncate_long_and_double": True,
77
- "enabled_precisions": {torch.float32, torch.float16},
78
- },
79
- dynamic=False,
80
- )
81
- print('Model compiled!')
82
-
83
- print('Saving compiled model...')
84
- torch_tensorrt.save(compiledModel, "compiled_pipe.ep")
85
- print('Compiled model saved!')
 
 
 
 
 
86
 
87
  with gr.Blocks() as demo:
88
  with gr.Row():
@@ -94,15 +106,15 @@ def create_demo() -> gr.Blocks:
94
 
95
  with gr.Row():
96
  with gr.Column():
97
- generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
98
  with gr.Column():
 
99
  time_cost = gr.Textbox(label="Time Cost", lines=1, interactive=False)
100
 
101
  g_btn.click(
102
  fn=text_to_image,
103
- inputs=[prompt, steps],
104
- # outputs=[generated_image, time_cost],
105
- outputs=[],
106
  )
107
 
108
  return demo
 
50
 
51
  base_pipe.unet = compiledModel
52
 
53
+ import torch._dynamo
54
+ torch._dynamo.config.suppress_errors = True
 
 
 
 
55
 
56
+ try:
57
+ init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img.png")
58
+ generated_image = base_pipe(
59
+ image=init_image,
60
+ prompt="A white cat",
61
+ num_inference_steps=5,
62
+ ).images[0]
63
+
64
+ generated_image.save("/tmp/gradio/generated_image.png")
65
+ except Exception as e:
66
+ print(f"Error: {e}")
67
 
68
 
69
  def create_demo() -> gr.Blocks:
70
 
71
  @spaces.GPU(duration=30)
72
+ def image_to_image(
73
+ image: gr.Image,
74
  prompt:str,
75
  steps:int,
76
  ):
77
+ run_task_time = 0
78
+ time_cost_str = ''
79
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
80
+ generated_image = base_pipe(
81
+ image=image,
82
+ prompt=prompt,
83
+ num_inference_steps=steps,
84
+ ).images[0]
85
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
86
+ return generated_image
87
+
88
+ def get_time_cost(run_task_time, time_cost_str):
89
+ now_time = int(time.time()*1000)
90
+ if run_task_time == 0:
91
+ time_cost_str = 'start'
92
+ else:
93
+ if time_cost_str != '':
94
+ time_cost_str += f'-->'
95
+ time_cost_str += f'{now_time - run_task_time}'
96
+ run_task_time = now_time
97
+ return run_task_time, time_cost_str
98
 
99
  with gr.Blocks() as demo:
100
  with gr.Row():
 
106
 
107
  with gr.Row():
108
  with gr.Column():
109
+ input_image = gr.Image(label="Input Image", type="pil", interactive=True)
110
  with gr.Column():
111
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
112
  time_cost = gr.Textbox(label="Time Cost", lines=1, interactive=False)
113
 
114
  g_btn.click(
115
  fn=text_to_image,
116
+ inputs=[input_image, prompt, steps],
117
+ outputs=[generated_image, time_cost],
 
118
  )
119
 
120
  return demo