ywen commited on
Commit
809ed8d
·
1 Parent(s): 8da63a9

Cache examples

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -16,8 +16,12 @@ args.print_step = None
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model, _, preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)
18
 
 
19
 
20
  def inference(target_image, prompt_len, iter):
 
 
 
21
  if prompt_len is not None:
22
  args.prompt_len = int(prompt_len)
23
  else:
@@ -29,10 +33,13 @@ def inference(target_image, prompt_len, iter):
29
  args.iter = 1000
30
 
31
  learned_prompt = optimize_prompt(model, preprocess, args, device, target_images=[target_image])
32
-
33
  return learned_prompt
34
-
35
  def inference_text(target_prompt, prompt_len, iter):
 
 
 
36
  if prompt_len is not None:
37
  args.prompt_len = min(int(prompt_len), 75)
38
  else:
@@ -85,7 +92,7 @@ with demo:
85
  image_button.click(inference, inputs=[input_image, prompt_len_field, num_step_field], outputs=output_prompt)
86
  prompt_button.click(inference_text, inputs=[input_prompt, prompt_len_field, num_step_field], outputs=output_prompt)
87
 
88
- gr.Examples([["sample.jpeg", 8, 1000]], inputs=[input_image, prompt_len_field, num_step_field])
89
- gr.Examples([["digital concept art of old wooden cabin in florida swamp, trending on artstation", 3, 1000]], inputs=[input_prompt, prompt_len_field, num_step_field])
90
 
91
  demo.launch(enable_queue=True)
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  model, _, preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)
18
 
19
+ args.counter = 0
20
 
21
  def inference(target_image, prompt_len, iter):
22
+ args.counter += 1
23
+ print(args.counter)
24
+
25
  if prompt_len is not None:
26
  args.prompt_len = int(prompt_len)
27
  else:
 
33
  args.iter = 1000
34
 
35
  learned_prompt = optimize_prompt(model, preprocess, args, device, target_images=[target_image])
36
+
37
  return learned_prompt
38
+
39
  def inference_text(target_prompt, prompt_len, iter):
40
+ args.counter += 1
41
+ print(args.counter)
42
+
43
  if prompt_len is not None:
44
  args.prompt_len = min(int(prompt_len), 75)
45
  else:
 
92
  image_button.click(inference, inputs=[input_image, prompt_len_field, num_step_field], outputs=output_prompt)
93
  prompt_button.click(inference_text, inputs=[input_prompt, prompt_len_field, num_step_field], outputs=output_prompt)
94
 
95
+ gr.Examples([["sample.jpeg", 8, 1000]], inputs=[input_image, prompt_len_field, num_step_field], fn=inference, outputs=output_prompt, cache_examples=True)
96
+ gr.Examples([["digital concept art of old wooden cabin in florida swamp, trending on artstation", 3, 1000]], inputs=[input_prompt, prompt_len_field, num_step_field], fn=inference_text, outputs=output_prompt, cache_examples=True)
97
 
98
  demo.launch(enable_queue=True)