Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -29,15 +29,11 @@ from huggingface_hub import snapshot_download
|
|
29 |
from dotenv import load_dotenv
|
30 |
load_dotenv()
|
31 |
|
32 |
-
# ---------------------------
|
33 |
# Set up device
|
34 |
-
# ---------------------------
|
35 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
36 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
|
37 |
|
38 |
-
# ---------------------------
|
39 |
# Load DeepHermes Llama (chat/LLM) model
|
40 |
-
# ---------------------------
|
41 |
hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
|
42 |
hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
|
43 |
hermes_llm_model = AutoModelForCausalLM.from_pretrained(
|
@@ -47,9 +43,7 @@ hermes_llm_model = AutoModelForCausalLM.from_pretrained(
|
|
47 |
)
|
48 |
hermes_llm_model.eval()
|
49 |
|
50 |
-
# ---------------------------
|
51 |
# Load Qwen2-VL processor and model for multimodal tasks
|
52 |
-
# ---------------------------
|
53 |
MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
|
54 |
# (If needed, you can pass extra arguments such as a size dict here if required.)
|
55 |
processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
|
@@ -59,9 +53,7 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
59 |
torch_dtype=torch.float16
|
60 |
).to("cuda").eval()
|
61 |
|
62 |
-
# ---------------------------
|
63 |
# Load Orpheus TTS model and SNAC for TTS synthesis
|
64 |
-
# ---------------------------
|
65 |
print("Loading SNAC model...")
|
66 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
67 |
snac_model = snac_model.to(tts_device)
|
@@ -93,17 +85,13 @@ orpheus_tts_model.to(tts_device)
|
|
93 |
orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
|
94 |
print(f"Orpheus TTS model loaded to {tts_device}")
|
95 |
|
96 |
-
# ---------------------------
|
97 |
# Some global parameters for chat and image generation
|
98 |
-
# ---------------------------
|
99 |
MAX_MAX_NEW_TOKENS = 2048
|
100 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
101 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
102 |
|
103 |
-
# ---------------------------
|
104 |
# Stable Diffusion XL setup
|
105 |
-
#
|
106 |
-
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
107 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
108 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
109 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
@@ -126,9 +114,7 @@ if ENABLE_CPU_OFFLOAD:
|
|
126 |
|
127 |
MAX_SEED = np.iinfo(np.int32).max
|
128 |
|
129 |
-
# ---------------------------
|
130 |
# Utility functions
|
131 |
-
# ---------------------------
|
132 |
def save_image(img: Image.Image) -> str:
|
133 |
unique_name = str(uuid.uuid4()) + ".png"
|
134 |
img.save(unique_name)
|
@@ -223,9 +209,7 @@ def generate_image_fn(
|
|
223 |
image_paths = [save_image(img) for img in images]
|
224 |
return image_paths, seed
|
225 |
|
226 |
-
# ---------------------------
|
227 |
# New TTS functions (SNAC/Orpheus pipeline)
|
228 |
-
# ---------------------------
|
229 |
def process_prompt(prompt, voice, tokenizer, device):
|
230 |
prompt = f"{voice}: {prompt}"
|
231 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
@@ -307,9 +291,7 @@ def generate_speech(text, voice, temperature, top_p, repetition_penalty, max_new
|
|
307 |
print(f"Error generating speech: {e}")
|
308 |
return None
|
309 |
|
310 |
-
# ---------------------------
|
311 |
# Main generate function for the chat interface
|
312 |
-
# ---------------------------
|
313 |
@spaces.GPU
|
314 |
def generate(
|
315 |
input_dict: dict,
|
@@ -501,9 +483,7 @@ def generate(
|
|
501 |
final_response = "".join(outputs)
|
502 |
yield final_response
|
503 |
|
504 |
-
# ---------------------------
|
505 |
# Gradio Interface
|
506 |
-
# ---------------------------
|
507 |
demo = gr.ChatInterface(
|
508 |
fn=generate,
|
509 |
additional_inputs=[
|
|
|
29 |
from dotenv import load_dotenv
|
30 |
load_dotenv()
|
31 |
|
|
|
32 |
# Set up device
|
|
|
33 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
34 |
tts_device = "cuda" if torch.cuda.is_available() else "cpu" # for SNAC and Orpheus TTS
|
35 |
|
|
|
36 |
# Load DeepHermes Llama (chat/LLM) model
|
|
|
37 |
hermes_model_id = "prithivMLmods/DeepHermes-3-Llama-3-3B-Preview-abliterated"
|
38 |
hermes_llm_tokenizer = AutoTokenizer.from_pretrained(hermes_model_id)
|
39 |
hermes_llm_model = AutoModelForCausalLM.from_pretrained(
|
|
|
43 |
)
|
44 |
hermes_llm_model.eval()
|
45 |
|
|
|
46 |
# Load Qwen2-VL processor and model for multimodal tasks
|
|
|
47 |
MODEL_ID_QWEN = "prithivMLmods/Qwen2-VL-OCR2-2B-Instruct"
|
48 |
# (If needed, you can pass extra arguments such as a size dict here if required.)
|
49 |
processor = AutoProcessor.from_pretrained(MODEL_ID_QWEN, trust_remote_code=True)
|
|
|
53 |
torch_dtype=torch.float16
|
54 |
).to("cuda").eval()
|
55 |
|
|
|
56 |
# Load Orpheus TTS model and SNAC for TTS synthesis
|
|
|
57 |
print("Loading SNAC model...")
|
58 |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
59 |
snac_model = snac_model.to(tts_device)
|
|
|
85 |
orpheus_tts_tokenizer = AutoTokenizer.from_pretrained(tts_model_name)
|
86 |
print(f"Orpheus TTS model loaded to {tts_device}")
|
87 |
|
|
|
88 |
# Some global parameters for chat and image generation
|
|
|
89 |
MAX_MAX_NEW_TOKENS = 2048
|
90 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
91 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
92 |
|
|
|
93 |
# Stable Diffusion XL setup
|
94 |
+
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") #SG161222/RealVisXL_V5.0_Lightning
|
|
|
95 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
96 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
97 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
|
|
114 |
|
115 |
MAX_SEED = np.iinfo(np.int32).max
|
116 |
|
|
|
117 |
# Utility functions
|
|
|
118 |
def save_image(img: Image.Image) -> str:
|
119 |
unique_name = str(uuid.uuid4()) + ".png"
|
120 |
img.save(unique_name)
|
|
|
209 |
image_paths = [save_image(img) for img in images]
|
210 |
return image_paths, seed
|
211 |
|
|
|
212 |
# New TTS functions (SNAC/Orpheus pipeline)
|
|
|
213 |
def process_prompt(prompt, voice, tokenizer, device):
|
214 |
prompt = f"{voice}: {prompt}"
|
215 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
|
|
291 |
print(f"Error generating speech: {e}")
|
292 |
return None
|
293 |
|
|
|
294 |
# Main generate function for the chat interface
|
|
|
295 |
@spaces.GPU
|
296 |
def generate(
|
297 |
input_dict: dict,
|
|
|
483 |
final_response = "".join(outputs)
|
484 |
yield final_response
|
485 |
|
|
|
486 |
# Gradio Interface
|
|
|
487 |
demo = gr.ChatInterface(
|
488 |
fn=generate,
|
489 |
additional_inputs=[
|