radames commited on
Commit
3280cce
·
1 Parent(s): 03258aa
Files changed (1) hide show
  1. app.py +26 -9
app.py CHANGED
@@ -40,6 +40,7 @@ CLASSIFIER_URL = (
40
  )
41
  ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
42
 
 
43
 
44
  s3 = boto3.client(
45
  service_name="s3",
@@ -120,12 +121,22 @@ def get_yaml_data(text_content):
120
  print(exc)
121
  return {}
122
 
123
- async def find_image_in_model_card(text):
124
- image_regex = re.compile(r"https?://\S+(?:png|jpg|jpeg|webp)")
125
- urls = re.findall(image_regex, text)
126
- if not urls:
 
 
 
 
 
 
 
 
 
127
  return []
128
 
 
129
  async with aiohttp.ClientSession() as session:
130
  tasks = [
131
  asyncio.ensure_future(upload_resize_image_url(session, image_url))
@@ -188,9 +199,10 @@ async def sync_data():
188
  with open(DB_FOLDER / "models.json", "w") as f:
189
  json.dump(all_models, f)
190
  # with open(DB_FOLDER / "models.json", "r") as f:
191
- # new_models = json.load(f)
192
 
193
  new_models_ids = [model["id"] for model in all_models]
 
194
 
195
  # get existing models
196
  with database.get_db() as db:
@@ -212,7 +224,7 @@ async def sync_data():
212
  print("Parsing model card")
213
  model_card_data = get_yaml_data(model_card)
214
  print("Finding images in model card")
215
- images = await find_image_in_model_card(model_card)
216
 
217
  classifier = run_classifier(images)
218
  print(images, classifier)
@@ -257,7 +269,7 @@ async def sync_data():
257
  print("Parsing model card")
258
  model_card_data = get_yaml_data(model_card)
259
  print("Finding images in model card")
260
- images = await find_image_in_model_card(model_card)
261
  classifier = run_classifier(images)
262
  model_data["images"] = images
263
  model_data["class"] = classifier
@@ -322,6 +334,7 @@ class Style(str, Enum):
322
  s3D = "3d"
323
  realistic = "realistic"
324
  nsfw = "nsfw"
 
325
 
326
 
327
  @app.get("/api/models")
@@ -344,9 +357,13 @@ def get_page(
344
  style_query = "json_extract(data, '$.class.3d') > 0.1 AND isNFSW = false"
345
  elif style == Style.realistic:
346
  style_query = "json_extract(data, '$.class.real_life') > 0.1 AND isNFSW = false"
 
 
347
  elif style == Style.nsfw:
348
  style_query = "isNFSW = true"
349
 
 
 
350
  with database.get_db() as db:
351
  cursor = db.cursor()
352
  cursor.execute(
@@ -359,7 +376,7 @@ def get_page(
359
  json_extract(data, '$.class.explicit') > 0.3 OR json_extract(data, '$.class.suggestive') > 0.3 AS isNFSW
360
  FROM models
361
  ) AS subquery
362
- WHERE (? IS NULL AND likes > 3 OR ? IS NOT NULL)
363
  AND {style_query}
364
  AND (? IS NULL OR EXISTS (
365
  SELECT 1
@@ -368,7 +385,7 @@ def get_page(
368
  ))
369
  ORDER BY {sort_query}
370
  LIMIT {MAX_PAGE_SIZE} OFFSET {(page - 1) * MAX_PAGE_SIZE};
371
- """,
372
  (tag, tag, tag, tag),
373
  )
374
  results = cursor.fetchall()
 
40
  )
41
  ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
42
 
43
+ BLOCKED_MODELS_REGEX = re.compile(r"(CyberHarem)", re.IGNORECASE)
44
 
45
  s3 = boto3.client(
46
  service_name="s3",
 
121
  print(exc)
122
  return {}
123
 
124
+ async def find_image_in_model_card(text, model_id):
125
+ base_url = f"https://huggingface.co/{model_id}/resolve/main/"
126
+ image_regex = re.compile(r"!\[.*\]\((.*?\.(png|jpg|jpeg|gif|bmp|webp))\)|src=\"(.*?\.(png|jpg|jpeg|gif|bmp|webp))\">", re.IGNORECASE)
127
+ matches = image_regex.findall(text)
128
+ urls = []
129
+ for match in matches:
130
+ for url in match:
131
+ if url:
132
+ if not url.startswith("http") and not url.startswith("https"):
133
+ url = base_url + url
134
+ urls.append(url)
135
+
136
+ if len(urls) == 0:
137
  return []
138
 
139
+ print(urls)
140
  async with aiohttp.ClientSession() as session:
141
  tasks = [
142
  asyncio.ensure_future(upload_resize_image_url(session, image_url))
 
199
  with open(DB_FOLDER / "models.json", "w") as f:
200
  json.dump(all_models, f)
201
  # with open(DB_FOLDER / "models.json", "r") as f:
202
+ # all_models = json.load(f)
203
 
204
  new_models_ids = [model["id"] for model in all_models]
205
+ new_models_ids = [model_id for model_id in new_models_ids if not re.match(BLOCKED_MODELS_REGEX, model_id)]
206
 
207
  # get existing models
208
  with database.get_db() as db:
 
224
  print("Parsing model card")
225
  model_card_data = get_yaml_data(model_card)
226
  print("Finding images in model card")
227
+ images = await find_image_in_model_card(model_card, model_id)
228
 
229
  classifier = run_classifier(images)
230
  print(images, classifier)
 
269
  print("Parsing model card")
270
  model_card_data = get_yaml_data(model_card)
271
  print("Finding images in model card")
272
+ images = await find_image_in_model_card(model_card, model_id)
273
  classifier = run_classifier(images)
274
  model_data["images"] = images
275
  model_data["class"] = classifier
 
334
  s3D = "3d"
335
  realistic = "realistic"
336
  nsfw = "nsfw"
337
+ lora = "lora"
338
 
339
 
340
  @app.get("/api/models")
 
357
  style_query = "json_extract(data, '$.class.3d') > 0.1 AND isNFSW = false"
358
  elif style == Style.realistic:
359
  style_query = "json_extract(data, '$.class.real_life') > 0.1 AND isNFSW = false"
360
+ elif style == Style.lora:
361
+ style_query = "json_extract(data, '$.meta.tags') LIKE '%lora%' AND isNFSW = false"
362
  elif style == Style.nsfw:
363
  style_query = "isNFSW = true"
364
 
365
+
366
+
367
  with database.get_db() as db:
368
  cursor = db.cursor()
369
  cursor.execute(
 
376
  json_extract(data, '$.class.explicit') > 0.3 OR json_extract(data, '$.class.suggestive') > 0.3 AS isNFSW
377
  FROM models
378
  ) AS subquery
379
+ WHERE (? IS NULL AND likes > 1 OR ? IS NOT NULL)
380
  AND {style_query}
381
  AND (? IS NULL OR EXISTS (
382
  SELECT 1
 
385
  ))
386
  ORDER BY {sort_query}
387
  LIMIT {MAX_PAGE_SIZE} OFFSET {(page - 1) * MAX_PAGE_SIZE};
388
+ """,
389
  (tag, tag, tag, tag),
390
  )
391
  results = cursor.fetchall()