Update app.py
Browse files
app.py
CHANGED
@@ -19,6 +19,12 @@ general_tags = df[df["category"] == 0].index
|
|
19 |
character_tags = df[df["category"] == 4].index
|
20 |
all_tags = df.index
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def predict(target_tags, search_in):
|
23 |
target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")]
|
24 |
target_ids = [label2id[tag] for tag in target_tags]
|
@@ -26,7 +32,7 @@ def predict(target_tags, search_in):
|
|
26 |
|
27 |
sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0)
|
28 |
tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags
|
29 |
-
return {id2label[i]: sim[i].item() for i in tags}
|
30 |
|
31 |
demo = gr.Interface(
|
32 |
fn=predict,
|
|
|
19 |
character_tags = df[df["category"] == 4].index
|
20 |
all_tags = df.index
|
21 |
|
22 |
+
tag_pair_df = pd.read_parquet("hf://datasets/p1atdev/danbooru-ja-tag-pair-20241015/data/train-00000-of-00001.parquet")
|
23 |
+
tag_pair = {title:other_names[0] for title, other_names in zip(tag_pair_df["title"], tag_pair_df["other_names"])}
|
24 |
+
for tag in df["name"]:
|
25 |
+
if tag not in tag_pair:
|
26 |
+
tag_pair[tag] = ""
|
27 |
+
|
28 |
def predict(target_tags, search_in):
|
29 |
target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")]
|
30 |
target_ids = [label2id[tag] for tag in target_tags]
|
|
|
32 |
|
33 |
sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0)
|
34 |
tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags
|
35 |
+
return {f"{id2label[i]}({tag_pair[id2label[i]]})": sim[i].item() for i in tags}
|
36 |
|
37 |
demo = gr.Interface(
|
38 |
fn=predict,
|