zhichyu commited on
Commit
4822328
·
1 Parent(s): e6a705f

Added current task into task executor's hearbeat (#3444)

Browse files

### What problem does this PR solve?

Added current task into task executor's hearbeat

### Type of change

- [x] Refactoring

Files changed (1) hide show
  1. rag/svr/task_executor.py +32 -18
rag/svr/task_executor.py CHANGED
@@ -83,11 +83,14 @@ FACTORY = {
83
  CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
84
  PAYLOAD: Payload | None = None
85
  BOOT_AT = datetime.now().isoformat()
86
- DONE_TASKS = 0
87
- FAILED_TASKS = 0
88
  PENDING_TASKS = 0
89
  LAG_TASKS = 0
90
 
 
 
 
 
 
91
 
92
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
93
  global PAYLOAD
@@ -135,12 +138,14 @@ def collect():
135
  return None
136
 
137
  if TaskService.do_cancel(msg["id"]):
138
- DONE_TASKS += 1
 
139
  logging.info("Task {} has been canceled.".format(msg["id"]))
140
  return None
141
  task = TaskService.get_task(msg["id"])
142
  if not task:
143
- DONE_TASKS += 1
 
144
  logging.warning("{} empty task!".format(msg["id"]))
145
  return None
146
 
@@ -427,16 +432,22 @@ def do_handle_task(r):
427
 
428
 
429
  def handle_task():
430
- global PAYLOAD, DONE_TASKS, FAILED_TASKS
431
  task = collect()
432
  if task:
433
  try:
434
  logging.info(f"handle_task begin for task {json.dumps(task)}")
 
 
435
  do_handle_task(task)
436
- DONE_TASKS += 1
437
- logging.exception(f"handle_task done for task {json.dumps(task)}")
 
 
438
  except Exception:
439
- FAILED_TASKS += 1
 
 
440
  logging.exception(f"handle_task got exception for task {json.dumps(task)}")
441
  if PAYLOAD:
442
  PAYLOAD.ack()
@@ -444,7 +455,7 @@ def handle_task():
444
 
445
 
446
  def report_status():
447
- global CONSUMER_NAME, BOOT_AT, DONE_TASKS, FAILED_TASKS, PENDING_TASKS, LAG_TASKS
448
  REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
449
  while True:
450
  try:
@@ -454,15 +465,17 @@ def report_status():
454
  PENDING_TASKS = int(group_info["pending"])
455
  LAG_TASKS = int(group_info["lag"])
456
 
457
- heartbeat = json.dumps({
458
- "name": CONSUMER_NAME,
459
- "now": now.isoformat(),
460
- "boot_at": BOOT_AT,
461
- "done": DONE_TASKS,
462
- "failed": FAILED_TASKS,
463
- "pending": PENDING_TASKS,
464
- "lag": LAG_TASKS,
465
- })
 
 
466
  REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
467
  logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
468
 
@@ -474,6 +487,7 @@ def report_status():
474
  time.sleep(30)
475
 
476
  def main():
 
477
  background_thread = threading.Thread(target=report_status)
478
  background_thread.daemon = True
479
  background_thread.start()
 
83
  CONSUMER_NAME = "task_consumer_" + CONSUMER_NO
84
  PAYLOAD: Payload | None = None
85
  BOOT_AT = datetime.now().isoformat()
 
 
86
  PENDING_TASKS = 0
87
  LAG_TASKS = 0
88
 
89
+ mt_lock = threading.Lock()
90
+ DONE_TASKS = 0
91
+ FAILED_TASKS = 0
92
+ CURRENT_TASK = None
93
+
94
 
95
  def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
96
  global PAYLOAD
 
138
  return None
139
 
140
  if TaskService.do_cancel(msg["id"]):
141
+ with mt_lock:
142
+ DONE_TASKS += 1
143
  logging.info("Task {} has been canceled.".format(msg["id"]))
144
  return None
145
  task = TaskService.get_task(msg["id"])
146
  if not task:
147
+ with mt_lock:
148
+ DONE_TASKS += 1
149
  logging.warning("{} empty task!".format(msg["id"]))
150
  return None
151
 
 
432
 
433
 
434
  def handle_task():
435
+ global PAYLOAD, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
436
  task = collect()
437
  if task:
438
  try:
439
  logging.info(f"handle_task begin for task {json.dumps(task)}")
440
+ with mt_lock:
441
+ CURRENT_TASK = copy.deepcopy(task)
442
  do_handle_task(task)
443
+ with mt_lock:
444
+ DONE_TASKS += 1
445
+ CURRENT_TASK = None
446
+ logging.info(f"handle_task done for task {json.dumps(task)}")
447
  except Exception:
448
+ with mt_lock:
449
+ FAILED_TASKS += 1
450
+ CURRENT_TASK = None
451
  logging.exception(f"handle_task got exception for task {json.dumps(task)}")
452
  if PAYLOAD:
453
  PAYLOAD.ack()
 
455
 
456
 
457
  def report_status():
458
+ global CONSUMER_NAME, BOOT_AT, PENDING_TASKS, LAG_TASKS, mt_lock, DONE_TASKS, FAILED_TASKS, CURRENT_TASK
459
  REDIS_CONN.sadd("TASKEXE", CONSUMER_NAME)
460
  while True:
461
  try:
 
465
  PENDING_TASKS = int(group_info["pending"])
466
  LAG_TASKS = int(group_info["lag"])
467
 
468
+ with mt_lock:
469
+ heartbeat = json.dumps({
470
+ "name": CONSUMER_NAME,
471
+ "now": now.isoformat(),
472
+ "boot_at": BOOT_AT,
473
+ "pending": PENDING_TASKS,
474
+ "lag": LAG_TASKS,
475
+ "done": DONE_TASKS,
476
+ "failed": FAILED_TASKS,
477
+ "current": CURRENT_TASK,
478
+ })
479
  REDIS_CONN.zadd(CONSUMER_NAME, heartbeat, now.timestamp())
480
  logging.info(f"{CONSUMER_NAME} reported heartbeat: {heartbeat}")
481
 
 
487
  time.sleep(30)
488
 
489
  def main():
490
+ settings.init_settings()
491
  background_thread = threading.Thread(target=report_status)
492
  background_thread.daemon = True
493
  background_thread.start()