BarBar288 commited on
Commit
5cda4a7
·
verified ·
1 Parent(s): e8ec0a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -16
app.py CHANGED
@@ -54,12 +54,12 @@ text_to_speech_pipelines = {}
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  logger.info(f"Device set to use {device}")
56
 
57
- visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa", device=device)
58
- document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", device=device)
59
- image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224", device=device)
60
- object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50", device=device)
61
- video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400", device=device)
62
- summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
63
 
64
  # Load speaker embeddings for text-to-audio
65
  def load_speaker_embeddings(model_name):
@@ -67,26 +67,34 @@ def load_speaker_embeddings(model_name):
67
  logger.info("Loading speaker embeddings for SpeechT5")
68
  from datasets import load_dataset
69
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
70
- speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0).to(device) # Example speaker
71
  return speaker_embeddings
72
  return None
73
 
74
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
75
  try:
76
- text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0", device=device)
77
  except ValueError as e:
78
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
79
  logger.info("Falling back to a different text-to-audio model.")
80
- text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts", device=device)
81
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
82
 
83
- audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base", device=device)
84
 
85
  def load_conversational_model(model_name):
86
  if model_name not in conversational_models_loaded:
87
  logger.info(f"Loading conversational model: {model_name}")
88
- tokenizer = AutoTokenizer.from_pretrained(conversational_models[model_name], use_auth_token=read_token)
89
- model = AutoModelForCausalLM.from_pretrained(conversational_models[model_name], use_auth_token=read_token).to(device)
 
 
 
 
 
 
 
 
90
  conversational_tokenizers[model_name] = tokenizer
91
  conversational_models_loaded[model_name] = model
92
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
@@ -95,7 +103,7 @@ def chat(model_name, user_input, history=[]):
95
  tokenizer, model = load_conversational_model(model_name)
96
 
97
  # Encode the input
98
- input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt").to(device)
99
 
100
  # Generate a response
101
  with torch.no_grad():
@@ -115,8 +123,11 @@ def generate_image(model_name, prompt):
115
  if model_name not in text_to_image_pipelines:
116
  logger.info(f"Loading text-to-image model: {model_name}")
117
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
118
- text_to_image_models[model_name], use_auth_token=read_token, torch_dtype=torch.float16, device_map="auto"
119
- ).to(device)
 
 
 
120
  pipeline = text_to_image_pipelines[model_name]
121
  image = pipeline(prompt).images[0]
122
  return image
@@ -125,7 +136,10 @@ def generate_speech(model_name, text):
125
  if model_name not in text_to_speech_pipelines:
126
  logger.info(f"Loading text-to-speech model: {model_name}")
127
  text_to_speech_pipelines[model_name] = pipeline(
128
- "text-to-speech", model=text_to_speech_models[model_name], use_auth_token=read_token, device=device
 
 
 
129
  )
130
  pipeline = text_to_speech_pipelines[model_name]
131
  audio = pipeline(text, speaker_embeddings=speaker_embeddings)
 
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  logger.info(f"Device set to use {device}")
56
 
57
+ visual_qa_pipeline = pipeline("visual-question-answering", model="dandelin/vilt-b32-finetuned-vqa")
58
+ document_qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
59
+ image_classification_pipeline = pipeline("image-classification", model="facebook/deit-base-distilled-patch16-224")
60
+ object_detection_pipeline = pipeline("object-detection", model="facebook/detr-resnet-50")
61
+ video_classification_pipeline = pipeline("video-classification", model="facebook/timesformer-base-finetuned-k400")
62
+ summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
63
 
64
  # Load speaker embeddings for text-to-audio
65
  def load_speaker_embeddings(model_name):
 
67
  logger.info("Loading speaker embeddings for SpeechT5")
68
  from datasets import load_dataset
69
  dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
70
+ speaker_embeddings = torch.tensor(dataset[7306]["xvector"]).unsqueeze(0) # Example speaker
71
  return speaker_embeddings
72
  return None
73
 
74
  # Use a different model for text-to-audio if stabilityai/stable-audio-open-1.0 is not supported
75
  try:
76
+ text_to_audio_pipeline = pipeline("text-to-audio", model="stabilityai/stable-audio-open-1.0")
77
  except ValueError as e:
78
  logger.error(f"Error loading stabilityai/stable-audio-open-1.0: {e}")
79
  logger.info("Falling back to a different text-to-audio model.")
80
+ text_to_audio_pipeline = pipeline("text-to-audio", model="microsoft/speecht5_tts")
81
  speaker_embeddings = load_speaker_embeddings("microsoft/speecht5_tts")
82
 
83
+ audio_classification_pipeline = pipeline("audio-classification", model="facebook/wav2vec2-base")
84
 
85
  def load_conversational_model(model_name):
86
  if model_name not in conversational_models_loaded:
87
  logger.info(f"Loading conversational model: {model_name}")
88
+ tokenizer = AutoTokenizer.from_pretrained(
89
+ conversational_models[model_name],
90
+ use_auth_token=read_token,
91
+ trust_remote_code=True
92
+ )
93
+ model = AutoModelForCausalLM.from_pretrained(
94
+ conversational_models[model_name],
95
+ use_auth_token=read_token,
96
+ trust_remote_code=True
97
+ )
98
  conversational_tokenizers[model_name] = tokenizer
99
  conversational_models_loaded[model_name] = model
100
  return conversational_tokenizers[model_name], conversational_models_loaded[model_name]
 
103
  tokenizer, model = load_conversational_model(model_name)
104
 
105
  # Encode the input
106
+ input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors="pt")
107
 
108
  # Generate a response
109
  with torch.no_grad():
 
123
  if model_name not in text_to_image_pipelines:
124
  logger.info(f"Loading text-to-image model: {model_name}")
125
  text_to_image_pipelines[model_name] = StableDiffusionPipeline.from_pretrained(
126
+ text_to_image_models[model_name],
127
+ use_auth_token=read_token,
128
+ torch_dtype=torch.float16,
129
+ device_map="auto"
130
+ )
131
  pipeline = text_to_image_pipelines[model_name]
132
  image = pipeline(prompt).images[0]
133
  return image
 
136
  if model_name not in text_to_speech_pipelines:
137
  logger.info(f"Loading text-to-speech model: {model_name}")
138
  text_to_speech_pipelines[model_name] = pipeline(
139
+ "text-to-speech",
140
+ model=text_to_speech_models[model_name],
141
+ use_auth_token=read_token,
142
+ device=device
143
  )
144
  pipeline = text_to_speech_pipelines[model_name]
145
  audio = pipeline(text, speaker_embeddings=speaker_embeddings)