Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import torch | |
import gradio as gr | |
from typing import Optional | |
from dataclasses import dataclass | |
from transformers import AutoTokenizer | |
from model import Transformer | |
class ModelArgs: | |
# Arch params | |
dim: int = 576 | |
intermediate_dim: int = 1536 | |
n_layers: int = 30 | |
n_heads: int = 9 | |
n_kv_heads: Optional[int] = 3 | |
vocab_size: int = 49152 # defined later by tokenizer | |
norm_eps: float = 1.0e-05 | |
init_scale: float = 0.041666666666666664 | |
rope_theta: int = 10000 | |
dropout: float = 0.1 | |
# Training params | |
seed: int = 42 | |
max_batch_size: int = 2 | |
max_seq_len: int = 2048 | |
steps: int = 5050 | |
breakpoint_step: int = 5000 | |
warmup_steps_frac: float = 0.5 | |
save_interval:int = 1000 | |
eval_interval:int = 500 | |
log_interval: int = 1 | |
grad_accum_steps: int = 8 | |
checkpoint_path = os.path.join(os.getcwd(), "checkpoints") | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
# Optimizer | |
initial_lr: float = 5e-4 | |
adam_beta1: float = 0.9 | |
adam_beta2: float = 0.95 | |
adam_eps: float = 1.0e-08 | |
weight_decay: float = 0.01 | |
use_fused: bool = True | |
# Initialize model and tokenizer | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") | |
tokenizer.pad_token = tokenizer.eos_token | |
config = ModelArgs() | |
config.device = device | |
model = Transformer(config) | |
# Load trained weights from zip | |
def load_checkpoint(model, path, device): | |
try: | |
checkpoint = torch.load(path, map_location=device) | |
model.load_state_dict({k.replace("_orig_mod.", ""): v for k, v in checkpoint.items() if 'cached_keys' not in k and 'cached_values' not in k}) | |
return model | |
except Exception as e: | |
print(f"Error loading checkpoint: {e}") | |
return None | |
model = load_checkpoint(model, "smollm2_HF.pth", device) | |
model.to(device) | |
model.eval() | |
def generate_text(prompt, | |
min_length: int = 28, | |
max_length: int = 40, | |
temperature: float =0.7, | |
top_k: int = 50, | |
top_p: float = 0.7 | |
): | |
"""Generate text from a prompt""" | |
min_length = int(max_length) | |
max_length = int(max_length) | |
temperature = float(temperature) | |
top_k = int(top_k) | |
top_p = float(top_p) | |
input_ids = tokenizer(prompt, | |
padding=True, | |
truncation=True, | |
max_length=config.max_seq_len, | |
return_tensors="pt")["input_ids"].to(device) | |
generated = model.generate( | |
input_ids, | |
max_length=max_length, | |
min_length=min_length, | |
pad_token_id=tokenizer.pad_token_id, | |
do_sample=True, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p | |
) | |
return tokenizer.decode(generated[0], skip_special_tokens=True) | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), | |
gr.Slider(minimum=10, maximum=500, value=28, label="Min Length"), | |
gr.Slider(minimum=10, maximum=500, value=64, label="Max Length"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.8, label="Temperature"), | |
gr.Slider(minimum=1, maximum=100, value=50, label="Top K"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Top P") | |
], | |
outputs=gr.Textbox(label="Generated Text"), | |
title="SmolLM2-135M Text Generation", | |
description="SmolLM2-135M trained onn cosmopedia-v2 with just 5000 steps", | |
examples=[ | |
["I found the love", 10, 50, 0.7, 50, 0.7], | |
["When the sun comes up", 20, 40, 0.8, 40, 0.9], | |
["The slow marching of ", 30, 60, 0.9, 45, 0.8], | |
], | |
) | |
if __name__ == "__main__": | |
iface.launch() |