E621TagExplorer / functions.py
Specimen5423's picture
Slight adjustments to general and meta colors
6013c6d
import pandas
import numpy
import pandas.io.formats.style
import random
import functools
from typing import Callable, Literal
DATA_FOLDER = "."
CAT_GENERAL = 0
CAT_ARTIST = 1
CAT_UNUSED = 2
CAT_COPYRIGHT = 3
CAT_CHARACTER = 4
CAT_SPECIES = 5
CAT_INVALID = 6
CAT_META = 7
CAT_LORE = 8
CATEGORY_COLORS = {
CAT_GENERAL: "#808080",
CAT_ARTIST: "#f2ac08",
CAT_UNUSED: "#ff3d3d",
CAT_COPYRIGHT: "#d0d",
CAT_CHARACTER: "#0a0",
CAT_SPECIES: "#ed5d1f",
CAT_INVALID: "#ff3d3d",
CAT_META: "#04f",
CAT_LORE: "#282"
}
def get_feather(filename: str) -> pandas.DataFrame:
return pandas.read_feather(f"{DATA_FOLDER}/{filename}.feather")
tags = get_feather("tags")
posts_by_tag = get_feather("posts_by_tag").set_index("tag_id")
tags_by_post = get_feather("tags_by_post").set_index("post_id")
tag_ratings = get_feather("tag_ratings")
implications = get_feather("implications")
tags_by_name = tags.copy(deep=True)
tags_by_name.set_index("name", inplace=True)
tags.set_index("tag_id", inplace=True)
@functools.cache
def get_related_tags(targets: tuple[str, ...], exclude: tuple[str, ...] = (), samples: int = 100_000) -> pandas.DataFrame:
these_tags = tags_by_name.loc[list(targets)]
posts_with_these_tags = posts_by_tag.loc[these_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.intersection(*x))["post_id"][True]
if (len(exclude) > 0):
excluded_tags = tags_by_name.loc[list(exclude)]
posts_with_excluded_tags = posts_by_tag.loc[excluded_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.union(*x))["post_id"][True]
posts_with_these_tags = posts_with_these_tags - posts_with_excluded_tags
total_post_count_together = len(posts_with_these_tags)
sample_posts = random.sample(list(posts_with_these_tags), samples) if total_post_count_together > samples else list(posts_with_these_tags)
post_count_together = len(sample_posts)
sample_ratio = post_count_together / total_post_count_together
tags_in_these_posts = tags_by_post.loc[sample_posts]
counts_in_these_posts = tags_in_these_posts["tag_id"].explode().value_counts().rename("overlap")
summaries = pandas.DataFrame(counts_in_these_posts).join(tags[tags["post_count"]>0], how="right").fillna(0)
summaries["overlap"] = numpy.minimum(summaries["overlap"] / sample_ratio, summaries["post_count"])
summaries = summaries[["category", "name", "overlap", "post_count"]]
# Old "interestingness" value, didn't give as good results as an actual statistical technique, go figure. Code kept for curiosity's sake.
#summaries["interestingness"] = summaries["overlap"].pow(2) / (total_post_count_together * summaries["post_count"])
# Phi coefficient stuff.
n = float(len(tags_by_post))
n11 = summaries["overlap"]
n1x = float(total_post_count_together)
nx1 = summaries["post_count"].astype("float64")
summaries["correlation"] = (n * n11 - n1x * nx1) / numpy.sqrt(n1x * nx1 * (n - n1x) * (n - nx1))
return summaries
def format_tags(styler: pandas.io.formats.style.Styler):
styler.apply(lambda row: numpy.where(row.index == "name", "color:"+CATEGORY_COLORS[row["category"]], ""), axis=1)
styler.hide(level=0)
styler.hide("category",axis=1)
if 'overlap' in styler.data:
styler.format("{:.0f}".format, subset=["overlap"])
if 'correlation' in styler.data:
styler.format("{:.2f}".format, subset=["correlation"])
styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["correlation"])
if 'score' in styler.data:
styler.format("{:.2f}".format, subset=["score"])
styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["score"])
return styler
def related_tags(*targets: str, exclude: tuple[str, ...] = (), category: int = None, samples: int = 100_000, min_overlap: int = 5, min_posts: int = 20, top: int = 30, bottom: int = 0) -> pandas.DataFrame:
result = get_related_tags(targets, exclude=exclude, samples=samples)
if category != None:
result = result[result["category"] == category]
result = result[~result["name"].isin(targets)]
result = result[result["overlap"] >= min_overlap]
result = result[result["post_count"] >= min_posts]
top_part = result.sort_values("correlation", ascending=False)[:top]
bottom_part = result.sort_values("correlation", ascending=True)[:bottom].sort_values("correlation", ascending=False)
return pandas.concat([top_part, bottom_part]).style.pipe(format_tags)
def implications_for(*subjects: str, seen: set[str] = None):
if seen is None:
seen = set()
for subject in subjects:
found = tags.loc[list(implications[implications["antecedent_id"] == tags_by_name.loc[subject, "tag_id"]].loc[:,"consequent_id"]), "name"].values
for f in found:
if f in seen:
pass
else:
yield f
seen.add(f)
yield from implications_for(f, seen=seen)
def parse_tag(potential_tag: str):
potential_tag = potential_tag.strip().replace(" ", "_").replace("\\(", "(").replace("\\)", ")")
if potential_tag == "":
return None
elif potential_tag in tags_by_name.index:
return potential_tag
elif potential_tag.startswith("by_") and potential_tag[3:] in tags_by_name.index:
return potential_tag[3:]
else:
print(f"Couldn't find tag '{potential_tag}', skipping it.")
def parse_tags(*parts: str):
for part in parts:
for potential_tag in part.split(","):
tag = parse_tag(potential_tag)
if tag is not None:
yield tag
def add_suggestions(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples : int, min_posts: int, rating: Literal['s', 'q', 'e']):
if isinstance(new_tags, str):
new_tags = [new_tags]
for new_tag in new_tags:
related = get_related_tags((new_tag,), samples=samples)
# Implementing the rating filter this way is horribly inefficient, fix it later
if rating == 's':
related = related.join(tag_ratings.set_index("tag_id"), on="tag_id")
related["post_count"] = related["s"]
related = related.drop("s", axis=1)
related = related.drop("q", axis=1)
related = related.drop("e", axis=1)
elif rating == 'q':
related = related.join(tag_ratings.set_index("tag_id"), on="tag_id")
related["post_count"] = related["s"] + related["q"]
related = related.drop("s", axis=1)
related = related.drop("q", axis=1)
related = related.drop("e", axis=1)
related = related[related["post_count"] >= min_posts]
if suggestions is None:
suggestions = related.rename(columns={"correlation": "score"})
else:
suggestions = suggestions.join(related, rsuffix="r")
# This is a totally made up way to combine correlations. It keeps them from going outside the +/- 1 range, which is nice. It also makes older
# tags less important every time newer ones are added. That could be considered a feature or not.
suggestions["score"] = numpy.real(numpy.power((numpy.sqrt(suggestions["score"] + 0j) + numpy.sqrt(multiplier * suggestions["correlation"] + 0j)) / 2, 2))
return suggestions[["category", "name", "post_count", "score"]]
def pick_tags(suggestions: pandas.DataFrame, category: int, count: int, from_top: int, excluding: list[str], weighted: bool = True):
options = suggestions[(True if category is None else suggestions["category"] == category) & (suggestions["score"] > 0) & ~suggestions["name"].isin(excluding)].sort_values("score", ascending=False)[:from_top]
if weighted:
values = list(options["name"].values)
weights = list(options["score"].values)
choices = []
for _ in range(count):
choice = random.choices(population=values, weights=weights, k=1)[0]
weights.pop(values.index(choice))
values.remove(choice)
choices.append(choice)
return choices
else:
return random.sample(list(options["name"].values), count)
def tag_to_prompt(tag: str) -> str:
if (tags_by_name.loc[tag]["category"] == CAT_ARTIST):
tag = "by " + tag
return tag.replace("_", " ").replace("(" , "\\(").replace(")" , "\\)")
# A lambda in a for loop doesn't capture variables the way I want it to, so this is a method now
def add_suggestions_later(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples: int, min_posts: int, rating: Literal['s', 'q', 'e']):
return lambda: add_suggestions(suggestions, new_tags, multiplier, samples, min_posts, rating)
Prompt = tuple[list[str], list[str], Callable[[], pandas.DataFrame]]
class PromptBuilder:
prompts: list[Prompt]
samples: int
min_posts: int
rating: Literal['s', 'q', 'e']
skip_list: list[str]
def __init__(self, prompts = [([],[],lambda: None)], skip=[], samples = 100_000, min_posts = 20, rating: Literal['s', 'q', 'e'] = 'e'):
self.prompts = prompts
self.samples = samples
self.min_posts = min_posts
self.rating = rating
self.skip_list = skip
def include(self, tag: str):
return PromptBuilder(prompts=[
(tag_list + [tag], negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in self.prompts
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def focus(self, tag: str):
return PromptBuilder(prompts=[
(tag_list, negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in self.prompts
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def exclude(self, tag: str):
return PromptBuilder(prompts=[
(tag_list, negative_list + [tag], add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in self.prompts
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def avoid(self, tag: str):
return PromptBuilder(prompts=[
(tag_list, negative_list, add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in self.prompts
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def pick(self, category: int, count: int, from_top: int):
new_prompts = self.prompts
for _ in range(count):
new_prompts = [
(tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in new_prompts
for s in (suggestions(),)
for tag in pick_tags(s, category, 1, from_top, tag_list + negative_list + self.skip_list)
]
return PromptBuilder(new_prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def foreach_pick(self, category: int, count: int, from_top: int):
return PromptBuilder(prompts=[
(tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating))
for (tag_list, negative_list, suggestions) in self.prompts
for s in (suggestions(),)
for tag in pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list)
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def pick_fast(self, category: int, count: int, from_top: int):
prompts = []
for (tag_list, negative_list, suggestions) in self.prompts:
s = suggestions()
new_tags = pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list)
prompts.append((tag_list + new_tags, negative_list, add_suggestions_later(s, new_tags, 1, self.samples, self.min_posts, self.rating)))
return PromptBuilder(prompts=prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def branch(self, count: int):
return PromptBuilder(prompts=[prompt for prompt in self.prompts for _ in range(count)], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating)
def build(self):
for (tag_list, negative_list, _) in self.prompts:
positive_prompt = ", ".join([ tag_to_prompt(tag) for tag in tag_list])
negative_prompt = ", ".join([ tag_to_prompt(tag) for tag in negative_list])
if negative_prompt:
yield f"{positive_prompt}\nNegative prompt: {negative_prompt}"
else:
yield positive_prompt
def print(self):
for prompt in self.build():
print(prompt)
def get_one(self):
for prompt in self.build():
return prompt