randeom commited on
Commit
740a7c8
·
verified ·
1 Parent(s): ab9dfa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -42
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  from huggingface_hub import InferenceClient
3
- from gradio_client import Client
4
  import re
5
 
6
  # Load custom CSS
@@ -8,8 +7,7 @@ with open('style.css') as f:
8
  st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
9
 
10
  # Initialize the HuggingFace Inference Client
11
- text_client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
12
- image_client = Client("cagliostrolab/animagine-xl-3.1")
13
 
14
  def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
15
  prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
@@ -25,7 +23,7 @@ def clean_generated_text(text):
25
  clean_text = re.sub(r'</s>$', '', text).strip()
26
  return clean_text
27
 
28
- def generate_text(client, prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
29
  temperature = max(temperature, 1e-2)
30
  generate_kwargs = dict(
31
  temperature=temperature,
@@ -45,32 +43,6 @@ def generate_text(client, prompt, temperature=0.9, max_new_tokens=512, top_p=0.9
45
  st.error(f"Error generating text: {e}")
46
  return ""
47
 
48
- def generate_image(prompt):
49
- try:
50
- result = image_client.predict(
51
- prompt, # Image prompt
52
- "", # Negative prompt
53
- 0, # Seed
54
- 512, # Width
55
- 512, # Height
56
- 7.5, # Guidance scale
57
- 25, # Number of inference steps
58
- 'DPM++ 2M Karras', # Sampler
59
- '1024 x 1024', # Aspect Ratio
60
- 'Anime', # Style Preset
61
- '(None)', # Quality Tags Presets
62
- True, # Use Upscaler
63
- 0, # Strength
64
- 1, # Upscale by
65
- True, # Add Quality Tags
66
- api_name="/run"
67
- )
68
- return result[0]['image']
69
- except Exception as e:
70
- st.error(f"Error generating image: {e}")
71
- st.write("Full error details:", e)
72
- return None
73
-
74
  def main():
75
  st.title("Enhanced Waifu Character Generator")
76
 
@@ -94,13 +66,11 @@ def main():
94
  top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
95
  repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
96
 
97
- # Initialize session state for generated text and image prompt
98
  if "character_description" not in st.session_state:
99
  st.session_state.character_description = ""
100
  if "image_prompt" not in st.session_state:
101
  st.session_state.image_prompt = ""
102
- if "image_path" not in st.session_state:
103
- st.session_state.image_path = ""
104
 
105
  # Generate button
106
  if st.button("Generate Waifu"):
@@ -109,13 +79,10 @@ def main():
109
  image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)
110
 
111
  # Generate character description
112
- st.session_state.character_description = generate_text(text_client, description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
113
 
114
  # Generate image prompt
115
- st.session_state.image_prompt = generate_text(text_client, image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
116
-
117
- # Generate image from image prompt
118
- st.session_state.image_path = generate_image(st.session_state.image_prompt)
119
 
120
  st.success("Waifu character generated!")
121
 
@@ -126,10 +93,6 @@ def main():
126
  if st.session_state.image_prompt:
127
  st.subheader("Image Prompt")
128
  st.write(st.session_state.image_prompt)
129
- if st.session_state.image_path:
130
- st.subheader("Generated Image")
131
- st.image(st.session_state.image_path)
132
 
133
  if __name__ == "__main__":
134
  main()
135
-
 
1
  import streamlit as st
2
  from huggingface_hub import InferenceClient
 
3
  import re
4
 
5
  # Load custom CSS
 
7
  st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
8
 
9
  # Initialize the HuggingFace Inference Client
10
+ client = InferenceClient(model="mistralai/Mistral-7B-Instruct-v0.1")
 
11
 
12
  def format_prompt_for_description(name, hair_color, personality, outfit_style, hobbies, favorite_food, background_story):
13
  prompt = f"Create a waifu character named {name} with {hair_color} hair, a {personality} personality, and wearing a {outfit_style}. "
 
23
  clean_text = re.sub(r'</s>$', '', text).strip()
24
  return clean_text
25
 
26
+ def generate_text(prompt, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
27
  temperature = max(temperature, 1e-2)
28
  generate_kwargs = dict(
29
  temperature=temperature,
 
43
  st.error(f"Error generating text: {e}")
44
  return ""
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def main():
47
  st.title("Enhanced Waifu Character Generator")
48
 
 
66
  top_p = st.slider("Top-p (nucleus sampling)", 0.0, 1.0, 0.95, step=0.05)
67
  repetition_penalty = st.slider("Repetition penalty", 1.0, 2.0, 1.0, step=0.05)
68
 
69
+ # Initialize session state for generated text
70
  if "character_description" not in st.session_state:
71
  st.session_state.character_description = ""
72
  if "image_prompt" not in st.session_state:
73
  st.session_state.image_prompt = ""
 
 
74
 
75
  # Generate button
76
  if st.button("Generate Waifu"):
 
79
  image_prompt = format_prompt_for_image(name, hair_color, personality, outfit_style)
80
 
81
  # Generate character description
82
+ st.session_state.character_description = generate_text(description_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
83
 
84
  # Generate image prompt
85
+ st.session_state.image_prompt = generate_text(image_prompt, temperature, max_new_tokens, top_p, repetition_penalty)
 
 
 
86
 
87
  st.success("Waifu character generated!")
88
 
 
93
  if st.session_state.image_prompt:
94
  st.subheader("Image Prompt")
95
  st.write(st.session_state.image_prompt)
 
 
 
96
 
97
  if __name__ == "__main__":
98
  main()