animrods commited on
Commit
f70143b
·
verified ·
1 Parent(s): c74635b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -13
app.py CHANGED
@@ -1,21 +1,33 @@
 
 
 
1
  import gradio as gr
2
  import os
3
- hf_token = os.environ.get("HF_TOKEN")
4
  import spaces
5
- import torch
6
- from pipeline_bria import BriaPipeline, BriaTransformer2DModel
7
  import time
 
 
 
 
 
 
 
 
8
 
9
  resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280"]
10
 
11
  # Ng
12
- default_negative_prompt= "Logo,Watermark,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
13
 
14
- transformer = BriaTransformer2DModel.from_pretrained("briaai/BRIA-3.2",subfolder='transformer',torch_dtype=torch.bfloat16)
15
- pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.1", transformer=transformer, torch_dtype=torch.bfloat16,trust_remote_code=True)
16
- pipe.to(device="cuda")
 
17
 
18
- @spaces.GPU(enable_queue=True)
19
  def infer(prompt,negative_prompt,seed,resolution):
20
  print(f"""
21
  —/n
@@ -30,13 +42,34 @@ def infer(prompt,negative_prompt,seed,resolution):
30
  else:
31
  try:
32
  seed=int(seed)
33
- generator = torch.Generator("cuda").manual_seed(seed)
34
  except:
35
  generator=None
36
 
37
  w,h = resolution.split()
38
  w,h = int(w),int(h)
39
- image = pipe(prompt,num_inference_steps=30, negative_prompt=negative_prompt,generator=generator,width=w,height=h).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  print(f'gen time is {time.time()-t} secs')
41
 
42
  # Future
@@ -44,7 +77,7 @@ def infer(prompt,negative_prompt,seed,resolution):
44
  # if nsfw:
45
  # raise gr.Error("Generated image is NSFW")
46
 
47
- return image
48
 
49
  css = """
50
  #col-container{
@@ -54,7 +87,7 @@ css = """
54
  """
55
  with gr.Blocks(css=css) as demo:
56
  with gr.Column(elem_id="col-container"):
57
- gr.Markdown("## BRIA 3.2")
58
  gr.HTML('''
59
  <p style="margin-bottom: 10px; font-size: 94%">
60
  This is a demo for
@@ -70,7 +103,7 @@ with gr.Blocks(css=css) as demo:
70
  ''')
71
  with gr.Group():
72
  with gr.Column():
73
- prompt_in = gr.Textbox(label="Prompt", value="""photo of mystical dragon eating sushi, text bubble says "Sushi Time".""")
74
  resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
75
  seed = gr.Textbox(label="Seed", value=-1)
76
  negative_prompt = gr.Textbox(label="Negative Prompt", value=default_negative_prompt)
 
1
+ import json
2
+ import requests
3
+ from io import BytesIO
4
  import gradio as gr
5
  import os
6
+ # hf_token = os.environ.get("HF_TOKEN")
7
  import spaces
8
+ # import torch
9
+ # from pipeline_bria import BriaPipeline
10
  import time
11
+ from PIL import Image
12
+
13
+ def download_image(url):
14
+ response = requests.get(url)
15
+ return Image.open(BytesIO(response.content)).convert("RGB")
16
+
17
+ hf_token = os.environ.get("HF_TOKEN_API_DEMO") # we get it from a secret env variable, such that it's private
18
+ auth_headers = {"api_token": hf_token}
19
 
20
  resolutions = ["1024 1024","1280 768","1344 768","768 1344","768 1280"]
21
 
22
  # Ng
23
+ default_negative_prompt= "Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers"
24
 
25
+ # Load pipeline
26
+ # trust_remote_code = True - allows loading a transformer which is not present at the transformers library(from transformer/bria_transformer.py)
27
+ # pipe = BriaPipeline.from_pretrained("briaai/BRIA-3.0-TOUCAN", torch_dtype=torch.bfloat16,trust_remote_code=True)
28
+ # pipe.to(device="cuda")
29
 
30
+ # @spaces.GPU(enable_queue=True)
31
  def infer(prompt,negative_prompt,seed,resolution):
32
  print(f"""
33
  —/n
 
42
  else:
43
  try:
44
  seed=int(seed)
45
+ # generator = torch.Generator("cuda").manual_seed(seed)
46
  except:
47
  generator=None
48
 
49
  w,h = resolution.split()
50
  w,h = int(w),int(h)
51
+ # image = pipe(prompt,num_inference_steps=30, negative_prompt=negative_prompt,generator=generator,width=w,height=h).images[0]
52
+ url = "http://engine.prod.bria-api.com/v1/text-to-image/base/3.2"
53
+
54
+ payload = json.dumps({
55
+ "prompt": prompt,
56
+ "num_results": 1,
57
+ "sync": True,
58
+ "prompt_enhancement": False,
59
+ "debias": False,
60
+ "fast": False,
61
+ "model_influence": 0.0000001,
62
+ "include_generation_prefix": False,
63
+ "negative_prompt": negative_prompt,
64
+ "num_inference_steps": 30,
65
+ "seed": seed
66
+ })
67
+ response = requests.request("POST", url, headers=auth_headers, data=payload)
68
+ print('1',response)
69
+ response = response.json()
70
+ print('2',response)
71
+ res_image = download_image(response["result"][0]['urls'][0])
72
+
73
  print(f'gen time is {time.time()-t} secs')
74
 
75
  # Future
 
77
  # if nsfw:
78
  # raise gr.Error("Generated image is NSFW")
79
 
80
+ return res_image
81
 
82
  css = """
83
  #col-container{
 
87
  """
88
  with gr.Blocks(css=css) as demo:
89
  with gr.Column(elem_id="col-container"):
90
+ gr.Markdown("## BRIA-3.2")
91
  gr.HTML('''
92
  <p style="margin-bottom: 10px; font-size: 94%">
93
  This is a demo for
 
103
  ''')
104
  with gr.Group():
105
  with gr.Column():
106
+ prompt_in = gr.Textbox(label="Prompt", value="A smiling man with wavy brown hair and a trimmed beard")
107
  resolution = gr.Dropdown(value=resolutions[0], show_label=True, label="Resolution", choices=resolutions)
108
  seed = gr.Textbox(label="Seed", value=-1)
109
  negative_prompt = gr.Textbox(label="Negative Prompt", value=default_negative_prompt)