Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# Load the model | |
MODEL_NAME = "DarwinAnim8or/TinyRP" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto") | |
# Sample character presets | |
SAMPLE_CHARACTERS = { | |
"Custom Character": "", | |
"Adventurous Knight": "You are Sir Gareth, a brave and noble knight on a quest to save the kingdom. You speak with honor and courage, always ready to help those in need. You carry an enchanted sword and have a loyal horse named Thunder.", | |
"Mysterious Wizard": "You are Eldara, an ancient and wise wizard who speaks in riddles and knows secrets of the mystical arts. You live in a tower filled with magical books and potions. You are helpful but often cryptic in your responses.", | |
"Friendly Tavern Keeper": "You are Bram, a cheerful tavern keeper who loves telling stories and meeting new travelers. Your tavern 'The Dancing Dragon' is a warm, welcoming place. You know all the local gossip and always have a tale to share.", | |
"Curious Scientist": "You are Dr. Maya Chen, a brilliant scientist who is fascinated by discovery and invention. You're enthusiastic about explaining complex concepts in simple ways and always looking for new experiments to try.", | |
"Space Explorer": "You are Captain Nova, a fearless space explorer who has traveled to distant galaxies. You pilot the starship 'Wanderer' and have encountered many alien species. You're brave, curious, and always ready for the next adventure.", | |
"Fantasy Princess": "You are Princess Lyra, kind-hearted royalty who cares deeply about her people. You're intelligent, diplomatic, and skilled in both politics and magic. You often sneak out of the castle to help citizens in need." | |
} | |
def generate_response( | |
message, | |
history: list[tuple[str, str]], | |
character_description, | |
max_tokens, | |
temperature, | |
top_p, | |
repetition_penalty, | |
use_chatml_format | |
): | |
# Prepare the conversation | |
if use_chatml_format and character_description.strip(): | |
# Use ChatML format with character as system message | |
conversation = f"<|im_start|>system\n{character_description}<|im_end|>\n" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
conversation += f"<|im_start|>user\n{user_msg}<|im_end|>\n" | |
if assistant_msg: | |
conversation += f"<|im_start|>assistant\n{assistant_msg}<|im_end|>\n" | |
# Add current message | |
conversation += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
else: | |
# Simple format | |
if character_description.strip(): | |
conversation = f"{character_description}\n\n" | |
else: | |
conversation = "" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
conversation += f"Human: {user_msg}\n" | |
if assistant_msg: | |
conversation += f"Assistant: {assistant_msg}\n" | |
# Add current message | |
conversation += f"Human: {message}\nAssistant:" | |
# Tokenize | |
inputs = tokenizer.encode(conversation, return_tensors="pt", truncation=True, max_length=1024-max_tokens) | |
# Generate | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
# Decode response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the new response | |
if use_chatml_format: | |
# Split on the last assistant tag | |
response = full_response.split("<|im_start|>assistant\n")[-1] | |
# Remove any trailing end tags | |
response = response.replace("<|im_end|>", "").strip() | |
else: | |
# Split on the last "Assistant:" | |
response = full_response.split("Assistant:")[-1].strip() | |
return response | |
def load_character_preset(character_name): | |
"""Load a character preset""" | |
return SAMPLE_CHARACTERS.get(character_name, "") | |
# Custom CSS for better styling | |
css = """ | |
.character-card { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
border-radius: 15px; | |
padding: 20px; | |
margin: 10px 0; | |
color: white; | |
} | |
.title-text { | |
text-align: center; | |
font-size: 2.5em; | |
font-weight: bold; | |
background: linear-gradient(45deg, #667eea, #764ba2); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
margin-bottom: 20px; | |
} | |
.parameter-box { | |
background: #f8f9fa; | |
border-radius: 10px; | |
padding: 15px; | |
margin: 10px 0; | |
} | |
""" | |
# Create the Gradio interface | |
with gr.Blocks(css=css, title="TinyRP Chat Demo") as demo: | |
gr.HTML('<div class="title-text">π TinyRP Character Chat</div>') | |
gr.Markdown(""" | |
### Welcome to TinyRP! | |
This is a demo of a small but capable roleplay model. Choose a character preset or create your own! | |
**Tips for better roleplay:** | |
- Be descriptive in your messages | |
- Stay in character | |
- Use ChatML format for best results | |
- Adjust temperature for creativity vs consistency | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Chat interface | |
chatbot = gr.Chatbot( | |
label="Chat", | |
height=500, | |
show_label=False, | |
avatar_images=("π§", "π") | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="Your message", | |
placeholder="Type your message here...", | |
lines=2, | |
scale=4 | |
) | |
send_btn = gr.Button("Send", variant="primary", scale=1) | |
with gr.Column(scale=1): | |
# Character selection | |
with gr.Group(): | |
gr.Markdown("### π Character Setup") | |
character_preset = gr.Dropdown( | |
choices=list(SAMPLE_CHARACTERS.keys()), | |
value="Custom Character", | |
label="Character Presets", | |
interactive=True | |
) | |
character_description = gr.Textbox( | |
label="Character Description", | |
placeholder="Describe your character's personality, background, and speaking style...", | |
lines=6, | |
value="" | |
) | |
load_preset_btn = gr.Button("Load Preset", variant="secondary") | |
# Generation parameters | |
with gr.Group(): | |
gr.Markdown("### βοΈ Generation Settings") | |
use_chatml_format = gr.Checkbox( | |
label="Use ChatML Format", | |
value=True, | |
info="Recommended for better character consistency" | |
) | |
max_tokens = gr.Slider( | |
minimum=16, | |
maximum=512, | |
value=128, | |
step=16, | |
label="Max Response Length", | |
info="Longer = more detailed responses" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.9, | |
step=0.1, | |
label="Temperature", | |
info="Higher = more creative/random" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.85, | |
step=0.05, | |
label="Top-p", | |
info="Focus on top % of likely words" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=1.5, | |
value=1.1, | |
step=0.05, | |
label="Repetition Penalty", | |
info="Reduce repetitive text" | |
) | |
# Control buttons | |
with gr.Group(): | |
clear_btn = gr.Button("ποΈ Clear Chat", variant="secondary") | |
# Sample character cards | |
with gr.Row(): | |
gr.Markdown("### π Featured Characters") | |
with gr.Row(): | |
for char_name, char_desc in list(SAMPLE_CHARACTERS.items())[1:4]: # Show first 3 non-custom | |
with gr.Column(scale=1): | |
gr.Markdown(f""" | |
<div class="character-card"> | |
<h4>{char_name}</h4> | |
<p>{char_desc[:100]}...</p> | |
</div> | |
""") | |
# Event handlers | |
def respond_wrapper(message, history, char_desc, max_tok, temp, top_p, rep_pen, use_chatml): | |
if not message.strip(): | |
return history, "" | |
try: | |
response = generate_response( | |
message, history, char_desc, max_tok, temp, top_p, rep_pen, use_chatml | |
) | |
history.append((message, response)) | |
return history, "" | |
except Exception as e: | |
error_msg = f"Error generating response: {str(e)}" | |
history.append((message, error_msg)) | |
return history, "" | |
# Connect events | |
send_btn.click( | |
respond_wrapper, | |
inputs=[msg, chatbot, character_description, max_tokens, temperature, top_p, repetition_penalty, use_chatml_format], | |
outputs=[chatbot, msg] | |
) | |
msg.submit( | |
respond_wrapper, | |
inputs=[msg, chatbot, character_description, max_tokens, temperature, top_p, repetition_penalty, use_chatml_format], | |
outputs=[chatbot, msg] | |
) | |
load_preset_btn.click( | |
load_character_preset, | |
inputs=[character_preset], | |
outputs=[character_description] | |
) | |
character_preset.change( | |
load_character_preset, | |
inputs=[character_preset], | |
outputs=[character_description] | |
) | |
clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() |