Spaces:
Running
Running
import pandas | |
import numpy | |
import pandas.io.formats.style | |
import random | |
import functools | |
from typing import Callable, Literal | |
DATA_FOLDER = "." | |
CAT_GENERAL = 0 | |
CAT_ARTIST = 1 | |
CAT_UNUSED = 2 | |
CAT_COPYRIGHT = 3 | |
CAT_CHARACTER = 4 | |
CAT_SPECIES = 5 | |
CAT_INVALID = 6 | |
CAT_META = 7 | |
CAT_LORE = 8 | |
CATEGORY_COLORS = { | |
CAT_GENERAL: "#808080", | |
CAT_ARTIST: "#f2ac08", | |
CAT_UNUSED: "#ff3d3d", | |
CAT_COPYRIGHT: "#d0d", | |
CAT_CHARACTER: "#0a0", | |
CAT_SPECIES: "#ed5d1f", | |
CAT_INVALID: "#ff3d3d", | |
CAT_META: "#04f", | |
CAT_LORE: "#282" | |
} | |
def get_feather(filename: str) -> pandas.DataFrame: | |
return pandas.read_feather(f"{DATA_FOLDER}/{filename}.feather") | |
tags = get_feather("tags") | |
posts_by_tag = get_feather("posts_by_tag").set_index("tag_id") | |
tags_by_post = get_feather("tags_by_post").set_index("post_id") | |
tag_ratings = get_feather("tag_ratings") | |
implications = get_feather("implications") | |
tags_by_name = tags.copy(deep=True) | |
tags_by_name.set_index("name", inplace=True) | |
tags.set_index("tag_id", inplace=True) | |
def get_related_tags(targets: tuple[str, ...], exclude: tuple[str, ...] = (), samples: int = 100_000) -> pandas.DataFrame: | |
these_tags = tags_by_name.loc[list(targets)] | |
posts_with_these_tags = posts_by_tag.loc[these_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.intersection(*x))["post_id"][True] | |
if (len(exclude) > 0): | |
excluded_tags = tags_by_name.loc[list(exclude)] | |
posts_with_excluded_tags = posts_by_tag.loc[excluded_tags["tag_id"]].map(set).groupby(lambda x: True).agg(lambda x: set.union(*x))["post_id"][True] | |
posts_with_these_tags = posts_with_these_tags - posts_with_excluded_tags | |
total_post_count_together = len(posts_with_these_tags) | |
sample_posts = random.sample(list(posts_with_these_tags), samples) if total_post_count_together > samples else list(posts_with_these_tags) | |
post_count_together = len(sample_posts) | |
sample_ratio = post_count_together / total_post_count_together | |
tags_in_these_posts = tags_by_post.loc[sample_posts] | |
counts_in_these_posts = tags_in_these_posts["tag_id"].explode().value_counts().rename("overlap") | |
summaries = pandas.DataFrame(counts_in_these_posts).join(tags[tags["post_count"]>0], how="right").fillna(0) | |
summaries["overlap"] = numpy.minimum(summaries["overlap"] / sample_ratio, summaries["post_count"]) | |
summaries = summaries[["category", "name", "overlap", "post_count"]] | |
# Old "interestingness" value, didn't give as good results as an actual statistical technique, go figure. Code kept for curiosity's sake. | |
#summaries["interestingness"] = summaries["overlap"].pow(2) / (total_post_count_together * summaries["post_count"]) | |
# Phi coefficient stuff. | |
n = float(len(tags_by_post)) | |
n11 = summaries["overlap"] | |
n1x = float(total_post_count_together) | |
nx1 = summaries["post_count"].astype("float64") | |
summaries["correlation"] = (n * n11 - n1x * nx1) / numpy.sqrt(n1x * nx1 * (n - n1x) * (n - nx1)) | |
return summaries | |
def format_tags(styler: pandas.io.formats.style.Styler): | |
styler.apply(lambda row: numpy.where(row.index == "name", "color:"+CATEGORY_COLORS[row["category"]], ""), axis=1) | |
styler.hide(level=0) | |
styler.hide("category",axis=1) | |
if 'overlap' in styler.data: | |
styler.format("{:.0f}".format, subset=["overlap"]) | |
if 'correlation' in styler.data: | |
styler.format("{:.2f}".format, subset=["correlation"]) | |
styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["correlation"]) | |
if 'score' in styler.data: | |
styler.format("{:.2f}".format, subset=["score"]) | |
styler.background_gradient(vmin=-1.0, vmax=1.0, cmap="RdYlGn", subset=["score"]) | |
return styler | |
def related_tags(*targets: str, exclude: tuple[str, ...] = (), category: int = None, samples: int = 100_000, min_overlap: int = 5, min_posts: int = 20, top: int = 30, bottom: int = 0) -> pandas.DataFrame: | |
result = get_related_tags(targets, exclude=exclude, samples=samples) | |
if category != None: | |
result = result[result["category"] == category] | |
result = result[~result["name"].isin(targets)] | |
result = result[result["overlap"] >= min_overlap] | |
result = result[result["post_count"] >= min_posts] | |
top_part = result.sort_values("correlation", ascending=False)[:top] | |
bottom_part = result.sort_values("correlation", ascending=True)[:bottom].sort_values("correlation", ascending=False) | |
return pandas.concat([top_part, bottom_part]).style.pipe(format_tags) | |
def implications_for(*subjects: str, seen: set[str] = None): | |
if seen is None: | |
seen = set() | |
for subject in subjects: | |
found = tags.loc[list(implications[implications["antecedent_id"] == tags_by_name.loc[subject, "tag_id"]].loc[:,"consequent_id"]), "name"].values | |
for f in found: | |
if f in seen: | |
pass | |
else: | |
yield f | |
seen.add(f) | |
yield from implications_for(f, seen=seen) | |
def parse_tag(potential_tag: str): | |
potential_tag = potential_tag.strip().replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
if potential_tag == "": | |
return None | |
elif potential_tag in tags_by_name.index: | |
return potential_tag | |
elif potential_tag.startswith("by_") and potential_tag[3:] in tags_by_name.index: | |
return potential_tag[3:] | |
else: | |
print(f"Couldn't find tag '{potential_tag}', skipping it.") | |
def parse_tags(*parts: str): | |
for part in parts: | |
for potential_tag in part.split(","): | |
tag = parse_tag(potential_tag) | |
if tag is not None: | |
yield tag | |
def add_suggestions(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples : int, min_posts: int, rating: Literal['s', 'q', 'e']): | |
if isinstance(new_tags, str): | |
new_tags = [new_tags] | |
for new_tag in new_tags: | |
related = get_related_tags((new_tag,), samples=samples) | |
# Implementing the rating filter this way is horribly inefficient, fix it later | |
if rating == 's': | |
related = related.join(tag_ratings.set_index("tag_id"), on="tag_id") | |
related["post_count"] = related["s"] | |
related = related.drop("s", axis=1) | |
related = related.drop("q", axis=1) | |
related = related.drop("e", axis=1) | |
elif rating == 'q': | |
related = related.join(tag_ratings.set_index("tag_id"), on="tag_id") | |
related["post_count"] = related["s"] + related["q"] | |
related = related.drop("s", axis=1) | |
related = related.drop("q", axis=1) | |
related = related.drop("e", axis=1) | |
related = related[related["post_count"] >= min_posts] | |
if suggestions is None: | |
suggestions = related.rename(columns={"correlation": "score"}) | |
else: | |
suggestions = suggestions.join(related, rsuffix="r") | |
# This is a totally made up way to combine correlations. It keeps them from going outside the +/- 1 range, which is nice. It also makes older | |
# tags less important every time newer ones are added. That could be considered a feature or not. | |
suggestions["score"] = numpy.real(numpy.power((numpy.sqrt(suggestions["score"] + 0j) + numpy.sqrt(multiplier * suggestions["correlation"] + 0j)) / 2, 2)) | |
return suggestions[["category", "name", "post_count", "score"]] | |
def pick_tags(suggestions: pandas.DataFrame, category: int, count: int, from_top: int, excluding: list[str], weighted: bool = True): | |
options = suggestions[(True if category is None else suggestions["category"] == category) & (suggestions["score"] > 0) & ~suggestions["name"].isin(excluding)].sort_values("score", ascending=False)[:from_top] | |
if weighted: | |
values = list(options["name"].values) | |
weights = list(options["score"].values) | |
choices = [] | |
for _ in range(count): | |
choice = random.choices(population=values, weights=weights, k=1)[0] | |
weights.pop(values.index(choice)) | |
values.remove(choice) | |
choices.append(choice) | |
return choices | |
else: | |
return random.sample(list(options["name"].values), count) | |
def tag_to_prompt(tag: str) -> str: | |
if (tags_by_name.loc[tag]["category"] == CAT_ARTIST): | |
tag = "by " + tag | |
return tag.replace("_", " ").replace("(" , "\\(").replace(")" , "\\)") | |
# A lambda in a for loop doesn't capture variables the way I want it to, so this is a method now | |
def add_suggestions_later(suggestions: pandas.DataFrame, new_tags: str | list[str], multiplier: int, samples: int, min_posts: int, rating: Literal['s', 'q', 'e']): | |
return lambda: add_suggestions(suggestions, new_tags, multiplier, samples, min_posts, rating) | |
Prompt = tuple[list[str], list[str], Callable[[], pandas.DataFrame]] | |
class PromptBuilder: | |
prompts: list[Prompt] | |
samples: int | |
min_posts: int | |
rating: Literal['s', 'q', 'e'] | |
skip_list: list[str] | |
def __init__(self, prompts = [([],[],lambda: None)], skip=[], samples = 100_000, min_posts = 20, rating: Literal['s', 'q', 'e'] = 'e'): | |
self.prompts = prompts | |
self.samples = samples | |
self.min_posts = min_posts | |
self.rating = rating | |
self.skip_list = skip | |
def include(self, tag: str): | |
return PromptBuilder(prompts=[ | |
(tag_list + [tag], negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in self.prompts | |
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def focus(self, tag: str): | |
return PromptBuilder(prompts=[ | |
(tag_list, negative_list, add_suggestions_later(suggestions(), tag, 1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in self.prompts | |
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def exclude(self, tag: str): | |
return PromptBuilder(prompts=[ | |
(tag_list, negative_list + [tag], add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in self.prompts | |
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def avoid(self, tag: str): | |
return PromptBuilder(prompts=[ | |
(tag_list, negative_list, add_suggestions_later(suggestions(), tag, -1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in self.prompts | |
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def pick(self, category: int, count: int, from_top: int): | |
new_prompts = self.prompts | |
for _ in range(count): | |
new_prompts = [ | |
(tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in new_prompts | |
for s in (suggestions(),) | |
for tag in pick_tags(s, category, 1, from_top, tag_list + negative_list + self.skip_list) | |
] | |
return PromptBuilder(new_prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def foreach_pick(self, category: int, count: int, from_top: int): | |
return PromptBuilder(prompts=[ | |
(tag_list + [tag], negative_list, add_suggestions_later(s, tag, 1, self.samples, self.min_posts, self.rating)) | |
for (tag_list, negative_list, suggestions) in self.prompts | |
for s in (suggestions(),) | |
for tag in pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list) | |
], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def pick_fast(self, category: int, count: int, from_top: int): | |
prompts = [] | |
for (tag_list, negative_list, suggestions) in self.prompts: | |
s = suggestions() | |
new_tags = pick_tags(s, category, count, from_top, tag_list + negative_list + self.skip_list) | |
prompts.append((tag_list + new_tags, negative_list, add_suggestions_later(s, new_tags, 1, self.samples, self.min_posts, self.rating))) | |
return PromptBuilder(prompts=prompts, samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def branch(self, count: int): | |
return PromptBuilder(prompts=[prompt for prompt in self.prompts for _ in range(count)], samples=self.samples, min_posts=self.min_posts, skip=self.skip_list, rating=self.rating) | |
def build(self): | |
for (tag_list, negative_list, _) in self.prompts: | |
positive_prompt = ", ".join([ tag_to_prompt(tag) for tag in tag_list]) | |
negative_prompt = ", ".join([ tag_to_prompt(tag) for tag in negative_list]) | |
if negative_prompt: | |
yield f"{positive_prompt}\nNegative prompt: {negative_prompt}" | |
else: | |
yield positive_prompt | |
def print(self): | |
for prompt in self.build(): | |
print(prompt) | |
def get_one(self): | |
for prompt in self.build(): | |
return prompt | |