Monke64 commited on
Commit
1c1af26
·
1 Parent(s): 9e17df9

Added image model

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
- #from diffusers import StableDiffusionPipeline
6
  import torch
7
 
8
  emo_list = []
@@ -22,11 +22,11 @@ def load_beat_model():
22
  beat_service = _Beat_tracking_service()
23
  return beat_service
24
 
25
- # @st.cache_resource
26
- # def load_image_model():
27
- # pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
28
- # pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
29
- # return pipeline
30
 
31
 
32
  if 'emotion' not in st.session_state:
@@ -41,7 +41,7 @@ if 'beat' not in st.session_state:
41
  emotion_service = load_emo_model()
42
  genre_service = load_genre_model()
43
  beat_service = load_beat_model()
44
- # image_service = load_image_model()
45
 
46
  st.title("Music2Image webpage")
47
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
@@ -71,7 +71,7 @@ if st.session_state.emotion != None and st.session_state.genre != None and st.se
71
  st.caption("Text description of your music file")
72
  text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "."
73
  st.text(text_output)
74
- #if text_output:
75
- # if st.button("Generate image from text description"):
76
- #image = image_service(text_output)
77
- #st.image(image)
 
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
+ from diffusers import StableDiffusionPipeline
6
  import torch
7
 
8
  emo_list = []
 
22
  beat_service = _Beat_tracking_service()
23
  return beat_service
24
 
25
+ @st.cache_resource
26
+ def load_image_model():
27
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
28
+ pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
29
+ return pipeline
30
 
31
 
32
  if 'emotion' not in st.session_state:
 
41
  emotion_service = load_emo_model()
42
  genre_service = load_genre_model()
43
  beat_service = load_beat_model()
44
+ image_service = load_image_model()
45
 
46
  st.title("Music2Image webpage")
47
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
 
71
  st.caption("Text description of your music file")
72
  text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "."
73
  st.text(text_output)
74
+ if text_output:
75
+ if st.button("Generate image from text description"):
76
+ image = image_service(text_output)
77
+ st.image(image)