import torch from transformers import AutoModelForCausalLM, AutoTokenizer import time import gradio as gr from gradio import deploy def generate_prompt(instruction, input=""): instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') input = input.strip().replace('\r\n','\n').replace('\n\n','\n') if input: return f"""Instruction: {instruction} Input: {input} Response:""" else: return f"""User: hi Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. User: {instruction} Assistant:""" model_path = "models/rwkv-6-world-1b6/" # Path to your local model directory model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True, # use_flash_attention_2=False ).to(torch.float32) # Create a custom tokenizer (make sure to download vocab.json) tokenizer = AutoTokenizer.from_pretrained( model_path, bos_token="", eos_token="", unk_token="", pad_token="", trust_remote_code=True, padding_side='left', clean_up_tokenization_spaces=False # Or set to True if you prefer ) # Function to handle text generation with word-by-word output and stop sequence def generate_text(input_text): prompt = generate_prompt(input_text) input_ids = tokenizer(prompt, return_tensors="pt").input_ids generated_text = "" for i in range(333): output = model.generate(input_ids, max_new_tokens=1, do_sample=True, temperature=1.0, top_p=0.3, top_k=0) new_word = tokenizer.decode(output[0][-1:], skip_special_tokens=True) print(new_word, end="", flush=True) # Print to console for monitoring generated_text += new_word input_ids = output yield generated_text # Yield the updated text after each word # Create the Gradio interface iface = gr.Interface( fn=generate_text, inputs="text", outputs="text", title="RWKV Chatbot", description="Enter your prompt below:", # flagging_callback=None flagging_dir="gradio_flagged/" ) # For local testing: iface.launch(share=False) # deploy() # Hugging Face Spaces will automatically launch the interface.