File size: 3,251 Bytes
19cef8f
b23b2e1
db63f1a
19cef8f
f593040
 
 
 
db63f1a
b23b2e1
19cef8f
f593040
 
 
b23b2e1
f593040
b23b2e1
 
f593040
db63f1a
b23b2e1
 
 
 
 
 
 
 
db63f1a
f593040
db63f1a
 
 
 
 
 
 
b23b2e1
db63f1a
f593040
19cef8f
db63f1a
f593040
 
 
 
 
 
 
 
 
 
 
19cef8f
db63f1a
 
 
 
 
 
b23b2e1
db63f1a
 
 
f593040
 
 
1d33274
19cef8f
f593040
db63f1a
f593040
db63f1a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import streamlit as st
from huggingface_hub import InferenceClient
import time

# Load custom CSS
with open('style.css') as f:
    st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)

# Initialize the HuggingFace Inference Client
client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")

def format_prompt(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story, system_prompt=""):
    prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
    prompt += f"Her hobbies include {hobbies}. Her favorite food is {favorite_food}. Here is her background story: {background_story}."
    if system_prompt:
        prompt = f"[SYS] {system_prompt} [/SYS] " + prompt
    return prompt

def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
    temperature = max(temperature, 1e-2)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    try:
        stream = client.text_generation(prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""
        for response in stream:
            output += response.token.text
        return output
    except Exception as e:
        st.error(f"Error generating text: {e}")
        return ""

def main():
    st.title("Enhanced Waifu Character Generator")

    # User inputs
    col1, col2 = st.columns(2)
    with col1:
        name = st.text_input("Name of the Waifu")
        hair_color = st.selectbox("Hair Color", ["Blonde", "Brunette", "Red", "Black", "Blue", "Pink"])
        personality = st.selectbox("Personality", ["Tsundere", "Yandere", "Kuudere", "Dandere", "Genki", "Normal"])
        outfit_style = st.selectbox("Outfit Style", ["School Uniform", "Maid Outfit", "Casual", "Kimono", "Gothic Lolita"])
    with col2:
        hobbies = st.text_input("Hobbies")
        favorite_food = st.text_input("Favorite Food")
        background_story = st.text_area("Background Story")
        system_prompt = st.text_input("Optional System Prompt", "")

    # Advanced settings
    with st.expander("Advanced Settings"):
        temperature = st.slider("Temperature", 0.0, 1.0, 0.9, step=0.05)
        max_new_tokens = st.slider("Max new tokens", 0, 8192, 512, step=64)
        top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
        repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)

    # Generate button
    if st.button("Generate Waifu"):
        with st.spinner("Generating waifu character..."):
            prompt = format_prompt(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story, system_prompt)
            generated_text = generate_text(prompt, temperature, max_new_tokens, top_p, repetition_penalty)
            st.success("Waifu character generated!")

    # Display the generated character
    if generated_text:
        st.subheader("Generated Waifu Character")
        st.write(generated_text)

if __name__ == "__main__":
    main()