jbilcke-hf HF Staff commited on
Commit
26cd6a4
·
1 Parent(s): 4af8a5a

fix for log monitoring

Browse files
Files changed (1) hide show
  1. vms/training_service.py +21 -10
vms/training_service.py CHANGED
@@ -519,7 +519,6 @@ class TrainingService:
519
 
520
  def _start_log_monitor(self, process: subprocess.Popen) -> None:
521
  """Start monitoring process output for logs"""
522
-
523
 
524
  def monitor():
525
  self.append_log("Starting log monitor thread")
@@ -546,15 +545,27 @@ class TrainingService:
546
  return True
547
  return False
548
 
549
- # Use select to monitor both stdout and stderr
550
- while process.poll() is None:
551
- outputs = [process.stdout, process.stderr]
552
- readable, _, _ = select.select(outputs, [], [], 1.0)
553
-
554
- for stream in readable:
555
- is_error = (stream == process.stderr)
556
- read_stream(stream, is_error)
557
-
 
 
 
 
 
 
 
 
 
 
 
 
558
  # Process any remaining output after process ends
559
  while read_stream(process.stdout):
560
  pass
 
519
 
520
  def _start_log_monitor(self, process: subprocess.Popen) -> None:
521
  """Start monitoring process output for logs"""
 
522
 
523
  def monitor():
524
  self.append_log("Starting log monitor thread")
 
545
  return True
546
  return False
547
 
548
+ # Create separate threads to monitor stdout and stderr
549
+ def monitor_stream(stream, is_error=False):
550
+ while process.poll() is None:
551
+ if not read_stream(stream, is_error):
552
+ time.sleep(0.1) # Short sleep to avoid CPU thrashing
553
+
554
+ # Start threads to monitor each stream
555
+ stdout_thread = threading.Thread(target=monitor_stream, args=(process.stdout, False))
556
+ stderr_thread = threading.Thread(target=monitor_stream, args=(process.stderr, True))
557
+ stdout_thread.daemon = True
558
+ stderr_thread.daemon = True
559
+ stdout_thread.start()
560
+ stderr_thread.start()
561
+
562
+ # Wait for process to complete
563
+ process.wait()
564
+
565
+ # Wait for threads to finish reading any remaining output
566
+ stdout_thread.join(timeout=2)
567
+ stderr_thread.join(timeout=2)
568
+
569
  # Process any remaining output after process ends
570
  while read_stream(process.stdout):
571
  pass