import torch import torch.nn.functional as F import gradio as gr from huggingface_hub import hf_hub_download from namegenerator import Model, ModelConfig torch.set_grad_enabled(False) special_tokens = ["", "", "", "", "0", "1"] tokens = special_tokens + list("abcdefghijklmnopqrstuvwxyz") char_to_idx = {char: idx for idx, char in enumerate(tokens)} idx_to_char = {idx: char for idx, char in enumerate(tokens)} hf_hub_download( "karanravindra/namegenerator", "model.pth", subfolder="model", local_dir="." ) model = Model( ModelConfig( vocab_size=len(tokens), embedding_dim=48, num_layers=6, max_length=24, # not padding to nearest 32 because max length of names is 17 - bump this for `theoretically` better performance q_heads=12, kv_heads=4, m=4, tie_weights=False, ) ) model.load_state_dict( torch.load("model/model.pth", map_location="cpu", weights_only=True) ) model.eval() def decode(encoded_name: list[int], strip_special_tokens: bool = True) -> str: if strip_special_tokens: encoded_name = [ idx for idx in encoded_name if idx not in [char_to_idx[""], char_to_idx[""], char_to_idx[""]] ] return "".join([idx_to_char[idx] for idx in encoded_name]) def decode_batch( encoded_names: torch.Tensor, strip_special_tokens: bool = True ) -> list[str]: return [ decode(encoded_name.tolist(), strip_special_tokens) for encoded_name in encoded_names ] def generate_names(n=16, gender=None, temperature=0.6): model.eval() if gender is None: genders = torch.cat( [ torch.tensor([[char_to_idx["0"]]]).repeat(n // 2, 1), torch.tensor([[char_to_idx["1"]]]).repeat(n // 2, 1), ], dim=0, ) else: gender = char_to_idx[str(gender)] genders = torch.full((n, 1), gender) start_token = torch.tensor([[char_to_idx[""]]]).repeat(n, 1) start_token = torch.cat([start_token, genders], dim=1) generated = start_token for _ in range(22): output = model(generated) / temperature token = torch.multinomial(F.softmax(output[:, -1], dim=1), 1) generated = torch.cat([generated, token], dim=1) if token.all() == char_to_idx[""]: break return decode_batch(generated, strip_special_tokens=True) def generate_name(gender: str, num_names: int, temperature: float): names = generate_names(num_names, gender, temperature) names = [name[1:].capitalize() for name in names] return "\n".join(names) demo = gr.Interface( generate_name, gr.Radio(["Male", "Female"], label="Sex", type="index"), gr.TextArea(lines=16, label="Generated Names"), additional_inputs=[ gr.Number(16, label="Number of Names"), gr.Slider(0.1, 2, 0.6, label="Temperature", step=0.1), ], title="Name Generator", description="Generates names based on sex using a GPT-2 model trained on names.", ) demo.launch()