|
import os |
|
import gradio as gr |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tokenizers import Tokenizer |
|
import json |
|
import math |
|
import requests |
|
from tqdm import tqdm |
|
|
|
|
|
TOKENIZER_FILE = "20B_tokenizer.json" |
|
TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json" |
|
|
|
def download_file(url, filename): |
|
if not os.path.exists(filename): |
|
print(f"Downloading {filename}...") |
|
response = requests.get(url, stream=True) |
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
with open(filename, 'wb') as file, tqdm( |
|
desc=filename, |
|
total=total_size, |
|
unit='iB', |
|
unit_scale=True, |
|
unit_divisor=1024, |
|
) as pbar: |
|
for data in response.iter_content(chunk_size=1024): |
|
size = file.write(data) |
|
pbar.update(size) |
|
|
|
|
|
if not os.path.exists(TOKENIZER_FILE): |
|
download_file(TOKENIZER_URL, TOKENIZER_FILE) |
|
|
|
tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
|
|
|
class RWKV_Model: |
|
def __init__(self, model_path): |
|
self.model_path = model_path |
|
self.model = None |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def load_model(self): |
|
if not os.path.exists(self.model_path): |
|
raise FileNotFoundError(f"Model file {self.model_path} not found") |
|
|
|
self.model = torch.load(self.model_path, map_location=self.device) |
|
print("Model loaded successfully") |
|
|
|
def generate(self, prompt, max_length=100, temperature=1.0, top_p=0.9): |
|
if self.model is None: |
|
self.load_model() |
|
|
|
input_ids = tokenizer.encode(prompt).ids |
|
input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device) |
|
|
|
with torch.no_grad(): |
|
output_sequence = [] |
|
|
|
for _ in range(max_length): |
|
outputs = self.model(input_tensor) |
|
next_token_logits = outputs[0, -1, :] / temperature |
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|
next_token_logits[indices_to_remove] = float('-inf') |
|
|
|
probs = F.softmax(next_token_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
output_sequence.append(next_token.item()) |
|
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1) |
|
|
|
if next_token.item() == tokenizer.token_to_id("</s>"): |
|
break |
|
|
|
return tokenizer.decode(output_sequence) |
|
|
|
def generate_text( |
|
prompt, |
|
temperature=1.0, |
|
top_p=0.9, |
|
max_length=100, |
|
model_size="small" |
|
): |
|
try: |
|
|
|
model_path = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" if model_size == "small" else "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" |
|
|
|
model = RWKV_Model(model_path) |
|
|
|
generated_text = model.generate( |
|
prompt=prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
return generated_text |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# RWKV-7 Text Generation Demo") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox( |
|
label="Input Prompt", |
|
placeholder="Enter your prompt here...", |
|
lines=5 |
|
) |
|
model_size = gr.Radio( |
|
choices=["small", "large"], |
|
label="Model Size", |
|
value="small" |
|
) |
|
|
|
with gr.Column(): |
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=1.0, |
|
label="Temperature" |
|
) |
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
label="Top-p" |
|
) |
|
max_length_slider = gr.Slider( |
|
minimum=10, |
|
maximum=500, |
|
value=100, |
|
step=10, |
|
label="Maximum Length" |
|
) |
|
|
|
generate_button = gr.Button("Generate") |
|
output_text = gr.Textbox(label="Generated Output", lines=10) |
|
|
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[ |
|
prompt_input, |
|
temperature_slider, |
|
top_p_slider, |
|
max_length_slider, |
|
model_size |
|
], |
|
outputs=output_text |
|
) |
|
|
|
gr.Markdown(""" |
|
## Parameters: |
|
- **Temperature**: Controls randomness (higher = more random) |
|
- **Top-p**: Controls diversity (higher = more diverse) |
|
- **Maximum Length**: Maximum number of tokens to generate |
|
- **Model Size**: |
|
- Small (0.1B parameters) |
|
- Large (0.4B parameters) |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |