tugot17 commited on
Commit
1d906b4
·
1 Parent(s): 819a5f1

Upload 4 files

Browse files
Files changed (3) hide show
  1. img_gen_v2.py +72 -0
  2. requirements.txt +6 -0
  3. streamlit_app.py +95 -0
img_gen_v2.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from diffusers import StableDiffusionImg2ImgPipeline, \
4
+ StableDiffusionPipeline
5
+
6
+
7
+ def check_cuda_device():
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+ return device
10
+
11
+
12
+ def get_the_model(device=None):
13
+ model_id = "stabilityai/stable-diffusion-2"
14
+ # if path:
15
+ # pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
16
+ # else:
17
+ pipe = StableDiffusionPipeline.from_pretrained(model_id,
18
+ torch_dtype=torch.float16)
19
+ if device:
20
+ pipe.to(device)
21
+ else:
22
+ device = check_cuda_device()
23
+ pipe.to(device)
24
+
25
+ return pipe
26
+
27
+
28
+ def get_image_to_image_model(path=None, device=None):
29
+ model_id = "stabilityai/stable-diffusion-2"
30
+ if path:
31
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
32
+ path,
33
+ torch_dtype=torch.float16)
34
+ else:
35
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.float16)
38
+ if device:
39
+ if device == "cuda" or device == "cpu":
40
+ pipe.to(device)
41
+ else:
42
+ device = check_cuda_device()
43
+ pipe.to(device)
44
+
45
+ return pipe
46
+
47
+
48
+ def gen_initial_img(int_prompt):
49
+ # image = get_the_model(num_inference_steps=100).images[0]
50
+ model = get_the_model(None)
51
+ image = model(int_prompt, num_inference_steps=100).images[0]
52
+
53
+ return image
54
+
55
+
56
+ def generate_story(int_prompt, steps, iterations=100):
57
+ image_dic = {}
58
+ init_img = gen_initial_img(int_prompt)
59
+ img2img_model = get_image_to_image_model()
60
+
61
+ img = init_img
62
+
63
+ for idx, step in enumerate(steps):
64
+ image = img2img_model(prompt=step, image=img, strength=0.75, guidance_scale=7.5,
65
+ num_inference_steps=iterations).images[0]
66
+ image_dic[idx] = {
67
+ "image": image,
68
+ "prompt": step
69
+ }
70
+ img = image
71
+
72
+ return init_img, image_dic
requirements.txt CHANGED
@@ -3,3 +3,9 @@ langchain==0.0.153
3
  openai==0.27.5
4
  anthropic==0.2.7
5
  python-dotenv==1.0.0
 
 
 
 
 
 
 
3
  openai==0.27.5
4
  anthropic==0.2.7
5
  python-dotenv==1.0.0
6
+ gTTS==2.3.2
7
+ torch==2.0.0
8
+ diffusers==0.16.1
9
+ transformers
10
+ ftfy
11
+ accelerate
streamlit_app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ from gtts import gTTS
6
+
7
+ from img_gen_v2 import generate_story
8
+ from prompt_generation import pipeline
9
+
10
+
11
+ # Function to create the page navigation
12
+ def page_navigation(current_page):
13
+ col1, col2, col3 = st.columns(3)
14
+
15
+ if current_page > 0:
16
+ with col1:
17
+ if st.button('<< Previous'):
18
+ current_page -= 1
19
+
20
+ with col2:
21
+ st.write(f'Step {current_page} of 10')
22
+
23
+ if current_page < 10:
24
+ with col3:
25
+ if st.button('Next >>'):
26
+ if current_page == 0:
27
+ user_input = st.session_state.user_input
28
+ prompt_response = pipeline(user_input, 10)
29
+ steps = prompt_response.get("steps")
30
+ init_prompt = prompt_response.get("story")
31
+
32
+ init_img, img_dict = generate_story(init_prompt, steps)
33
+
34
+ st.session_state.pipeline_response = prompt_response
35
+ st.session_state.init_img = init_img
36
+ st.session_state.img_dict = img_dict
37
+
38
+ current_page += 1
39
+
40
+ return current_page
41
+
42
+
43
+ # Main function to display the pages
44
+ def get_pipeline_data(page_number):
45
+ pipeline_response = st.session_state.pipeline_response
46
+ text_output = pipeline_response.get("steps")[page_number - 1]
47
+
48
+ # random_img = f"https://picsum.photos/800/600?random={page_number}"
49
+ # response = requests.get(random_img)
50
+ # image = Image.open(BytesIO(response.content))
51
+ img_dict = st.session_state.img_dict
52
+ img = img_dict[page_number-1]
53
+
54
+ return {"text_output": text_output, "image_obj": img}
55
+
56
+
57
+ def main():
58
+ st.set_page_config(page_title="Narrative chat", layout="wide")
59
+ st.title("DreamBot")
60
+
61
+ # Initialize the current page
62
+ current_page = st.session_state.get('current_page', 0)
63
+
64
+ # Display content for each page
65
+ if current_page == 0:
66
+ st.write("Tell me what story you would like me to tell:")
67
+ user_input = st.text_area("")
68
+ st.session_state.user_input = user_input
69
+
70
+ else:
71
+ # Retrieve data from random generators
72
+ data = get_pipeline_data(current_page)
73
+ text_output = data.get('text_output', '')
74
+ image = data.get('image_obj', '')
75
+
76
+ # Display text output
77
+ st.write(text_output)
78
+
79
+ tts = gTTS(text_output)
80
+ tts.save('audio.mp3')
81
+ st.audio('audio.mp3')
82
+
83
+ # Display image output
84
+ if image:
85
+ st.image(image, use_column_width=False, width=400)
86
+
87
+ # Display page navigation
88
+ current_page = page_navigation(current_page)
89
+
90
+ st.write('current_page:', current_page)
91
+ st.session_state.current_page = current_page
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()