diff --git a/.project-root b/.project-root new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py index 0da0319a5b670dce5025888fde58916b96f19869..ce9d40439dfb9589894efb05a100bb3d54651d8b 100644 --- a/app.py +++ b/app.py @@ -1,64 +1,230 @@ +import re import gradio as gr -from huggingface_hub import InferenceClient +import numpy as np +import os +import threading +import subprocess +import sys +import time -""" -For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference -""" -client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") +from huggingface_hub import snapshot_download +from tools.fish_e2e import FishE2EAgent, FishE2EEventType +from tools.schema import ServeMessage, ServeTextPart, ServeVQPart +# Download Weights +os.makedirs("checkpoints", exist_ok=True) +snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4") +snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b") -def respond( - message, - history: list[tuple[str, str]], - system_message, - max_tokens, - temperature, - top_p, +class ChatState: + def __init__(self): + self.conversation = [] + self.added_systext = False + self.added_sysaudio = False + + def get_history(self): + results = [] + for msg in self.conversation: + results.append({"role": msg.role, "content": self.repr_message(msg)}) + + # Process assistant messages to extract questions and update user messages + for i, msg in enumerate(results): + if msg["role"] == "assistant": + match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"]) + if match and i > 0 and results[i - 1]["role"] == "user": + # Update previous user message with extracted question + results[i - 1]["content"] += "\n" + match.group(1) + # Remove the Question/Answer format from assistant message + msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1] + return results + + def repr_message(self, msg: ServeMessage): + response = "" + for part in msg.parts: + if isinstance(part, ServeTextPart): + response += part.text + elif isinstance(part, ServeVQPart): + response += f"