Monke64 commited on
Commit
5685b90
·
1 Parent(s): d203a00

Removed image model

Browse files
Files changed (1) hide show
  1. app.py +11 -12
app.py CHANGED
@@ -2,11 +2,10 @@ import streamlit as st
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
- from diffusers import StableDiffusionPipeline
6
  import torch
7
  import os
8
 
9
- os.environ['CUDA_VISIBLE_DEVICES'] = '4'
10
  emo_list = []
11
  gen_list = []
12
  tempo_list = []
@@ -24,11 +23,11 @@ def load_beat_model():
24
  beat_service = _Beat_tracking_service()
25
  return beat_service
26
 
27
- @st.cache_resource
28
- def load_image_model():
29
- pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
30
- pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
31
- return pipeline
32
 
33
 
34
  if 'emotion' not in st.session_state:
@@ -43,7 +42,7 @@ if 'beat' not in st.session_state:
43
  emotion_service = load_emo_model()
44
  genre_service = load_genre_model()
45
  beat_service = load_beat_model()
46
- image_service = load_image_model()
47
 
48
  st.title("Music2Image webpage")
49
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
@@ -73,7 +72,7 @@ if st.session_state.emotion != None and st.session_state.genre != None and st.se
73
  st.caption("Text description of your music file")
74
  text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "."
75
  st.text(text_output)
76
- if text_output:
77
- if st.button("Generate image from text description"):
78
- image = image_service(text_output)
79
- st.image(image)
 
2
  from flask.Emotion_spotting_service import _Emotion_spotting_service
3
  from flask.Genre_spotting_service import _Genre_spotting_service
4
  from flask.Beat_tracking_service import _Beat_tracking_service
5
+ #from diffusers import StableDiffusionPipeline
6
  import torch
7
  import os
8
 
 
9
  emo_list = []
10
  gen_list = []
11
  tempo_list = []
 
23
  beat_service = _Beat_tracking_service()
24
  return beat_service
25
 
26
+ # @st.cache_resource
27
+ # def load_image_model():
28
+ # pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
29
+ # pipeline.load_lora_weights("Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
30
+ # return pipeline
31
 
32
 
33
  if 'emotion' not in st.session_state:
 
42
  emotion_service = load_emo_model()
43
  genre_service = load_genre_model()
44
  beat_service = load_beat_model()
45
+ # image_service = load_image_model()
46
 
47
  st.title("Music2Image webpage")
48
  user_input = st.file_uploader("Upload your wav/mp3 files here", type=["wav","mp3"],key = "file_uploader")
 
72
  st.caption("Text description of your music file")
73
  text_output = "This piece of music falls under the " + st.session_state.genre[0] + " genre. It is of tempo " + str(int(st.session_state.beat)) + " and evokes a sense of" + st.session_state.emotion + "."
74
  st.text(text_output)
75
+ # if text_output:
76
+ # if st.button("Generate image from text description"):
77
+ # image = image_service(text_output)
78
+ # st.image(image)