furusu commited on
Commit
ddad67b
·
verified ·
1 Parent(s): 11de6c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -19,6 +19,12 @@ general_tags = df[df["category"] == 0].index
19
  character_tags = df[df["category"] == 4].index
20
  all_tags = df.index
21
 
 
 
 
 
 
 
22
  def predict(target_tags, search_in):
23
  target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")]
24
  target_ids = [label2id[tag] for tag in target_tags]
@@ -26,7 +32,7 @@ def predict(target_tags, search_in):
26
 
27
  sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0)
28
  tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags
29
- return {id2label[i]: sim[i].item() for i in tags}
30
 
31
  demo = gr.Interface(
32
  fn=predict,
 
19
  character_tags = df[df["category"] == 4].index
20
  all_tags = df.index
21
 
22
+ tag_pair_df = pd.read_parquet("hf://datasets/p1atdev/danbooru-ja-tag-pair-20241015/data/train-00000-of-00001.parquet")
23
+ tag_pair = {title:other_names[0] for title, other_names in zip(tag_pair_df["title"], tag_pair_df["other_names"])}
24
+ for tag in df["name"]:
25
+ if tag not in tag_pair:
26
+ tag_pair[tag] = ""
27
+
28
  def predict(target_tags, search_in):
29
  target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")]
30
  target_ids = [label2id[tag] for tag in target_tags]
 
32
 
33
  sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0)
34
  tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags
35
+ return {f"{id2label[i]}({tag_pair[id2label[i]]})": sim[i].item() for i in tags}
36
 
37
  demo = gr.Interface(
38
  fn=predict,