nightey3s commited on
Commit
984bc80
·
unverified ·
1 Parent(s): ee2109f

Fix compatability for ZeroGPU

Browse files
Files changed (1) hide show
  1. profanity_detector.py +60 -34
profanity_detector.py CHANGED
@@ -76,53 +76,79 @@ def load_models():
76
  PROFANITY_MODEL = "parsawar/profanity_model_3.1"
77
  profanity_tokenizer = AutoTokenizer.from_pretrained(PROFANITY_MODEL)
78
 
79
- # Load model with memory optimization using half-precision
80
- profanity_model = AutoModelForSequenceClassification.from_pretrained(PROFANITY_MODEL)
81
-
82
- # Only move to device for local runs
83
- if not IS_ZEROGPU and torch.cuda.is_available():
84
- profanity_model = profanity_model.to(device)
85
- try:
86
- profanity_model = profanity_model.half()
87
- logger.info("Successfully converted profanity model to half precision")
88
- except Exception as e:
89
- logger.warning(f"Could not convert to half precision: {str(e)}")
 
 
 
 
 
 
 
 
90
 
 
91
  logger.info("Loading detoxification model...")
92
  T5_MODEL = "s-nlp/t5-paranmt-detox"
93
  t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL)
94
 
95
- # Load model with memory optimization
96
- t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL)
97
-
98
- # Move to GPU if available and optimize with half-precision where possible
99
- if not IS_ZEROGPU and torch.cuda.is_available():
100
- t5_model = t5_model.to(device)
101
- # Convert to half precision to save memory (if possible)
102
- try:
103
- t5_model = t5_model.half() # Convert to FP16
104
- logger.info("Successfully converted T5 model to half precision")
105
- except Exception as e:
106
- logger.warning(f"Could not convert to half precision: {str(e)}")
 
 
 
107
 
108
  logger.info("Loading Whisper speech-to-text model...")
109
- whisper_model = whisper.load_model("large")
110
- if not IS_ZEROGPU and torch.cuda.is_available():
111
- whisper_model = whisper_model.to(device)
 
 
 
 
112
 
113
  logger.info("Loading Text-to-Speech model...")
114
  TTS_MODEL = "microsoft/speecht5_tts"
115
  tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL)
116
- # Load TTS models without automatic device mapping
117
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL)
118
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
119
 
120
- # Move models to appropriate device
121
- if not IS_ZEROGPU and torch.cuda.is_available():
122
- tts_model = tts_model.to(device)
123
- vocoder = vocoder.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Speaker embeddings for TTS
126
  speaker_embeddings = torch.zeros((1, 512))
127
  if not IS_ZEROGPU and torch.cuda.is_available():
128
  speaker_embeddings = speaker_embeddings.to(device)
 
76
  PROFANITY_MODEL = "parsawar/profanity_model_3.1"
77
  profanity_tokenizer = AutoTokenizer.from_pretrained(PROFANITY_MODEL)
78
 
79
+ # Load model without moving to CUDA directly
80
+ if IS_ZEROGPU:
81
+ logger.info("ZeroGPU mode: Loading model without CUDA initialization")
82
+ # For ZeroGPU, use device_map='auto' or just stay on CPU
83
+ profanity_model = AutoModelForSequenceClassification.from_pretrained(
84
+ PROFANITY_MODEL,
85
+ device_map=None, # Explicitly stay on CPU
86
+ low_cpu_mem_usage=True
87
+ )
88
+ else:
89
+ # For local runs, normal loading with CUDA if available
90
+ profanity_model = AutoModelForSequenceClassification.from_pretrained(PROFANITY_MODEL)
91
+ if torch.cuda.is_available():
92
+ profanity_model = profanity_model.to(device)
93
+ try:
94
+ profanity_model = profanity_model.half()
95
+ logger.info("Successfully converted profanity model to half precision")
96
+ except Exception as e:
97
+ logger.warning(f"Could not convert to half precision: {str(e)}")
98
 
99
+ # Apply similar changes to all other model loading...
100
  logger.info("Loading detoxification model...")
101
  T5_MODEL = "s-nlp/t5-paranmt-detox"
102
  t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL)
103
 
104
+ if IS_ZEROGPU:
105
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(
106
+ T5_MODEL,
107
+ device_map=None,
108
+ low_cpu_mem_usage=True
109
+ )
110
+ else:
111
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL)
112
+ if torch.cuda.is_available():
113
+ t5_model = t5_model.to(device)
114
+ try:
115
+ t5_model = t5_model.half()
116
+ logger.info("Successfully converted T5 model to half precision")
117
+ except Exception as e:
118
+ logger.warning(f"Could not convert to half precision: {str(e)}")
119
 
120
  logger.info("Loading Whisper speech-to-text model...")
121
+ if IS_ZEROGPU:
122
+ # For ZeroGPU, stay on CPU in the main process
123
+ whisper_model = whisper.load_model("medium", device="cpu")
124
+ else:
125
+ whisper_model = whisper.load_model("large")
126
+ if torch.cuda.is_available():
127
+ whisper_model = whisper_model.to(device)
128
 
129
  logger.info("Loading Text-to-Speech model...")
130
  TTS_MODEL = "microsoft/speecht5_tts"
131
  tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL)
 
 
 
132
 
133
+ if IS_ZEROGPU:
134
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(
135
+ TTS_MODEL,
136
+ device_map=None,
137
+ low_cpu_mem_usage=True
138
+ )
139
+ vocoder = SpeechT5HifiGan.from_pretrained(
140
+ "microsoft/speecht5_hifigan",
141
+ device_map=None,
142
+ low_cpu_mem_usage=True
143
+ )
144
+ else:
145
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL)
146
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
147
+ if torch.cuda.is_available():
148
+ tts_model = tts_model.to(device)
149
+ vocoder = vocoder.to(device)
150
 
151
+ # Speaker embeddings - always on CPU for ZeroGPU
152
  speaker_embeddings = torch.zeros((1, 512))
153
  if not IS_ZEROGPU and torch.cuda.is_available():
154
  speaker_embeddings = speaker_embeddings.to(device)