veasnakao commited on
Commit
18e3498
·
verified ·
1 Parent(s): b1c3546

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -4,6 +4,8 @@ import random
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
 
 
7
 
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -13,28 +15,47 @@ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", tor
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @spaces.GPU()
17
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
  if randomize_seed:
19
  seed = random.randint(0, MAX_SEED)
20
  generator = torch.Generator().manual_seed(seed)
21
  image = pipe(
22
- prompt = prompt,
23
- width = width,
24
- height = height,
25
- num_inference_steps = num_inference_steps,
26
- generator = generator,
27
  guidance_scale=0.0
28
  ).images[0]
29
- return image, seed
30
-
 
 
 
 
 
 
 
31
  examples = [
32
  "a tiny astronaut hatching from an egg on the moon",
33
  "a cat holding a sign that says hello world",
34
  "an anime illustration of a wiener schnitzel",
35
  ]
36
 
37
- css="""
38
  #col-container {
39
  margin: 0 auto;
40
  max-width: 520px;
@@ -62,6 +83,7 @@ with gr.Blocks(css=css) as demo:
62
  run_button = gr.Button("Run", scale=0)
63
 
64
  result = gr.Image(label="Result", show_label=False)
 
65
 
66
  with gr.Accordion("Advanced Settings", open=False):
67
 
@@ -95,7 +117,6 @@ with gr.Blocks(css=css) as demo:
95
 
96
  with gr.Row():
97
 
98
-
99
  num_inference_steps = gr.Slider(
100
  label="Number of inference steps",
101
  minimum=1,
@@ -105,18 +126,18 @@ with gr.Blocks(css=css) as demo:
105
  )
106
 
107
  gr.Examples(
108
- examples = examples,
109
- fn = infer,
110
- inputs = [prompt],
111
- outputs = [result, seed],
112
  cache_examples="lazy"
113
  )
114
 
115
  gr.on(
116
  triggers=[run_button.click, prompt.submit],
117
- fn = infer,
118
- inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
119
- outputs = [result, seed]
120
  )
121
 
122
- demo.launch()
 
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
+ import boto3
8
+ from io import BytesIO
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
17
 
18
+ # Initialize S3 client
19
+ s3_client = boto3.client('s3')
20
+ BUCKET_NAME = 'your-s3-bucket-name' # Replace with your S3 bucket name
21
+
22
+ def upload_to_s3(image, image_name):
23
+ """Upload an image to S3 bucket."""
24
+ buffer = BytesIO()
25
+ image.save(buffer, format="PNG")
26
+ buffer.seek(0)
27
+ s3_client.put_object(Bucket=BUCKET_NAME, Key=image_name, Body=buffer, ContentType='image/png')
28
+ return f"https://{BUCKET_NAME}.s3.amazonaws.com/{image_name}"
29
+
30
  @spaces.GPU()
31
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
32
  if randomize_seed:
33
  seed = random.randint(0, MAX_SEED)
34
  generator = torch.Generator().manual_seed(seed)
35
  image = pipe(
36
+ prompt=prompt,
37
+ width=width,
38
+ height=height,
39
+ num_inference_steps=num_inference_steps,
40
+ generator=generator,
41
  guidance_scale=0.0
42
  ).images[0]
43
+
44
+ # Generate a unique name for the image
45
+ image_name = f"{seed}_{prompt[:10]}.png"
46
+
47
+ # Upload image to S3
48
+ s3_url = upload_to_s3(image, image_name)
49
+
50
+ return image, seed, s3_url
51
+
52
  examples = [
53
  "a tiny astronaut hatching from an egg on the moon",
54
  "a cat holding a sign that says hello world",
55
  "an anime illustration of a wiener schnitzel",
56
  ]
57
 
58
+ css = """
59
  #col-container {
60
  margin: 0 auto;
61
  max-width: 520px;
 
83
  run_button = gr.Button("Run", scale=0)
84
 
85
  result = gr.Image(label="Result", show_label=False)
86
+ s3_link = gr.Text(label="S3 URL", show_label=False)
87
 
88
  with gr.Accordion("Advanced Settings", open=False):
89
 
 
117
 
118
  with gr.Row():
119
 
 
120
  num_inference_steps = gr.Slider(
121
  label="Number of inference steps",
122
  minimum=1,
 
126
  )
127
 
128
  gr.Examples(
129
+ examples=examples,
130
+ fn=infer,
131
+ inputs=[prompt],
132
+ outputs=[result, seed, s3_link],
133
  cache_examples="lazy"
134
  )
135
 
136
  gr.on(
137
  triggers=[run_button.click, prompt.submit],
138
+ fn=infer,
139
+ inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
140
+ outputs=[result, seed, s3_link]
141
  )
142
 
143
+ demo.launch()