Garrett Goon commited on
Commit
03168a3
·
1 Parent(s): 4e26ea8

running on t4 tests

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -52,30 +52,25 @@ def replace_concept_tokens(text: str):
52
  return text
53
 
54
 
55
- # def inference(
56
- # prompt: str, num_inference_steps: int = 50, guidance_scale: int = 3.0
57
- # ):
58
- # prompt = replace_concept_tokens(prompt)
59
- # for _ in range(3):
60
- # img_list = pipeline(
61
- # prompt=prompt,
62
- # num_inference_steps=num_inference_steps,
63
- # guidance_scale=guidance_scale,
64
- # )
65
- # if not img_list["nsfw_content_detected"]:
66
- # break
67
- # return img_list["sample"]
68
 
69
  DEFAULT_PROMPT = (
70
  "A watercolor painting on textured paper of a <det-logo> using soft strokes,"
71
  " pastel colors, incredible composition, masterpiece"
72
  )
73
 
74
-
75
- def white_imgs(prompt: str, guidance_scale: float, num_inference_steps: int, seed: int):
76
- return [torch.ones(512, 512, 3).numpy() for _ in range(2)]
77
-
78
-
79
  with gr.Blocks() as demo:
80
  prompt = gr.Textbox(
81
  label="Prompt including the token '<det-logo>'",
@@ -110,11 +105,11 @@ with gr.Blocks() as demo:
110
  generate_btn = gr.Button(label="Generate")
111
  gallery = gr.Gallery(
112
  label="Generated Images",
113
- value=[torch.zeros(512, 512, 3).numpy() for _ in range(2)],
114
  ).style(height="auto")
115
 
116
  generate_btn.click(
117
- white_imgs,
118
  inputs=[prompt, guidance_scale, num_inference_steps, seed],
119
  outputs=gallery,
120
  )
 
52
  return text
53
 
54
 
55
+ def inference(
56
+ prompt: str, guidance_scale: int, num_inference_steps: int, seed: int
57
+ ):
58
+ prompt = replace_concept_tokens(prompt)
59
+ generator = torch.Generator(device=device).manual_seed(seed)
60
+ out = pipeline(
61
+ prompt=[prompt] * 2,
62
+ num_inference_steps=num_inference_steps,
63
+ guidance_scale=guidance_scale,
64
+ generator=generator,
65
+ )
66
+ img_list = [item['sample'] for item in out]
67
+ return img_list
68
 
69
  DEFAULT_PROMPT = (
70
  "A watercolor painting on textured paper of a <det-logo> using soft strokes,"
71
  " pastel colors, incredible composition, masterpiece"
72
  )
73
 
 
 
 
 
 
74
  with gr.Blocks() as demo:
75
  prompt = gr.Textbox(
76
  label="Prompt including the token '<det-logo>'",
 
105
  generate_btn = gr.Button(label="Generate")
106
  gallery = gr.Gallery(
107
  label="Generated Images",
108
+ value=[torch.ones(512, 512, 3).numpy() for _ in range(2)],
109
  ).style(height="auto")
110
 
111
  generate_btn.click(
112
+ inference,
113
  inputs=[prompt, guidance_scale, num_inference_steps, seed],
114
  outputs=gallery,
115
  )