Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from namegenerator import Model, ModelConfig | |
torch.set_grad_enabled(False) | |
special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>", "0", "1"] | |
tokens = special_tokens + list("abcdefghijklmnopqrstuvwxyz") | |
char_to_idx = {char: idx for idx, char in enumerate(tokens)} | |
idx_to_char = {idx: char for idx, char in enumerate(tokens)} | |
hf_hub_download( | |
"karanravindra/namegenerator", "model.pth", subfolder="model", local_dir="." | |
) | |
model = Model( | |
ModelConfig( | |
vocab_size=len(tokens), | |
embedding_dim=48, | |
num_layers=6, | |
max_length=24, # not padding to nearest 32 because max length of names is 17 - bump this for `theoretically` better performance | |
q_heads=12, | |
kv_heads=4, | |
m=4, | |
tie_weights=False, | |
) | |
) | |
model.load_state_dict( | |
torch.load("model/model.pth", map_location="cpu", weights_only=True) | |
) | |
model.eval() | |
def decode(encoded_name: list[int], strip_special_tokens: bool = True) -> str: | |
if strip_special_tokens: | |
encoded_name = [ | |
idx | |
for idx in encoded_name | |
if idx | |
not in [char_to_idx["<sos>"], char_to_idx["<eos>"], char_to_idx["<pad>"]] | |
] | |
return "".join([idx_to_char[idx] for idx in encoded_name]) | |
def decode_batch( | |
encoded_names: torch.Tensor, strip_special_tokens: bool = True | |
) -> list[str]: | |
return [ | |
decode(encoded_name.tolist(), strip_special_tokens) | |
for encoded_name in encoded_names | |
] | |
def generate_names(n=16, gender=None, temperature=0.6): | |
model.eval() | |
if gender is None: | |
genders = torch.cat( | |
[ | |
torch.tensor([[char_to_idx["0"]]]).repeat(n // 2, 1), | |
torch.tensor([[char_to_idx["1"]]]).repeat(n // 2, 1), | |
], | |
dim=0, | |
) | |
else: | |
gender = char_to_idx[str(gender)] | |
genders = torch.full((n, 1), gender) | |
start_token = torch.tensor([[char_to_idx["<sos>"]]]).repeat(n, 1) | |
start_token = torch.cat([start_token, genders], dim=1) | |
generated = start_token | |
for _ in range(22): | |
output = model(generated) / temperature | |
token = torch.multinomial(F.softmax(output[:, -1], dim=1), 1) | |
generated = torch.cat([generated, token], dim=1) | |
if token.all() == char_to_idx["<pad>"]: | |
break | |
return decode_batch(generated, strip_special_tokens=True) | |
def generate_name(gender: str, num_names: int, temperature: float): | |
names = generate_names(num_names, gender, temperature) | |
names = [name[1:].capitalize() for name in names] | |
return "\n".join(names) | |
demo = gr.Interface( | |
generate_name, | |
gr.Radio(["Male", "Female"], label="Sex", type="index"), | |
gr.TextArea(lines=16, label="Generated Names"), | |
additional_inputs=[ | |
gr.Number(16, label="Number of Names"), | |
gr.Slider(0.1, 2, 0.6, label="Temperature", step=0.1), | |
], | |
title="Name Generator", | |
description="Generates names based on sex using a GPT-2 model trained on names.", | |
) | |
demo.launch() | |