RWKV-7 / app.py
Tonic's picture
Update app.py
22d5543 verified
import os
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from tokenizers import Tokenizer
import json
import math
import requests
from tqdm import tqdm
# Download tokenizer if not present
TOKENIZER_FILE = "20B_tokenizer.json"
TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json"
def download_file(url, filename):
if not os.path.exists(filename):
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as file, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=1024):
size = file.write(data)
pbar.update(size)
# Ensure tokenizer exists
if not os.path.exists(TOKENIZER_FILE):
download_file(TOKENIZER_URL, TOKENIZER_FILE)
tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
class RWKV_Model:
def __init__(self, model_path):
self.model_path = model_path
self.model = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(self):
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file {self.model_path} not found")
self.model = torch.load(self.model_path, map_location=self.device)
print("Model loaded successfully")
def generate(self, prompt, max_length=100, temperature=1.0, top_p=0.9):
if self.model is None:
self.load_model()
input_ids = tokenizer.encode(prompt).ids
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device)
with torch.no_grad():
output_sequence = []
for _ in range(max_length):
outputs = self.model(input_tensor)
next_token_logits = outputs[0, -1, :] / temperature
# Apply top-p sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
output_sequence.append(next_token.item())
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
if next_token.item() == tokenizer.token_to_id("</s>"):
break
return tokenizer.decode(output_sequence)
def generate_text(
prompt,
temperature=1.0,
top_p=0.9,
max_length=100,
model_size="small"
):
try:
# Select model based on size
model_path = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" if model_size == "small" else "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth"
model = RWKV_Model(model_path)
generated_text = model.generate(
prompt=prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p
)
return generated_text
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# RWKV-7 Text Generation Demo")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Input Prompt",
placeholder="Enter your prompt here...",
lines=5
)
model_size = gr.Radio(
choices=["small", "large"],
label="Model Size",
value="small"
)
with gr.Column():
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
label="Temperature"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
label="Top-p"
)
max_length_slider = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Maximum Length"
)
generate_button = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Output", lines=10)
generate_button.click(
fn=generate_text,
inputs=[
prompt_input,
temperature_slider,
top_p_slider,
max_length_slider,
model_size
],
outputs=output_text
)
gr.Markdown("""
## Parameters:
- **Temperature**: Controls randomness (higher = more random)
- **Top-p**: Controls diversity (higher = more diverse)
- **Maximum Length**: Maximum number of tokens to generate
- **Model Size**:
- Small (0.1B parameters)
- Large (0.4B parameters)
""")
if __name__ == "__main__":
demo.launch()