Monke64 commited on
Commit
a08179c
·
1 Parent(s): 481c7ae

Added image model

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -2,13 +2,19 @@ import streamlit as st
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
- #from diffusers import StableDiffusionPipeline
 
6
  import torch
7
  import os
8
 
9
  emo_list = []
10
  gen_list = []
11
  tempo_list = []
 
 
 
 
 
12
  @st.cache_resource
13
  def load_emo_model():
14
  emo_service = _Emotion_spotting_service("flask/emotion_model.h5")
@@ -23,11 +29,11 @@ def load_beat_model():
23
  beat_service = _Beat_tracking_service()
24
  return beat_service
25
 
26
- # @st.cache_resource
27
- # def load_image_model():
28
- # pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
29
- # pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
30
- # return pipeline
31
 
32
 
33
  if 'emotion' not in st.session_state:
@@ -42,7 +48,7 @@ if 'beat' not in st.session_state:
42
  emotion_service = load_emo_model()
43
  genre_service = load_genre_model()
44
  beat_service = load_beat_model()
45
- # image_service = load_image_model()
46
 
47
  st.title("Music2Image webpage")
48
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
@@ -76,7 +82,7 @@ if st.session_state.emotion != None and st.session_state.genre != None and st.se
76
  st.caption("Text description of your music file")
77
  text_output = "A scenic image that describes a " + speed + " pace with a feeling of" + st.session_state.emotion + "."
78
  st.text(text_output)
79
- # if text_output:
80
- # if st.button("Generate image from text description"):
81
- # image = image_service(text_output)
82
- # st.image(image)
 
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
+ from diffusers import StableDiffusionPipeline
6
+ import tensorflow as tf
7
  import torch
8
  import os
9
 
10
  emo_list = []
11
  gen_list = []
12
  tempo_list = []
13
+
14
+ physical_devices = tf.config.experimental.list_physical_devices('GPU')
15
+ if len(physical_devices) > 0:
16
+ tf.config.experimental.set_memory_growth(physical_devices[0], True)
17
+
18
  @st.cache_resource
19
  def load_emo_model():
20
  emo_service = _Emotion_spotting_service("flask/emotion_model.h5")
 
29
  beat_service = _Beat_tracking_service()
30
  return beat_service
31
 
32
+ @st.cache_resource
33
+ def load_image_model():
34
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
35
+ pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
36
+ return pipeline
37
 
38
 
39
  if 'emotion' not in st.session_state:
 
48
  emotion_service = load_emo_model()
49
  genre_service = load_genre_model()
50
  beat_service = load_beat_model()
51
+ image_service = load_image_model()
52
 
53
  st.title("Music2Image webpage")
54
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
 
82
  st.caption("Text description of your music file")
83
  text_output = "A scenic image that describes a " + speed + " pace with a feeling of" + st.session_state.emotion + "."
84
  st.text(text_output)
85
+ if text_output:
86
+ if st.button("Generate image from text description"):
87
+ image = image_service(text_output)
88
+ st.image(image)