Added current task into task executor's hearbeat (#3444)
Browse files### What problem does this PR solve?
Added current task into task executor's hearbeat
### Type of change
- [x] Refactoring
- rag/svr/task_executor.py +32 -18
rag/svr/task_executor.py
CHANGED
@@ -83,11 +83,14 @@ FACTORY = {
|
|
83 |
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
84 |
PAYLOAD: Payload | None = None
|
85 |
BOOT_AT = datetime.now().isoformat()
|
86 |
-
DONE_TASKS = 0
|
87 |
-
FAILED_TASKS = 0
|
88 |
PENDING_TASKS = 0
|
89 |
LAG_TASKS = 0
|
90 |
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
93 |
global PAYLOAD
|
@@ -135,12 +138,14 @@ def collect():
|
|
135 |
return None
|
136 |
|
137 |
if TaskService.do_cancel(msg["id"]):
|
138 |
-
|
|
|
139 |
logging.info("Task {} has been canceled.".format(msg["id"]))
|
140 |
return None
|
141 |
task = TaskService.get_task(msg["id"])
|
142 |
if not task:
|
143 |
-
|
|
|
144 |
logging.warning("{} empty task!".format(msg["id"]))
|
145 |
return None
|
146 |
|
@@ -427,16 +432,22 @@ def do_handle_task(r):
|
|
427 |
|
428 |
|
429 |
def handle_task():
|
430 |
-
global PAYLOAD, DONE_TASKS, FAILED_TASKS
|
431 |
task = collect()
|
432 |
if task:
|
433 |
try:
|
434 |
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
|
|
|
|
435 |
do_handle_task(task)
|
436 |
-
|
437 |
-
|
|
|
|
|
438 |
except Exception:
|
439 |
-
|
|
|
|
|
440 |
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
441 |
if PAYLOAD:
|
442 |
PAYLOAD.ack()
|
@@ -444,7 +455,7 @@ def handle_task():
|
|
444 |
|
445 |
|
446 |
def report_status():
|
447 |
-
global CONSUMER_NAME, BOOT_AT, DONE_TASKS, FAILED_TASKS,
|
448 |
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
449 |
while True:
|
450 |
try:
|
@@ -454,15 +465,17 @@ def report_status():
|
|
454 |
PENDING_TASKS = int(group_info["pending"])
|
455 |
LAG_TASKS = int(group_info["lag"])
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
466 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
467 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
468 |
|
@@ -474,6 +487,7 @@ def report_status():
|
|
474 |
time.sleep(30)
|
475 |
|
476 |
def main():
|
|
|
477 |
background_thread = threading.Thread(target=report_status)
|
478 |
background_thread.daemon = True
|
479 |
background_thread.start()
|
|
|
83 |
CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
|
84 |
PAYLOAD: Payload | None = None
|
85 |
BOOT_AT = datetime.now().isoformat()
|
|
|
|
|
86 |
PENDING_TASKS = 0
|
87 |
LAG_TASKS = 0
|
88 |
|
89 |
+
mt_lock = threading.Lock()
|
90 |
+
DONE_TASKS = 0
|
91 |
+
FAILED_TASKS = 0
|
92 |
+
CURRENT_TASK = None
|
93 |
+
|
94 |
|
95 |
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
|
96 |
global PAYLOAD
|
|
|
138 |
return None
|
139 |
|
140 |
if TaskService.do_cancel(msg["id"]):
|
141 |
+
with mt_lock:
|
142 |
+
DONE_TASKS += 1
|
143 |
logging.info("Task {} has been canceled.".format(msg["id"]))
|
144 |
return None
|
145 |
task = TaskService.get_task(msg["id"])
|
146 |
if not task:
|
147 |
+
with mt_lock:
|
148 |
+
DONE_TASKS += 1
|
149 |
logging.warning("{} empty task!".format(msg["id"]))
|
150 |
return None
|
151 |
|
|
|
432 |
|
433 |
|
434 |
def handle_task():
|
435 |
+
global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
436 |
task = collect()
|
437 |
if task:
|
438 |
try:
|
439 |
logging.info(f"handle_task begin for task {json.dumps(task)}")
|
440 |
+
with mt_lock:
|
441 |
+
CURRENT_TASK = copy.deepcopy(task)
|
442 |
do_handle_task(task)
|
443 |
+
with mt_lock:
|
444 |
+
DONE_TASKS += 1
|
445 |
+
CURRENT_TASK = None
|
446 |
+
logging.info(f"handle_task done for task {json.dumps(task)}")
|
447 |
except Exception:
|
448 |
+
with mt_lock:
|
449 |
+
FAILED_TASKS += 1
|
450 |
+
CURRENT_TASK = None
|
451 |
logging.exception(f"handle_task got exception for task {json.dumps(task)}")
|
452 |
if PAYLOAD:
|
453 |
PAYLOAD.ack()
|
|
|
455 |
|
456 |
|
457 |
def report_status():
|
458 |
+
global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
|
459 |
REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
|
460 |
while True:
|
461 |
try:
|
|
|
465 |
PENDING_TASKS = int(group_info["pending"])
|
466 |
LAG_TASKS = int(group_info["lag"])
|
467 |
|
468 |
+
with mt_lock:
|
469 |
+
heartbeat = json.dumps({
|
470 |
+
"name": CONSUMER_NAME,
|
471 |
+
"now": now.isoformat(),
|
472 |
+
"boot_at": BOOT_AT,
|
473 |
+
"pending": PENDING_TASKS,
|
474 |
+
"lag": LAG_TASKS,
|
475 |
+
"done": DONE_TASKS,
|
476 |
+
"failed": FAILED_TASKS,
|
477 |
+
"current": CURRENT_TASK,
|
478 |
+
})
|
479 |
REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
|
480 |
logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
|
481 |
|
|
|
487 |
time.sleep(30)
|
488 |
|
489 |
def main():
|
490 |
+
settings.init_settings()
|
491 |
background_thread = threading.Thread(target=report_status)
|
492 |
background_thread.daemon = True
|
493 |
background_thread.start()
|