Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,30 +5,31 @@ from diffusers import DiffusionPipeline
|
|
5 |
from transformers import pipeline
|
6 |
|
7 |
# Load text generation pipeline
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def extend_prompt(prompt):
|
11 |
return text_pipe(prompt + ',', num_return_sequences=1, max_new_tokens=50)[0]["generated_text"]
|
12 |
|
13 |
@st.cache_resource
|
14 |
-
def load_pipeline(
|
15 |
-
device = "cuda" if use_cuda and torch.cuda.is_available() else "cpu"
|
16 |
pipe = DiffusionPipeline.from_pretrained(
|
17 |
"stabilityai/sdxl-turbo",
|
18 |
-
torch_dtype=torch.
|
19 |
-
variant=
|
20 |
use_safetensors=True
|
21 |
)
|
22 |
-
|
23 |
-
pipe.enable_xformers_memory_efficient_attention()
|
24 |
-
pipe.to(device)
|
25 |
return pipe
|
26 |
|
27 |
-
def generate_image(prompt, use_details
|
28 |
-
pipe = load_pipeline(
|
29 |
generator = torch.manual_seed(np.random.randint(0, 2**32))
|
30 |
extended_prompt = extend_prompt(prompt) if use_details else prompt
|
31 |
-
image = pipe(prompt=extended_prompt, generator=generator, num_inference_steps=
|
32 |
return image, extended_prompt
|
33 |
|
34 |
# Add the custom CSS file
|
@@ -41,11 +42,10 @@ st.markdown("<div class='subheader'>Create an anime-style GitHub profile picture
|
|
41 |
input_text = st.text_area("Describe your GitHub profile picture:", "Create an anime-style GitHub profile picture for a boy")
|
42 |
|
43 |
details_checkbox = st.checkbox("Generate Details?", value=True)
|
44 |
-
cuda_checkbox = st.checkbox("Use CUDA?", value=False)
|
45 |
|
46 |
if st.button("Generate Image"):
|
47 |
with st.spinner('Generating image...'):
|
48 |
-
image, extended_prompt = generate_image(input_text, details_checkbox
|
49 |
st.image(image, caption="Generated Image")
|
50 |
st.text(f"Extended Prompt: {extended_prompt}")
|
51 |
st.balloons()
|
|
|
5 |
from transformers import pipeline
|
6 |
|
7 |
# Load text generation pipeline
|
8 |
+
@st.cache_resource
|
9 |
+
def load_text_pipeline():
|
10 |
+
return pipeline('text-generation', model='daspartho/prompt-extend')
|
11 |
+
|
12 |
+
text_pipe = load_text_pipeline()
|
13 |
|
14 |
def extend_prompt(prompt):
|
15 |
return text_pipe(prompt + ',', num_return_sequences=1, max_new_tokens=50)[0]["generated_text"]
|
16 |
|
17 |
@st.cache_resource
|
18 |
+
def load_pipeline():
|
|
|
19 |
pipe = DiffusionPipeline.from_pretrained(
|
20 |
"stabilityai/sdxl-turbo",
|
21 |
+
torch_dtype=torch.float32,
|
22 |
+
variant=None,
|
23 |
use_safetensors=True
|
24 |
)
|
25 |
+
pipe.to("cpu")
|
|
|
|
|
26 |
return pipe
|
27 |
|
28 |
+
def generate_image(prompt, use_details):
|
29 |
+
pipe = load_pipeline()
|
30 |
generator = torch.manual_seed(np.random.randint(0, 2**32))
|
31 |
extended_prompt = extend_prompt(prompt) if use_details else prompt
|
32 |
+
image = pipe(prompt=extended_prompt, generator=generator, num_inference_steps=15, guidance_scale=7.5).images[0]
|
33 |
return image, extended_prompt
|
34 |
|
35 |
# Add the custom CSS file
|
|
|
42 |
input_text = st.text_area("Describe your GitHub profile picture:", "Create an anime-style GitHub profile picture for a boy")
|
43 |
|
44 |
details_checkbox = st.checkbox("Generate Details?", value=True)
|
|
|
45 |
|
46 |
if st.button("Generate Image"):
|
47 |
with st.spinner('Generating image...'):
|
48 |
+
image, extended_prompt = generate_image(input_text, details_checkbox)
|
49 |
st.image(image, caption="Generated Image")
|
50 |
st.text(f"Extended Prompt: {extended_prompt}")
|
51 |
st.balloons()
|