top001 commited on
Commit
70c9623
·
verified ·
1 Parent(s): a138a91

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -8
app.py CHANGED
@@ -9,26 +9,22 @@ import onnxruntime as ort
9
  from PIL import Image
10
  from huggingface_hub import hf_hub_download
11
 
12
-
13
  def _yield_tags_from_txt_file(txt_file: str):
14
  with open(txt_file, 'r') as f:
15
  for line in f:
16
  if line:
17
  yield line.strip()
18
 
19
-
20
  @lru_cache()
21
  def get_deepdanbooru_tags() -> List[str]:
22
  tags_file = hf_hub_download('chinoll/deepdanbooru', 'tags.txt')
23
  return list(_yield_tags_from_txt_file(tags_file))
24
 
25
-
26
  @lru_cache()
27
  def get_deepdanbooru_onnx() -> ort.InferenceSession:
28
  onnx_file = hf_hub_download('chinoll/deepdanbooru', 'deepdanbooru.onnx')
29
  return ort.InferenceSession(onnx_file)
30
 
31
-
32
  def image_preprocess(image: Image.Image) -> np.ndarray:
33
  if image.mode != 'RGB':
34
  image = image.convert('RGB')
@@ -49,10 +45,8 @@ def image_preprocess(image: Image.Image) -> np.ndarray:
49
  assert data.shape == (512, 512, 3), f'Shape (512, 512, 3) expected, but {data.shape!r} found.'
50
  return data.reshape((1, 512, 512, 3)) # B x H x W x C
51
 
52
-
53
  RE_SPECIAL = re.compile(r'([\\()])')
54
 
55
-
56
  def image_to_deepdanbooru_tags(image: Image.Image, threshold: float,
57
  use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
58
  -> Tuple[str, Mapping[str, float]]:
@@ -84,8 +78,12 @@ def image_to_deepdanbooru_tags(image: Image.Image, threshold: float,
84
 
85
  return output_text, filtered_tags
86
 
87
-
88
  if __name__ == '__main__':
 
 
 
 
 
89
  with gr.Blocks() as demo:
90
  with gr.Row():
91
  with gr.Column():
@@ -111,4 +109,33 @@ if __name__ == '__main__':
111
  inputs=[gr_input_image, gr_threshold, gr_space, gr_escape, gr_confidence, gr_order],
112
  outputs=[gr_output_text, gr_tags],
113
  )
114
- demo.queue(os.cpu_count()).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from PIL import Image
10
  from huggingface_hub import hf_hub_download
11
 
 
12
  def _yield_tags_from_txt_file(txt_file: str):
13
  with open(txt_file, 'r') as f:
14
  for line in f:
15
  if line:
16
  yield line.strip()
17
 
 
18
  @lru_cache()
19
  def get_deepdanbooru_tags() -> List[str]:
20
  tags_file = hf_hub_download('chinoll/deepdanbooru', 'tags.txt')
21
  return list(_yield_tags_from_txt_file(tags_file))
22
 
 
23
  @lru_cache()
24
  def get_deepdanbooru_onnx() -> ort.InferenceSession:
25
  onnx_file = hf_hub_download('chinoll/deepdanbooru', 'deepdanbooru.onnx')
26
  return ort.InferenceSession(onnx_file)
27
 
 
28
  def image_preprocess(image: Image.Image) -> np.ndarray:
29
  if image.mode != 'RGB':
30
  image = image.convert('RGB')
 
45
  assert data.shape == (512, 512, 3), f'Shape (512, 512, 3) expected, but {data.shape!r} found.'
46
  return data.reshape((1, 512, 512, 3)) # B x H x W x C
47
 
 
48
  RE_SPECIAL = re.compile(r'([\\()])')
49
 
 
50
  def image_to_deepdanbooru_tags(image: Image.Image, threshold: float,
51
  use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
52
  -> Tuple[str, Mapping[str, float]]:
 
78
 
79
  return output_text, filtered_tags
80
 
 
81
  if __name__ == '__main__':
82
+ import io
83
+ from fastapi import FastAPI, File, UploadFile
84
+ from fastapi.responses import JSONResponse
85
+ from fastapi.middleware.cors import CORSMiddleware
86
+
87
  with gr.Blocks() as demo:
88
  with gr.Row():
89
  with gr.Column():
 
109
  inputs=[gr_input_image, gr_threshold, gr_space, gr_escape, gr_confidence, gr_order],
110
  outputs=[gr_output_text, gr_tags],
111
  )
112
+
113
+ # Get the FastAPI app from Gradio Blocks
114
+ app = demo.app
115
+
116
+ # Allow cross-origin requests (optional, useful for testing)
117
+ origins = ["*"]
118
+ app.add_middleware(
119
+ CORSMiddleware,
120
+ allow_origins=origins,
121
+ allow_methods=["*"],
122
+ allow_headers=["*"],
123
+ )
124
+
125
+ @app.post("/api/analyze_image")
126
+ async def analyze_image(file: UploadFile = File(...)):
127
+ contents = await file.read()
128
+ image = Image.open(io.BytesIO(contents))
129
+ output_text, filtered_tags = image_to_deepdanbooru_tags(
130
+ image,
131
+ threshold=0.5,
132
+ use_spaces=False,
133
+ use_escape=True,
134
+ include_ranks=False,
135
+ score_descend=True
136
+ )
137
+ return JSONResponse(content=filtered_tags)
138
+
139
+ # Launch the Gradio app
140
+ demo.queue(concurrency_count=os.cpu_count()).launch(server_name="0.0.0.0")
141
+