import os import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from tokenizers import Tokenizer import json import math import requests from tqdm import tqdm # Download tokenizer if not present TOKENIZER_FILE = "20B_tokenizer.json" TOKENIZER_URL = "https://raw.githubusercontent.com/BlinkDL/ChatRWKV/main/20B_tokenizer.json" def download_file(url, filename): if not os.path.exists(filename): print(f"Downloading {filename}...") response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(filename, 'wb') as file, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as pbar: for data in response.iter_content(chunk_size=1024): size = file.write(data) pbar.update(size) # Ensure tokenizer exists if not os.path.exists(TOKENIZER_FILE): download_file(TOKENIZER_URL, TOKENIZER_FILE) tokenizer = Tokenizer.from_file(TOKENIZER_FILE) class RWKV_Model: def __init__(self, model_path): self.model_path = model_path self.model = None self.device = "cuda" if torch.cuda.is_available() else "cpu" def load_model(self): if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model file {self.model_path} not found") self.model = torch.load(self.model_path, map_location=self.device) print("Model loaded successfully") def generate(self, prompt, max_length=100, temperature=1.0, top_p=0.9): if self.model is None: self.load_model() input_ids = tokenizer.encode(prompt).ids input_tensor = torch.tensor(input_ids).unsqueeze(0).to(self.device) with torch.no_grad(): output_sequence = [] for _ in range(max_length): outputs = self.model(input_tensor) next_token_logits = outputs[0, -1, :] / temperature # Apply top-p sampling sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[indices_to_remove] = float('-inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) output_sequence.append(next_token.item()) input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1) if next_token.item() == tokenizer.token_to_id(""): break return tokenizer.decode(output_sequence) def generate_text( prompt, temperature=1.0, top_p=0.9, max_length=100, model_size="small" ): try: # Select model based on size model_path = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.pth" if model_size == "small" else "RWKV-x070-World-0.4B-v2.9-20250107-ctx4096.pth" model = RWKV_Model(model_path) generated_text = model.generate( prompt=prompt, max_length=max_length, temperature=temperature, top_p=top_p ) return generated_text except Exception as e: return f"Error: {str(e)}" # Create Gradio interface with gr.Blocks() as demo: gr.Markdown("# RWKV-7 Text Generation Demo") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Input Prompt", placeholder="Enter your prompt here...", lines=5 ) model_size = gr.Radio( choices=["small", "large"], label="Model Size", value="small" ) with gr.Column(): temperature_slider = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, label="Temperature" ) top_p_slider = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, label="Top-p" ) max_length_slider = gr.Slider( minimum=10, maximum=500, value=100, step=10, label="Maximum Length" ) generate_button = gr.Button("Generate") output_text = gr.Textbox(label="Generated Output", lines=10) generate_button.click( fn=generate_text, inputs=[ prompt_input, temperature_slider, top_p_slider, max_length_slider, model_size ], outputs=output_text ) gr.Markdown(""" ## Parameters: - **Temperature**: Controls randomness (higher = more random) - **Top-p**: Controls diversity (higher = more diverse) - **Maximum Length**: Maximum number of tokens to generate - **Model Size**: - Small (0.1B parameters) - Large (0.4B parameters) """) if __name__ == "__main__": demo.launch()