John6666 commited on
Commit
31cedff
1 Parent(s): c6d9f2f

Upload 3 files

Browse files
Files changed (2) hide show
  1. model.py +2 -2
  2. multit2i.py +26 -5
model.py CHANGED
@@ -14,9 +14,9 @@ models = [
14
  'digiplay/majicMIX_realistic_v7',
15
  'votepurchase/counterfeitV30_v30',
16
  'Meina/MeinaMix_V11',
17
- 'John6666/cute-illustration-style-reinforced-model-v61-sd15',
18
  'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
19
- 'kayfahaarukku/UrangDiffusion-1.1',
 
20
  'Eugeoter/artiwaifu-diffusion-1.0',
21
  'Raelina/Rae-Diffusion-XL-V2',
22
  'Raelina/Raemu-XL-V4',
 
14
  'digiplay/majicMIX_realistic_v7',
15
  'votepurchase/counterfeitV30_v30',
16
  'Meina/MeinaMix_V11',
 
17
  'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
18
+ 'KBlueLeaf/Kohaku-XL-Zeta',
19
+ 'kayfahaarukku/UrangDiffusion-1.2',
20
  'Eugeoter/artiwaifu-diffusion-1.0',
21
  'Raelina/Rae-Diffusion-XL-V2',
22
  'Raelina/Raemu-XL-V4',
multit2i.py CHANGED
@@ -33,22 +33,43 @@ def is_repo_name(s):
33
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
34
 
35
 
36
- def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  from huggingface_hub import HfApi
38
  api = HfApi()
39
  default_tags = ["diffusers"]
40
  if not sort: sort = "last_modified"
 
41
  models = []
42
  try:
43
- model_infos = api.list_models(author=author, pipeline_tag="text-to-image", token=HF_TOKEN,
44
- tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit * 5)
45
  except Exception as e:
46
  print(f"Error: Failed to list models.")
47
  print(e)
48
  return models
49
  for model in model_infos:
50
- if not model.private and not model.gated and HF_TOKEN is None:
51
- if not_tag and not_tag in model.tags: continue
 
52
  models.append(model.id)
53
  if len(models) == limit: break
54
  return models
 
33
  return re.fullmatch(r'^[^/]+?/[^/]+?$', s)
34
 
35
 
36
+ def get_status(model_name: str):
37
+ from huggingface_hub import InferenceClient
38
+ client = InferenceClient(timeout=10)
39
+ return client.get_model_status(model_name)
40
+
41
+
42
+ def is_loadable(model_name: str, force_gpu: bool = False):
43
+ try:
44
+ status = get_status(model_name)
45
+ except Exception as e:
46
+ print(e)
47
+ print(f"Couldn't load {model_name}.")
48
+ return False
49
+ gpu_state = isinstance(status.compute_type, dict) and "gpu" in status.compute_type.keys()
50
+ if status is None or status.state not in ["Loadable", "Loaded"] or (force_gpu and not gpu_state):
51
+ print(f"Couldn't load {model_name}. Model state:'{status.state}', GPU:{gpu_state}")
52
+ return status is not None and status.state in ["Loadable", "Loaded"] and (not force_gpu or gpu_state)
53
+
54
+
55
+ def find_model_list(author: str="", tags: list[str]=[], not_tag="", sort: str="last_modified", limit: int=30, force_gpu=False, check_status=False):
56
  from huggingface_hub import HfApi
57
  api = HfApi()
58
  default_tags = ["diffusers"]
59
  if not sort: sort = "last_modified"
60
+ limit = limit * 20 if check_status and force_gpu else limit * 5
61
  models = []
62
  try:
63
+ model_infos = api.list_models(author=author, task="text-to-image",
64
+ tags=list_uniq(default_tags + tags), cardData=True, sort=sort, limit=limit)
65
  except Exception as e:
66
  print(f"Error: Failed to list models.")
67
  print(e)
68
  return models
69
  for model in model_infos:
70
+ if not model.private and not model.gated:
71
+ loadable = is_loadable(model.id, force_gpu) if check_status else True
72
+ if not_tag and not_tag in model.tags or not loadable: continue
73
  models.append(model.id)
74
  if len(models) == limit: break
75
  return models