Yemin Shi commited on
Commit
405afc3
1 Parent(s): 61ffabe

update APIs

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +42 -26
.DS_Store DELETED
Binary file (6.15 kB)
 
app.py CHANGED
@@ -11,7 +11,6 @@ from css_and_js import js, call_JS
11
  from PIL import Image, PngImagePlugin, ImageChops
12
 
13
  url_host = "http://flagstudio.baai.ac.cn"
14
- #token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiZjAxOGMxMzJiYTUyNDBjMzk5NTMzYTI5YjBmMzZiODMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6IjVjMmQzMjdiLWI5Y2MtNDhiZS1hZWQ4LTllMjQ4MDk4NzMxYyIsIm5iZiI6MTY2OTAwNjE5NywiZXhwIjoxOTg0MzY2MTk3LCJpYXQiOjE2NjkwMDYxOTd9.9B3MDk8wA6iWH5puXjcD19tJJ4Ox7mdpRyWZs5Kwt70"
15
  token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMGY4M2QxMDg3N2MzMTFlZGFiYzYwZmU5ZGFjMTI1ZDMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6ImE3YTE1N2I3LTllNTItNDllMS04YzA0LWEzZmI5YjZiZjNlYSIsIm5iZiI6MTY3MDU5MTcwMSwiZXhwIjoxOTg1OTUxNzAxLCJpYXQiOjE2NzA1OTE3MDF9.OcfGayna-wr_5mo4LT6OJHSCokna8vqKSmmCftFUsx8"
16
 
17
  def read_content(file_path: str) -> str:
@@ -58,7 +57,6 @@ def upload_image(img):
58
  def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
59
  data = {
60
  "type": "gen-image",
61
- "gen_image_num": image_num,
62
  "parameters": {
63
  "width": width, # output height width
64
  "height": height, # output image height
@@ -87,35 +85,47 @@ def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
87
  "height": mask.height,
88
  }
89
  headers = {"token": token}
 
90
  # Send create task request
91
- # url = "http://flagstudio.baai.ac.cn/api/v1/task/create"
92
  url = url_host+"/api/v1/task/create"
93
- r = requests.post(url, json=data, headers=headers)
94
- if r.status_code != 200:
95
- raise gr.Error(r.reason)
96
- create_res = r.json()
97
- task_id = create_res["data"]["task_id"]
 
 
 
 
 
98
 
99
  # Get result
100
  url = url_host+"/api/v1/task/status"
 
101
  while True:
102
- r = requests.post(url, json=create_res["data"], headers=headers)
103
- if r.status_code != 200:
104
- raise gr.Error(r.reason)
105
- res = r.json()
106
- if res["code"] == 6002:
107
- # Running
108
- time.sleep(1)
109
- continue
110
- elif res["code"] == 0:
111
- # Finished
112
- images = []
113
- for img_info in res["data"]["images"]:
114
- img_res = requests.get(img_info["url"])
115
- images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
116
  return images
117
- else:
118
- raise gr.Error(f"Error code: {res['code']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
121
  if filter_content(class_draw) != "鍥界敾":
@@ -134,7 +144,13 @@ def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
134
 
135
 
136
  def img2img(prompt, image_and_mask):
137
- return post_reqest(0, prompt, 512, 512, 1, image_and_mask["image"], image_and_mask["mask"])
 
 
 
 
 
 
138
 
139
 
140
  examples = [
@@ -311,4 +327,4 @@ if __name__ == "__main__":
311
  gr.HTML(read_content("footer.html"))
312
  # gr.Image('./contributors.png')
313
 
314
- block.queue(max_size=50, concurrency_count=20).launch()
 
11
  from PIL import Image, PngImagePlugin, ImageChops
12
 
13
  url_host = "http://flagstudio.baai.ac.cn"
 
14
  token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyX2lkIjoiMGY4M2QxMDg3N2MzMTFlZGFiYzYwZmU5ZGFjMTI1ZDMiLCJhcHBfbmFtZSI6IndlYiIsImlkZW50aXR5X3R5cGUiOiIyIiwidXNlcl9yb2xlIjoiMiIsImp0aSI6ImE3YTE1N2I3LTllNTItNDllMS04YzA0LWEzZmI5YjZiZjNlYSIsIm5iZiI6MTY3MDU5MTcwMSwiZXhwIjoxOTg1OTUxNzAxLCJpYXQiOjE2NzA1OTE3MDF9.OcfGayna-wr_5mo4LT6OJHSCokna8vqKSmmCftFUsx8"
15
 
16
  def read_content(file_path: str) -> str:
 
57
  def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
58
  data = {
59
  "type": "gen-image",
 
60
  "parameters": {
61
  "width": width, # output height width
62
  "height": height, # output image height
 
85
  "height": mask.height,
86
  }
87
  headers = {"token": token}
88
+
89
  # Send create task request
90
+ all_task_data = []
91
  url = url_host+"/api/v1/task/create"
92
+ for _ in range(image_num):
93
+ r = requests.post(url, json=data, headers=headers)
94
+ if r.status_code != 200:
95
+ raise gr.Error(r.reason)
96
+ create_res = r.json()
97
+ if create_res['code'] == 3002:
98
+ raise gr.Error("Inappropriate prompt detected.")
99
+ elif create_res['code'] != 0:
100
+ raise gr.Error("Unknown error")
101
+ all_task_data.append(create_res["data"])
102
 
103
  # Get result
104
  url = url_host+"/api/v1/task/status"
105
+ images = []
106
  while True:
107
+ if len(all_task_data) <= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return images
109
+ for i in range(len(all_task_data)-1, -1, -1):
110
+ data = all_task_data[i]
111
+ r = requests.post(url, json=data, headers=headers)
112
+ if r.status_code != 200:
113
+ raise gr.Error(r.reason)
114
+ res = r.json()
115
+ if res["code"] == 6002:
116
+ # Running
117
+ continue
118
+ if res["code"] == 6005:
119
+ raise gr.Error("NSFW image detected.")
120
+ elif res["code"] == 0:
121
+ # Finished
122
+ for img_info in res["data"]["images"]:
123
+ img_res = requests.get(img_info["url"])
124
+ images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
125
+ del all_task_data[i]
126
+ else:
127
+ raise gr.Error(f"Error code: {res['code']}")
128
+ time.sleep(1)
129
 
130
  def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
131
  if filter_content(class_draw) != "鍥界敾":
 
144
 
145
 
146
  def img2img(prompt, image_and_mask):
147
+ if image_and_mask["image"].width <= image_and_mask["image"].height:
148
+ width = 512
149
+ height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height)
150
+ else:
151
+ height = 512
152
+ width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width)
153
+ return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"])
154
 
155
 
156
  examples = [
 
327
  gr.HTML(read_content("footer.html"))
328
  # gr.Image('./contributors.png')
329
 
330
+ block.queue(max_size=100, concurrency_count=50).launch()