BarBar288 commited on
Commit
e8ec0a6
·
verified ·
1 Parent(s): b276854

Update app.py

Browse files

Add lazy loading for models to prevent overuse of resources.

Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -51,12 +51,15 @@ text_to_image_pipelines = {}
51
  text_to_speech_pipelines = {}
52
 
53
  # Initialize pipelines for other tasks
54
- visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
55
- document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
56
- image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224")
57
- object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50")
58
- video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400")
59
- summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
 
 
 
60
 
61
  # Load speaker embeddings for text-to-audio
62
  def load_speaker_embeddings(model_name):
@@ -64,26 +67,26 @@ def load_speaker_embeddings(model_name):
64
  logger.info("Loading speaker embeddings for SpeechT5")
65
  from datasets import load_dataset
66
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
67
- speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0).to() # Example speaker
68
  return speaker_embeddings
69
  return None
70
 
71
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
72
  try:
73
- text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0")
74
  except ValueError as e:
75
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
76
  logger.info("Falling back to a different text-to-audio model.")
77
- text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts")
78
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
79
 
80
- audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base")
81
 
82
  def load_conversational_model(model_name):
83
  if model_name not in conversational_models_loaded:
84
  logger.info(f"Loading conversational model: {model_name}")
85
  tokenizer = AutoTokenizer.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
86
- model = AutoModelForCausalLM.from_pretrained(conversational_models[model_name], use_auth_token=read_token).to()
87
  conversational_tokenizers[model_name] = tokenizer
88
  conversational_models_loaded[model_name] = model
89
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
@@ -92,7 +95,7 @@ def chat(model_name, user_input, history=[]):
92
  tokenizer, model = load_conversational_model(model_name)
93
 
94
  # Encode the input
95
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to()
96
 
97
  # Generate a response
98
  with torch.no_grad():
@@ -112,8 +115,8 @@ def generate_image(model_name, prompt):
112
  if model_name not in text_to_image_pipelines:
113
  logger.info(f"Loading text-to-image model: {model_name}")
114
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
115
- text_to_image_models[model_name], use_auth_token=read_token, torch_dtype=torch.float16, _map="auto"
116
- )
117
  pipeline = text_to_image_pipelines[model_name]
118
  image = pipeline(prompt).images[0]
119
  return image
@@ -122,7 +125,7 @@ def generate_speech(model_name, text):
122
  if model_name not in text_to_speech_pipelines:
123
  logger.info(f"Loading text-to-speech model: {model_name}")
124
  text_to_speech_pipelines[model_name] = pipeline(
125
- "text-to-speech", model=text_to_speech_models[model_name], use_auth_token=read_token
126
  )
127
  pipeline = text_to_speech_pipelines[model_name]
128
  audio = pipeline(text, speaker_embeddings=speaker_embeddings)
 
51
  text_to_speech_pipelines = {}
52
 
53
  # Initialize pipelines for other tasks
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ logger.info(f"Device set to use {device}")
56
+
57
+ visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", device=device)
58
+ document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device)
59
+ image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224", device=device)
60
+ object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50", device=device)
61
+ video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400", device=device)
62
+ summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
63
 
64
  # Load speaker embeddings for text-to-audio
65
  def load_speaker_embeddings(model_name):
 
67
  logger.info("Loading speaker embeddings for SpeechT5")
68
  from datasets import load_dataset
69
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
70
+ speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0).to(device) # Example speaker
71
  return speaker_embeddings
72
  return None
73
 
74
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
75
  try:
76
+ text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0", device=device)
77
  except ValueError as e:
78
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
79
  logger.info("Falling back to a different text-to-audio model.")
80
+ text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts", device=device)
81
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
82
 
83
+ audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base", device=device)
84
 
85
  def load_conversational_model(model_name):
86
  if model_name not in conversational_models_loaded:
87
  logger.info(f"Loading conversational model: {model_name}")
88
  tokenizer = AutoTokenizer.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
89
+ model = AutoModelForCausalLM.from_pretrained(conversational_models[model_name], use_auth_token=read_token).to(device)
90
  conversational_tokenizers[model_name] = tokenizer
91
  conversational_models_loaded[model_name] = model
92
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
 
95
  tokenizer, model = load_conversational_model(model_name)
96
 
97
  # Encode the input
98
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
99
 
100
  # Generate a response
101
  with torch.no_grad():
 
115
  if model_name not in text_to_image_pipelines:
116
  logger.info(f"Loading text-to-image model: {model_name}")
117
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
118
+ text_to_image_models[model_name], use_auth_token=read_token, torch_dtype=torch.float16, device_map="auto"
119
+ ).to(device)
120
  pipeline = text_to_image_pipelines[model_name]
121
  image = pipeline(prompt).images[0]
122
  return image
 
125
  if model_name not in text_to_speech_pipelines:
126
  logger.info(f"Loading text-to-speech model: {model_name}")
127
  text_to_speech_pipelines[model_name] = pipeline(
128
+ "text-to-speech", model=text_to_speech_models[model_name], use_auth_token=read_token, device=device
129
  )
130
  pipeline = text_to_speech_pipelines[model_name]
131
  audio = pipeline(text, speaker_embeddings=speaker_embeddings)