prithivMLmods commited on
Commit
a5607ef
·
verified ·
1 Parent(s): 153d99a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -21
app.py CHANGED
@@ -29,15 +29,11 @@ from huggingface_hub import snapshot_download
29
  from dotenv import load_dotenv
30
  load_dotenv()
31
 
32
- # ---------------------------
33
  # Set up device
34
- # ---------------------------
35
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36
  tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
37
 
38
- # ---------------------------
39
  # Load DeepHermes Llama (chat/LLM) model
40
- # ---------------------------
41
  hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
42
  hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
43
  hermes_llm_model = AutoModelForCausalLM.from_pretrained(
@@ -47,9 +43,7 @@ hermes_llm_model = AutoModelForCausalLM.from_pretrained(
47
  )
48
  hermes_llm_model.eval()
49
 
50
- # ---------------------------
51
  # Load Qwen2-VL processor and model for multimodal tasks
52
- # ---------------------------
53
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
54
  # (If needed, you can pass extra arguments such as a size dict here if required.)
55
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
@@ -59,9 +53,7 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
59
  torch_dtype=torch.float16
60
  ).to("cuda").eval()
61
 
62
- # ---------------------------
63
  # Load Orpheus TTS model and SNAC for TTS synthesis
64
- # ---------------------------
65
  print("Loading SNAC model...")
66
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
67
  snac_model = snac_model.to(tts_device)
@@ -93,17 +85,13 @@ orpheus_tts_model.to(tts_device)
93
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
94
  print(f"Orpheus TTS model loaded to {tts_device}")
95
 
96
- # ---------------------------
97
  # Some global parameters for chat and image generation
98
- # ---------------------------
99
  MAX_MAX_NEW_TOKENS = 2048
100
  DEFAULT_MAX_NEW_TOKENS = 1024
101
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
102
 
103
- # ---------------------------
104
  # Stable Diffusion XL setup
105
- # ---------------------------
106
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
107
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
108
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
109
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
@@ -126,9 +114,7 @@ if ENABLE_CPU_OFFLOAD:
126
 
127
  MAX_SEED = np.iinfo(np.int32).max
128
 
129
- # ---------------------------
130
  # Utility functions
131
- # ---------------------------
132
  def save_image(img: Image.Image) -> str:
133
  unique_name = str(uuid.uuid4()) + ".png"
134
  img.save(unique_name)
@@ -223,9 +209,7 @@ def generate_image_fn(
223
  image_paths = [save_image(img) for img in images]
224
  return image_paths, seed
225
 
226
- # ---------------------------
227
  # New TTS functions (SNAC/Orpheus pipeline)
228
- # ---------------------------
229
  def process_prompt(prompt, voice, tokenizer, device):
230
  prompt = f"{voice}: {prompt}"
231
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
@@ -307,9 +291,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
307
  print(f"Error generating speech: {e}")
308
  return None
309
 
310
- # ---------------------------
311
  # Main generate function for the chat interface
312
- # ---------------------------
313
  @spaces.GPU
314
  def generate(
315
  input_dict: dict,
@@ -501,9 +483,7 @@ def generate(
501
  final_response = "".join(outputs)
502
  yield final_response
503
 
504
- # ---------------------------
505
  # Gradio Interface
506
- # ---------------------------
507
  demo = gr.ChatInterface(
508
  fn=generate,
509
  additional_inputs=[
 
29
  from dotenv import load_dotenv
30
  load_dotenv()
31
 
 
32
  # Set up device
 
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
34
  tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
35
 
 
36
  # Load DeepHermes Llama (chat/LLM) model
 
37
  hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
38
  hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
39
  hermes_llm_model = AutoModelForCausalLM.from_pretrained(
 
43
  )
44
  hermes_llm_model.eval()
45
 
 
46
  # Load Qwen2-VL processor and model for multimodal tasks
 
47
  MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
48
  # (If needed, you can pass extra arguments such as a size dict here if required.)
49
  processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
 
53
  torch_dtype=torch.float16
54
  ).to("cuda").eval()
55
 
 
56
  # Load Orpheus TTS model and SNAC for TTS synthesis
 
57
  print("Loading SNAC model...")
58
  snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
59
  snac_model = snac_model.to(tts_device)
 
85
  orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
86
  print(f"Orpheus TTS model loaded to {tts_device}")
87
 
 
88
  # Some global parameters for chat and image generation
 
89
  MAX_MAX_NEW_TOKENS = 2048
90
  DEFAULT_MAX_NEW_TOKENS = 1024
91
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
92
 
 
93
  # Stable Diffusion XL setup
94
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") #SG161222/RealVisXL_V5.0_Lightning
 
95
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
96
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
97
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
114
 
115
  MAX_SEED = np.iinfo(np.int32).max
116
 
 
117
  # Utility functions
 
118
  def save_image(img: Image.Image) -> str:
119
  unique_name = str(uuid.uuid4()) + ".png"
120
  img.save(unique_name)
 
209
  image_paths = [save_image(img) for img in images]
210
  return image_paths, seed
211
 
 
212
  # New TTS functions (SNAC/Orpheus pipeline)
 
213
  def process_prompt(prompt, voice, tokenizer, device):
214
  prompt = f"{voice}: {prompt}"
215
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids
 
291
  print(f"Error generating speech: {e}")
292
  return None
293
 
 
294
  # Main generate function for the chat interface
 
295
  @spaces.GPU
296
  def generate(
297
  input_dict: dict,
 
483
  final_response = "".join(outputs)
484
  yield final_response
485
 
 
486
  # Gradio Interface
 
487
  demo = gr.ChatInterface(
488
  fn=generate,
489
  additional_inputs=[