Spaces:
Sleeping
Sleeping
import html | |
import json | |
import os | |
import queue | |
import shutil | |
import sys | |
import threading | |
import uuid | |
from pathlib import Path | |
from time import sleep | |
os.environ['GRADIO_TEMP_DIR'] = os.getcwd() | |
video_root_path = os.path.join(os.getcwd(), 'video_root') | |
os.makedirs(video_root_path, exist_ok=True) | |
from omagent_core.clients.devices.app.callback import AppCallback | |
from omagent_core.clients.devices.app.input import AppInput | |
from omagent_core.clients.devices.app.schemas import ContentStatus, MessageType | |
from omagent_core.engine.automator.task_handler import TaskHandler | |
from omagent_core.engine.http.models.workflow_status import terminal_status | |
from omagent_core.engine.workflow.conductor_workflow import ConductorWorkflow | |
from omagent_core.services.connectors.redis import RedisConnector | |
from omagent_core.utils.build import build_from_file | |
from omagent_core.utils.container import container | |
from omagent_core.utils.logger import logging | |
from omagent_core.utils.registry import registry | |
registry.import_module() | |
container.register_connector(name="redis_stream_client", connector=RedisConnector) | |
# container.register_stm(stm='RedisSTM') | |
container.register_callback(callback=AppCallback) | |
container.register_input(input=AppInput) | |
import gradio as gr | |
class WebpageClient: | |
def __init__( | |
self, | |
interactor: ConductorWorkflow = None, | |
processor: ConductorWorkflow = None, | |
config_path: str = "./config", | |
workers: list = [], | |
) -> None: | |
self._interactor = interactor | |
self._processor = processor | |
self._config_path = config_path | |
self._workers = workers | |
self._workflow_instance_id = None | |
self._worker_config = build_from_file(self._config_path) | |
self._task_to_domain = {} | |
self._incomplete_message = "" | |
self._custom_css = """ | |
#OmAgent { | |
height: 100vh !important; | |
max-height: calc(100vh - 190px) !important; | |
overflow-y: auto; | |
} | |
.running-message { | |
margin: 0; | |
padding: 2px 4px; | |
white-space: pre-wrap; | |
word-wrap: break-word; | |
font-family: inherit; | |
} | |
/* Remove the background and border of the message box */ | |
.message-wrap { | |
background: none !important; | |
border: none !important; | |
padding: 0 !important; | |
margin: 0 !important; | |
} | |
/* Remove the bubble style of the running message */ | |
.message:has(.running-message) { | |
background: none !important; | |
border: none !important; | |
padding: 0 !important; | |
box-shadow: none !important; | |
} | |
""" | |
self.workflow_instance_id = str(uuid.uuid4()) | |
self.processor_instance_id = str(uuid.uuid4()) | |
worker_config = build_from_file(self._config_path) | |
self.initialization(workers, worker_config) | |
def initialization(self, workers, worker_config): | |
self.workers = {} | |
for worker in workers: | |
worker.workflow_instance_id = self.workflow_instance_id | |
self.workers[type(worker).__name__] = worker | |
for config in worker_config: | |
worker_cls = registry.get_worker(config['name']) | |
worker = worker_cls(**config) | |
worker.workflow_instance_id = self.workflow_instance_id | |
self.workers[config['name']] = worker | |
def gradio_app(self): | |
with gr.Blocks() as demo: | |
def load_local_video() -> dict: | |
result = {} | |
for root, _, files in os.walk(video_root_path): | |
for file in filter(lambda x: x.split('.')[-1].lower() in ( | |
'mp4', 'avi', 'mov', 'wmv', 'flv', 'mkv', 'webm', 'm4v'), files): | |
file_obs_path = os.path.join(root, file) | |
result[Path(file_obs_path).name] = file_obs_path | |
return result | |
video_dict = load_local_video() | |
current_video = None | |
state = gr.State(value={ | |
'video_dict': video_dict, | |
'current_video': current_video | |
}) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Column(): | |
def display_video_map(video_title): | |
# change display video | |
video_path = state.value.get('video_dict', {}).get(video_title) | |
exception_queue = queue.Queue() | |
workflow_input = {'video_path': video_path} | |
processor_result = None | |
def run_workflow(workflow_input): | |
nonlocal processor_result | |
try: | |
processor_result = self._processor.start_workflow_with_input( | |
workflow_input=workflow_input, workers=self.workers | |
) | |
except Exception as e: | |
exception_queue.put(e) # add exception to queue | |
logging.error(f"Error starting workflow: {e}") | |
raise e | |
# workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,)) | |
# workflow_thread.start() | |
run_workflow(workflow_input) | |
processor_workflow_instance_id = self.processor_instance_id | |
while True: | |
status = self._processor.get_workflow( | |
workflow_id=processor_workflow_instance_id).status | |
if status in terminal_status: | |
break | |
sleep(1) | |
state.value['video_dict'] = load_local_video() | |
state.value.update(current_video=video_path) | |
state.value.update(processor_result=processor_result) | |
state.value.update(processor_workflow_instance_id=processor_workflow_instance_id) | |
return video_path, state | |
select_video = gr.Dropdown( | |
state.value['video_dict'].keys(), | |
value=None | |
) | |
display_video = gr.Video( | |
state.value['current_video'], | |
) | |
select_video.change( | |
fn=display_video_map, | |
inputs=[select_video], | |
outputs=[display_video, state] | |
) | |
with gr.Column(): | |
chatbot = gr.Chatbot( | |
type="messages", | |
) | |
chat_input = gr.Textbox( | |
interactive=True, | |
placeholder="Enter message...", | |
show_label=False, | |
) | |
chat_msg = chat_input.submit( | |
self.add_message, | |
[chatbot, chat_input, state], | |
[chatbot, chat_input] | |
) | |
bot_msg = chat_msg.then( | |
self.bot, (chatbot, state), chatbot, api_name="bot_response" | |
) | |
bot_msg.then( | |
lambda: gr.Textbox(interactive=True), None, [chat_input] | |
) | |
demo.launch( | |
max_file_size='1gb' | |
) | |
def start_interactor(self): | |
try: | |
self.gradio_app() | |
except KeyboardInterrupt: | |
logging.info("\nDetected Ctrl+C, stopping workflow...") | |
if self._workflow_instance_id is not None: | |
self._interactor._executor.terminate( | |
workflow_id=self._workflow_instance_id | |
) | |
raise | |
def stop_interactor(self): | |
# self._task_handler_interactor.stop_processes() | |
print("stop_interactor") | |
sys.exit(0) | |
def start_processor(self): | |
self._task_handler_processor = TaskHandler( | |
worker_config=self._worker_config, workers=self._workers, task_to_domain=self._task_to_domain | |
) | |
self._task_handler_processor.start_processes() | |
try: | |
with gr.Blocks(title="OmAgent", css=self._custom_css) as chat_interface: | |
chatbot = gr.Chatbot( | |
elem_id="OmAgent", | |
bubble_full_width=False, | |
type="messages", | |
height="100%", | |
) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_count="multiple", | |
placeholder="Enter message or upload file...", | |
show_label=False, | |
) | |
chat_msg = chat_input.submit( | |
self.add_processor_message, | |
[chatbot, chat_input], | |
[chatbot, chat_input], | |
) | |
bot_msg = chat_msg.then( | |
self.processor_bot, chatbot, chatbot, api_name="bot_response" | |
) | |
bot_msg.then( | |
lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input] | |
) | |
chat_interface.launch(server_port=7861) | |
except KeyboardInterrupt: | |
logging.info("\nDetected Ctrl+C, stopping workflow...") | |
if self._workflow_instance_id is not None: | |
self._processor._executor.terminate( | |
workflow_id=self._workflow_instance_id | |
) | |
raise | |
def stop_processor(self): | |
self._task_handler_processor.stop_processes() | |
def add_message(self, history, message, state): | |
if isinstance(state, gr.State): | |
if state.value.get('current_video') is None: | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": 'Please select a video'}) | |
return history, gr.Textbox(value=None, interactive=False) | |
else: | |
if not state: | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": 'Please reselect the video'}) | |
return history, gr.Textbox(value=None, interactive=False) | |
if state.get('current_video') is None: | |
history.append({"role": "user", "content": message}) | |
history.append({"role": "assistant", "content": 'Please select a video'}) | |
return history, gr.Textbox(value=None, interactive=False) | |
if self._workflow_instance_id is None: | |
workflow_input = { | |
'question': message, | |
"video_md5": state.value.get('processor_result', {}).get("video_md5"), | |
"video_path": state.value.get('processor_result', {}).get("video_path"), | |
"instance_id": state.value.get('processor_result', {}).get("instance_id"), | |
"processor_workflow_instance_id": state.value.get("processor_workflow_instance_id") | |
} | |
exception_queue = queue.Queue() | |
def run_workflow(workflow_input): | |
try: | |
self._interactor.start_workflow_with_input( | |
workflow_input=workflow_input, workers=self.workers | |
) | |
except Exception as e: | |
exception_queue.put(e) # add exception to queue | |
logging.error(f"Error starting workflow: {e}") | |
raise e | |
workflow_thread = threading.Thread(target=run_workflow, args=(workflow_input,),daemon=True) | |
workflow_thread.start() | |
self._workflow_instance_id = self.workflow_instance_id | |
contents = [] | |
history.append({"role": "user", "content": message}) | |
contents.append({"data": message, "type": "text"}) | |
result = { | |
"agent_id": self._workflow_instance_id, | |
"messages": [{"role": "user", "content": contents}], | |
"kwargs": {}, | |
} | |
container.get_connector("redis_stream_client")._client.xadd( | |
f"{self._workflow_instance_id}_input", | |
{"payload": json.dumps(result, ensure_ascii=False)}, | |
) | |
return history, gr.Textbox(value=None, interactive=False) | |
def add_processor_message(self, history, message): | |
if self._workflow_instance_id is None: | |
self._workflow_instance_id = self._processor.start_workflow_with_input( | |
workflow_input={}, task_to_domain=self._task_to_domain | |
) | |
image_items = [] | |
for idx, x in enumerate(message["files"]): | |
history.append({"role": "user", "content": {"path": x}}) | |
image_items.append( | |
{"type": "image_url", "resource_id": str(idx), "data": str(x)} | |
) | |
result = {"content": image_items} | |
container.get_connector("redis_stream_client")._client.xadd( | |
f"image_process", {"payload": json.dumps(result, ensure_ascii=False)} | |
) | |
return history, gr.MultimodalTextbox(value=None, interactive=False) | |
def bot(self, history, state): | |
if isinstance(state, gr.State): | |
if state.value.get('current_video') is None: | |
yield history | |
return | |
else: | |
if state.get('current_video') is None: | |
yield history | |
return | |
stream_name = f"{self._workflow_instance_id}_output" | |
consumer_name = f"{self._workflow_instance_id}_agent" # consumer name | |
group_name = "omappagent" # replace with your consumer group name | |
running_stream_name = f"{self._workflow_instance_id}_running" | |
self._check_redis_stream_exist(stream_name, group_name) | |
self._check_redis_stream_exist(running_stream_name, group_name) | |
while True: | |
# read running stream | |
running_messages = self._get_redis_stream_message( | |
group_name, consumer_name, running_stream_name | |
) | |
for stream, message_list in running_messages: | |
for message_id, message in message_list: | |
payload_data = self._get_message_payload(message) | |
if payload_data is None: | |
continue | |
progress = html.escape(payload_data.get("progress", "")) | |
message = html.escape(payload_data.get("message", "")) | |
formatted_message = ( | |
f'<pre class="running-message">{progress}: {message}</pre>' | |
) | |
history.append({"role": "assistant", "content": formatted_message}) | |
yield history | |
container.get_connector("redis_stream_client")._client.xack( | |
running_stream_name, group_name, message_id | |
) | |
# read output stream | |
messages = self._get_redis_stream_message( | |
group_name, consumer_name, stream_name | |
) | |
finish_flag = False | |
for stream, message_list in messages: | |
for message_id, message in message_list: | |
incomplete_flag = False | |
payload_data = self._get_message_payload(message) | |
if payload_data is None: | |
continue | |
if payload_data["content_status"] == ContentStatus.INCOMPLETE.value: | |
incomplete_flag = True | |
message_item = payload_data["message"] | |
if message_item["type"] == MessageType.IMAGE_URL.value: | |
history.append( | |
{ | |
"role": "assistant", | |
"content": {"path": message_item["content"]}, | |
} | |
) | |
else: | |
if incomplete_flag: | |
self._incomplete_message = ( | |
self._incomplete_message + message_item["content"] | |
) | |
if history and history[-1]["role"] == "assistant": | |
history[-1]["content"] = self._incomplete_message | |
else: | |
history.append( | |
{ | |
"role": "assistant", | |
"content": self._incomplete_message, | |
} | |
) | |
else: | |
if self._incomplete_message != "": | |
self._incomplete_message = ( | |
self._incomplete_message + message_item["content"] | |
) | |
if history and history[-1]["role"] == "assistant": | |
history[-1]["content"] = self._incomplete_message | |
else: | |
history.append( | |
{ | |
"role": "assistant", | |
"content": self._incomplete_message, | |
} | |
) | |
self._incomplete_message = "" | |
else: | |
history.append( | |
{ | |
"role": "assistant", | |
"content": message_item["content"], | |
} | |
) | |
yield history | |
container.get_connector("redis_stream_client")._client.xack( | |
stream_name, group_name, message_id | |
) | |
# check finish flag | |
if ( | |
"interaction_type" in payload_data | |
and payload_data["interaction_type"] == 1 | |
): | |
finish_flag = True | |
if ( | |
"content_status" in payload_data | |
and payload_data["content_status"] | |
== ContentStatus.END_ANSWER.value | |
): | |
self._workflow_instance_id = None | |
finish_flag = True | |
if finish_flag: | |
break | |
sleep(0.01) | |
def processor_bot(self, history: list): | |
history.append({"role": "assistant", "content": f"processing..."}) | |
yield history | |
while True: | |
status = self._processor.get_workflow( | |
workflow_id=self._workflow_instance_id | |
).status | |
if status in terminal_status: | |
history.append({"role": "assistant", "content": f"completed"}) | |
yield history | |
self._workflow_instance_id = None | |
break | |
sleep(0.01) | |
def _get_redis_stream_message( | |
self, group_name: str, consumer_name: str, stream_name: str | |
): | |
messages = container.get_connector("redis_stream_client")._client.xreadgroup( | |
group_name, consumer_name, {stream_name: ">"}, count=1 | |
) | |
messages = [ | |
( | |
stream, | |
[ | |
( | |
message_id, | |
{ | |
k.decode("utf-8"): v.decode("utf-8") | |
for k, v in message.items() | |
}, | |
) | |
for message_id, message in message_list | |
], | |
) | |
for stream, message_list in messages | |
] | |
return messages | |
def _check_redis_stream_exist(self, stream_name: str, group_name: str): | |
try: | |
container.get_connector("redis_stream_client")._client.xgroup_create( | |
stream_name, group_name, id="0", mkstream=True | |
) | |
except Exception as e: | |
logging.debug(f"Consumer group may already exist: {e}") | |
def _get_message_payload(self, message: dict): | |
logging.info(f"Received running message: {message}") | |
payload = message.get("payload") | |
# check payload data | |
if not payload: | |
logging.error("Payload is empty") | |
return None | |
try: | |
payload_data = json.loads(payload) | |
except json.JSONDecodeError as e: | |
logging.error(f"Payload is not a valid JSON: {e}") | |
return None | |
return payload_data | |