nirajandhakal commited on
Commit
39f57d1
·
verified ·
1 Parent(s): 5e1f3c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -6,11 +6,17 @@ import torch
6
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
7
  from PIL import Image
8
  import io
 
 
 
9
 
10
  dtype = torch.bfloat16
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(device)
 
 
 
14
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
@@ -92,7 +98,8 @@ body {
92
  }
93
  """
94
 
95
- with gr.Blocks(css=css) as demo:
 
96
 
97
  with gr.Column(elem_id="col-container"):
98
  gr.Markdown(f"""# FLUX.1 [dev]
@@ -105,7 +112,8 @@ with gr.Blocks(css=css) as demo:
105
  </div>
106
  </a>
107
  """)
108
-
 
109
  with gr.Row():
110
  prompt = gr.Text(
111
  label="Prompt",
@@ -225,4 +233,5 @@ with gr.Blocks(css=css) as demo:
225
  """
226
  )
227
 
 
228
  demo.launch()
 
6
  from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
7
  from PIL import Image
8
  import io
9
+ import os
10
+ import subprocess
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
  dtype = torch.bfloat16
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
+ huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
17
+
18
+
19
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype, token = huggingface_token).to(device)
20
 
21
  MAX_SEED = np.iinfo(np.int32).max
22
  MAX_IMAGE_SIZE = 2048
 
98
  }
99
  """
100
 
101
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
102
+ gr.HTML(title)
103
 
104
  with gr.Column(elem_id="col-container"):
105
  gr.Markdown(f"""# FLUX.1 [dev]
 
112
  </div>
113
  </a>
114
  """)
115
+
116
+
117
  with gr.Row():
118
  prompt = gr.Text(
119
  label="Prompt",
 
233
  """
234
  )
235
 
236
+
237
  demo.launch()