Spaces:
Running
on
Zero
Running
on
Zero
Use HF if CUDA available and not persistent model load
Browse files
app.py
CHANGED
@@ -34,48 +34,35 @@ def get_file_hash(file_path):
|
|
34 |
hash_md5.update(chunk)
|
35 |
return hash_md5.hexdigest()
|
36 |
|
37 |
-
def
|
38 |
model_config = MODEL_INFO[model]
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
60 |
-
except Exception as e:
|
61 |
-
print(f"Error: {e}")
|
62 |
-
return None
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
"""Get cached interface instance for the model."""
|
67 |
model = MODELS[model_name]
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
# self.model = AutoModelForCausalLM.from_pretrained(
|
72 |
-
# model_path,
|
73 |
-
# torch_dtype=dtype,
|
74 |
-
# **additional_model_config
|
75 |
-
# ).to(self.device)
|
76 |
-
if not config:
|
77 |
-
# Fallback to HF model
|
78 |
-
has_cuda = torch.cuda.is_available()
|
79 |
model_config = MODEL_INFO[model]
|
80 |
config = outetts.ModelConfig(
|
81 |
model_path=f"OuteAI/{model_name}",
|
@@ -83,13 +70,16 @@ def get_cached_interface(model_name: str):
|
|
83 |
backend=outetts.Backend.HF,
|
84 |
additional_model_config={
|
85 |
"device_map": "auto" if has_cuda else "cpu",
|
86 |
-
"attn_implementation": "flash_attention_2",
|
87 |
"quantization_config": BitsAndBytesConfig(
|
88 |
load_in_8bit=True
|
89 |
) if has_cuda else None,
|
90 |
},
|
91 |
**model_config
|
92 |
)
|
|
|
|
|
|
|
93 |
|
94 |
# Initialize the interface
|
95 |
interface = outetts.Interface(config=config)
|
@@ -122,8 +112,8 @@ def create_speaker_and_generate(model_name, audio_file, test_text: Optional[str]
|
|
122 |
# Return default values for startup/caching purposes
|
123 |
return "Please upload an audio file to create a speaker profile.", None
|
124 |
|
125 |
-
# Get
|
126 |
-
interface =
|
127 |
|
128 |
# Get or create speaker profile (with caching)
|
129 |
speaker = get_or_create_speaker(interface, audio_file)
|
|
|
34 |
hash_md5.update(chunk)
|
35 |
return hash_md5.hexdigest()
|
36 |
|
37 |
+
def try_ggml_model(model: outetts.Models, backend: outetts.Backend, quantization: outetts.LlamaCppQuantization):
|
38 |
model_config = MODEL_INFO[model]
|
39 |
+
repo = f"OuteAI/{model.value}-GGUF"
|
40 |
+
filename = f"{model.value}-{quantization.value}.gguf"
|
41 |
+
model_path = hf_hub_download(
|
42 |
+
repo_id=repo,
|
43 |
+
filename=filename,
|
44 |
+
local_dir=os.path.join(helpers.get_cache_dir(), "gguf"),
|
45 |
+
local_files_only=False
|
46 |
+
)
|
47 |
+
return outetts.ModelConfig(
|
48 |
+
model_path=model_path,
|
49 |
+
tokenizer_path=f"OuteAI/{model.value}",
|
50 |
+
backend=backend,
|
51 |
+
n_gpu_layers=99,
|
52 |
+
verbose=False,
|
53 |
+
device=None,
|
54 |
+
dtype=None,
|
55 |
+
additional_model_config={},
|
56 |
+
audio_codec_path=None,
|
57 |
+
**model_config
|
58 |
+
)
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
def get_interface(model_name: str):
|
61 |
+
"""Get interface instance for the model (no caching to avoid CUDA memory issues)."""
|
|
|
62 |
model = MODELS[model_name]
|
63 |
|
64 |
+
has_cuda = torch.cuda.is_available()
|
65 |
+
if has_cuda:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
model_config = MODEL_INFO[model]
|
67 |
config = outetts.ModelConfig(
|
68 |
model_path=f"OuteAI/{model_name}",
|
|
|
70 |
backend=outetts.Backend.HF,
|
71 |
additional_model_config={
|
72 |
"device_map": "auto" if has_cuda else "cpu",
|
73 |
+
"attn_implementation": "flash_attention_2" if has_cuda else "eager",
|
74 |
"quantization_config": BitsAndBytesConfig(
|
75 |
load_in_8bit=True
|
76 |
) if has_cuda else None,
|
77 |
},
|
78 |
**model_config
|
79 |
)
|
80 |
+
else:
|
81 |
+
quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q6_K)
|
82 |
+
config = try_ggml_model(model, outetts.Backend.LLAMACPP, quantization)
|
83 |
|
84 |
# Initialize the interface
|
85 |
interface = outetts.Interface(config=config)
|
|
|
112 |
# Return default values for startup/caching purposes
|
113 |
return "Please upload an audio file to create a speaker profile.", None
|
114 |
|
115 |
+
# Get interface (no caching to avoid CUDA memory issues)
|
116 |
+
interface = get_interface(model_name)
|
117 |
|
118 |
# Get or create speaker profile (with caching)
|
119 |
speaker = get_or_create_speaker(interface, audio_file)
|