LordFarquaad42 commited on
Commit
40041d3
·
verified ·
1 Parent(s): 29c957d

use inference api

Browse files
Files changed (1) hide show
  1. app.py +17 -39
app.py CHANGED
@@ -2,48 +2,26 @@
2
  import streamlit as st
3
  import torch
4
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
 
 
 
 
 
 
5
 
6
- def load_pipelines(device, dtype):
7
- prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype).to(device)
8
- decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype).to(device)
9
- return prior, decoder
10
-
11
- def generate_images(prompt, negative_prompt, num_images_per_prompt, device, dtype):
12
- with torch.cuda.amp.autocast(dtype=dtype):
13
- prior_output = prior(
14
- prompt=prompt,
15
- height=1024,
16
- width=1024,
17
- negative_prompt=negative_prompt,
18
- guidance_scale=4.0,
19
- num_images_per_prompt=num_images_per_prompt,
20
- )
21
- decoder_output = decoder(
22
- image_embeddings=prior_output.image_embeddings,
23
- prompt=prompt,
24
- negative_prompt=negative_prompt,
25
- guidance_scale=0.0,
26
- output_type="pil",
27
- )
28
- return decoder_output.images
29
-
30
 
31
  st.title("Image Generator with Diffusers")
32
 
33
- @st.cache(allow_output_mutation=True)
34
- def init_model():
35
- device = "cpu"
36
- dtype = torch.bfloat16
37
- return load_pipelines(device, dtype), device, dtype
38
-
39
- (prior, decoder), device, dtype = init_model()
40
-
41
- prompt = st.text_input("Enter a prompt:", "Anthropomorphic cat dressed as a pilot")
42
- negative_prompt = st.text_input("Enter a negative prompt:", "")
43
- num_images_per_prompt = st.slider("Number of images per prompt:", 1, 5, 2)
44
 
 
45
 
46
- if st.button("Generate"):
47
- images = generate_images(prompt, negative_prompt, num_images_per_prompt, device, dtype)
48
- for img in images:
49
- st.image(img, use_column_width=True)
 
 
2
  import streamlit as st
3
  import torch
4
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
5
+ import requests
6
+ import io
7
+ from PIL import Image
8
+ import os
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
 
12
+ API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-cascade-prior"
13
+ headers = {"Authorization": f"Bearer ${os.getenv('bearer_token')}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  st.title("Image Generator with Diffusers")
16
 
17
+ def query(payload):
18
+ response = requests.post(API_URL, headers=headers, json=payload)
19
+ return response.content
 
 
 
 
 
 
 
 
20
 
21
+ prompt = st.text_input("Enter a prompt:", "batman hitting the griddy in gotham")
22
 
23
+ image_bytes = query({
24
+ "inputs": prompt,
25
+ })
26
+ # You can access the image with PIL.Image for example
27
+ image = Image.open(io.BytesIO(image_bytes))