hans00 commited on
Commit
476f08d
·
unverified ·
1 Parent(s): bc05181

Use HF if CUDA available and not persistent model load

Browse files
Files changed (1) hide show
  1. app.py +31 -41
app.py CHANGED
@@ -34,48 +34,35 @@ def get_file_hash(file_path):
34
  hash_md5.update(chunk)
35
  return hash_md5.hexdigest()
36
 
37
- def try_auto_model_config(model: outetts.Models, backend: outetts.Backend, quantization: outetts.LlamaCppQuantization):
38
  model_config = MODEL_INFO[model]
39
- try:
40
- repo = f"OuteAI/{model.value}-GGUF"
41
- filename = f"{model.value}-{quantization.value}.gguf"
42
- model_path = hf_hub_download(
43
- repo_id=repo,
44
- filename=filename,
45
- local_dir=os.path.join(helpers.get_cache_dir(), "gguf"),
46
- local_files_only=False
47
- )
48
- return outetts.ModelConfig(
49
- model_path=model_path,
50
- tokenizer_path=f"OuteAI/{model.value}",
51
- backend=backend,
52
- n_gpu_layers=99,
53
- verbose=False,
54
- device=None,
55
- dtype=None,
56
- additional_model_config={},
57
- audio_codec_path=None,
58
- **model_config
59
- )
60
- except Exception as e:
61
- print(f"Error: {e}")
62
- return None
63
 
64
- @lru_cache(maxsize=5)
65
- def get_cached_interface(model_name: str):
66
- """Get cached interface instance for the model."""
67
  model = MODELS[model_name]
68
 
69
- quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q6_K)
70
- config = try_auto_model_config(model, outetts.Backend.LLAMACPP, quantization)
71
- # self.model = AutoModelForCausalLM.from_pretrained(
72
- # model_path,
73
- # torch_dtype=dtype,
74
- # **additional_model_config
75
- # ).to(self.device)
76
- if not config:
77
- # Fallback to HF model
78
- has_cuda = torch.cuda.is_available()
79
  model_config = MODEL_INFO[model]
80
  config = outetts.ModelConfig(
81
  model_path=f"OuteAI/{model_name}",
@@ -83,13 +70,16 @@ def get_cached_interface(model_name: str):
83
  backend=outetts.Backend.HF,
84
  additional_model_config={
85
  "device_map": "auto" if has_cuda else "cpu",
86
- "attn_implementation": "flash_attention_2",
87
  "quantization_config": BitsAndBytesConfig(
88
  load_in_8bit=True
89
  ) if has_cuda else None,
90
  },
91
  **model_config
92
  )
 
 
 
93
 
94
  # Initialize the interface
95
  interface = outetts.Interface(config=config)
@@ -122,8 +112,8 @@ def create_speaker_and_generate(model_name, audio_file, test_text: Optional[str]
122
  # Return default values for startup/caching purposes
123
  return "Please upload an audio file to create a speaker profile.", None
124
 
125
- # Get cached interface
126
- interface = get_cached_interface(model_name)
127
 
128
  # Get or create speaker profile (with caching)
129
  speaker = get_or_create_speaker(interface, audio_file)
 
34
  hash_md5.update(chunk)
35
  return hash_md5.hexdigest()
36
 
37
+ def try_ggml_model(model: outetts.Models, backend: outetts.Backend, quantization: outetts.LlamaCppQuantization):
38
  model_config = MODEL_INFO[model]
39
+ repo = f"OuteAI/{model.value}-GGUF"
40
+ filename = f"{model.value}-{quantization.value}.gguf"
41
+ model_path = hf_hub_download(
42
+ repo_id=repo,
43
+ filename=filename,
44
+ local_dir=os.path.join(helpers.get_cache_dir(), "gguf"),
45
+ local_files_only=False
46
+ )
47
+ return outetts.ModelConfig(
48
+ model_path=model_path,
49
+ tokenizer_path=f"OuteAI/{model.value}",
50
+ backend=backend,
51
+ n_gpu_layers=99,
52
+ verbose=False,
53
+ device=None,
54
+ dtype=None,
55
+ additional_model_config={},
56
+ audio_codec_path=None,
57
+ **model_config
58
+ )
 
 
 
 
59
 
60
+ def get_interface(model_name: str):
61
+ """Get interface instance for the model (no caching to avoid CUDA memory issues)."""
 
62
  model = MODELS[model_name]
63
 
64
+ has_cuda = torch.cuda.is_available()
65
+ if has_cuda:
 
 
 
 
 
 
 
 
66
  model_config = MODEL_INFO[model]
67
  config = outetts.ModelConfig(
68
  model_path=f"OuteAI/{model_name}",
 
70
  backend=outetts.Backend.HF,
71
  additional_model_config={
72
  "device_map": "auto" if has_cuda else "cpu",
73
+ "attn_implementation": "flash_attention_2" if has_cuda else "eager",
74
  "quantization_config": BitsAndBytesConfig(
75
  load_in_8bit=True
76
  ) if has_cuda else None,
77
  },
78
  **model_config
79
  )
80
+ else:
81
+ quantization = MODEL_QUANTIZATION.get(model, outetts.LlamaCppQuantization.Q6_K)
82
+ config = try_ggml_model(model, outetts.Backend.LLAMACPP, quantization)
83
 
84
  # Initialize the interface
85
  interface = outetts.Interface(config=config)
 
112
  # Return default values for startup/caching purposes
113
  return "Please upload an audio file to create a speaker profile.", None
114
 
115
+ # Get interface (no caching to avoid CUDA memory issues)
116
+ interface = get_interface(model_name)
117
 
118
  # Get or create speaker profile (with caching)
119
  speaker = get_or_create_speaker(interface, audio_file)