prakhar patidar commited on
Commit
54c4cef
·
1 Parent(s): 85c84bf

updated app with presist history

Browse files
Files changed (1) hide show
  1. app.py +61 -9
app.py CHANGED
@@ -1,24 +1,76 @@
1
- import streamlit as st
 
2
 
 
 
3
  from diffusers import StableDiffusionPipeline
4
  import torch
5
 
6
  model_id = "runwayml/stable-diffusion-v1-5"
7
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
8
 
9
  # pipe = pipe.to("cuda")
10
 
11
- # prompt = "a photo of an astronaut riding a horse on mars"
12
- # image = pipe(prompt).images[0]
13
- # st.image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
15
 
16
  with st.form("tti_form"):
17
- st.write("Text to Image")
18
  prompt = st.text_input("Enter Prompt")
19
  # Every form must have a submit button.
20
  submitted = st.form_submit_button("Submit")
21
  if submitted:
22
- st.write("prompt", prompt)
23
- image = pipe(prompt).images[0]
24
- st.image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import json
3
 
4
+ import streamlit as st
5
+ from PIL import Image
6
  from diffusers import StableDiffusionPipeline
7
  import torch
8
 
9
  model_id = "runwayml/stable-diffusion-v1-5"
10
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
11
 
12
  # pipe = pipe.to("cuda")
13
 
14
+ st.title("Convert Text To Image :sunglasses:")
15
+
16
+
17
+ def load_metadata():
18
+ metadata = []
19
+ try:
20
+ with open("metadata.json", "r") as f:
21
+ metadata = json.load(f)
22
+ except:
23
+ print("json file doesn't exist")
24
+ return metadata
25
+
26
+
27
+ def update_metadata(data):
28
+ metadata = []
29
+ try:
30
+ with open("metadata.json", "r") as f:
31
+ metadata = json.load(f)
32
+ except:
33
+ print("json file doesn't exist")
34
+
35
+ metadata.append(data)
36
+
37
+ with open("metadata.json", "w") as json_file:
38
+ json.dump(metadata, json_file)
39
+
40
 
41
+ def render_metadata(data):
42
+ st.write("Previous prompt results (common for all)")
43
+ for d in data:
44
+ st.write("Prompt: ", d["prompt"])
45
+ st.image(Image.open(d["file_name"]))
46
+
47
+
48
+ metadata = load_metadata()
49
 
50
  with st.form("tti_form"):
 
51
  prompt = st.text_input("Enter Prompt")
52
  # Every form must have a submit button.
53
  submitted = st.form_submit_button("Submit")
54
  if submitted:
55
+ with st.spinner("Processing..."):
56
+ image_name = uuid.uuid4().hex + ".png"
57
+ st.write("prompt", prompt)
58
+ image = pipe(prompt).images[0]
59
+ # test mode - comment above and uncomment following
60
+ # image = Image.open("abcd.png")
61
+ pil_image = Image.fromarray(image)
62
+
63
+ st.image(image)
64
+ # save image
65
+ data = {"file_name": image_name, "prompt": prompt}
66
+ pil_image.save(image_name)
67
+
68
+ # with open(image_name, 'wb') as f:
69
+ # f.write(image)
70
+
71
+ update_metadata(data)
72
+
73
+ st.success("Image generated!")
74
+
75
+
76
+ render_metadata(metadata)