From 634aace3264e8d7122e455579d6ccbbaeb990277 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Mon, 20 Apr 2026 13:12:52 -0700 Subject: [PATCH 1/3] refactor(lease): centralized LeaseManager decouples heartbeats from poll loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace inline heartbeat logic in both TaskRunner and AsyncTaskRunner with a process-wide LeaseManager singleton. A background daemon thread checks for due heartbeats at ~1s intervals and dispatches them to a ThreadPoolExecutor(4), so heartbeat API calls and retries never block task polling. Key changes: - LeaseManager: singleton with background thread + thread pool, fork-safe via PID check, lazy start on first track() call - TaskRunner: delegates to LeaseManager.track/untrack instead of inline _send_due_heartbeats/_send_heartbeat methods - AsyncTaskRunner: same delegation; creates a sync TaskResourceApi for LeaseManager since heartbeats run in the thread pool, not the event loop - task_handler.py: clean — no LeaseManager knowledge needed - 21 unit tests covering track/untrack, heartbeat dispatch, retries, non-blocking behavior, singleton/fork safety, thread safety Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .../client/automator/async_task_runner.py | 81 ++--- .../client/automator/lease_tracker.py | 209 +++++++++++- src/conductor/client/automator/task_runner.py | 84 ++--- tests/unit/automator/test_lease_manager.py | 320 ++++++++++++++++++ 4 files changed, 574 insertions(+), 120 deletions(-) create mode 100644 tests/unit/automator/test_lease_manager.py diff --git a/src/conductor/client/automator/async_task_runner.py b/src/conductor/client/automator/async_task_runner.py index 0801b0d8..6306adcf 100644 --- a/src/conductor/client/automator/async_task_runner.py +++ b/src/conductor/client/automator/async_task_runner.py @@ -32,7 +32,7 @@ from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline from conductor.client.worker.exception import NonRetryableException from conductor.client.automator.json_schema_generator import generate_json_schema_from_function -from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR +from conductor.client.automator.lease_tracker import LeaseManager logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -113,7 +113,9 @@ def __init__( self._semaphore = None self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint - self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats + self._lease_manager = LeaseManager.get_instance() + self._tracked_task_ids = set() # Local set for cleanup on shutdown + self._sync_task_client = None # Created after fork for LeaseManager heartbeats async def run(self) -> None: """Main async loop - runs continuously in single event loop.""" @@ -133,6 +135,17 @@ async def run(self) -> None: api_client=self.async_api_client ) + # Create a sync TaskResourceApi for LeaseManager heartbeats + # (LeaseManager sends heartbeats from its own ThreadPoolExecutor) + from conductor.client.http.api.task_resource_api import TaskResourceApi + from conductor.client.http.api_client import ApiClient + self._sync_task_client = TaskResourceApi( + ApiClient( + configuration=self.configuration, + metrics_collector=self.metrics_collector + ) + ) + # Create semaphore in the event loop (must be created within the loop) self._semaphore = asyncio.Semaphore(self._max_workers) @@ -168,8 +181,10 @@ async def _cleanup(self) -> None: """Clean up async resources.""" logger.debug("Cleaning up AsyncTaskRunner resources...") - # Stop all lease extension tracking - self._lease_info.clear() + # Untrack all tasks this runner was tracking from the shared LeaseManager + for task_id in list(self._tracked_task_ids): + self._lease_manager.untrack(task_id) + self._tracked_task_ids.clear() # Cancel any running tasks (EAFP style) try: @@ -441,9 +456,6 @@ async def __async_register_task_definition(self) -> None: async def run_once(self) -> None: """Execute one iteration of the polling loop (async version).""" try: - # Send lease extension heartbeats for any tasks that are due - await self._send_due_heartbeats() - # No need for manual cleanup - tasks remove themselves via add_done_callback # Just check capacity directly current_capacity = len(self._running_tasks) @@ -932,68 +944,27 @@ async def __async_update_task(self, task_result: TaskResult): return None - # -- Lease extension (heartbeat) methods ---------------------------------- + # -- Lease extension (heartbeat) delegation to LeaseManager ---------------- def _track_lease(self, task) -> None: - """Start tracking a task for lease extension heartbeat.""" + """Start tracking a task for lease extension via the shared LeaseManager.""" if not getattr(self.worker, 'lease_extend_enabled', False): return timeout = getattr(task, 'response_timeout_seconds', None) or 0 if timeout <= 0: return - interval = timeout * LEASE_EXTEND_DURATION_FACTOR - if interval < 1: - return - self._lease_info[task.task_id] = LeaseInfo( + self._lease_manager.track( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, response_timeout_seconds=timeout, - last_heartbeat_time=time.monotonic(), - interval_seconds=interval, - ) - logger.debug( - "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", - task.task_id, timeout, interval, + task_client=self._sync_task_client, ) + self._tracked_task_ids.add(task.task_id) def _untrack_lease(self, task_id: str) -> None: """Stop tracking a task for lease extension.""" - removed = self._lease_info.pop(task_id, None) - if removed is not None: - logger.debug("Untracked lease for task %s", task_id) - - async def _send_due_heartbeats(self) -> None: - """Check all tracked tasks and send heartbeats for any that are due.""" - if not self._lease_info: - return - now = time.monotonic() - for info in list(self._lease_info.values()): - elapsed = now - info.last_heartbeat_time - if elapsed < info.interval_seconds: - continue - await self._send_heartbeat(info) - info.last_heartbeat_time = time.monotonic() - - async def _send_heartbeat(self, info: LeaseInfo) -> None: - """Send a single lease extension heartbeat with retry (async).""" - result = TaskResult( - task_id=info.task_id, - workflow_instance_id=info.workflow_instance_id, - extend_lease=True, - ) - for attempt in range(LEASE_EXTEND_RETRY_COUNT): - try: - await self.async_task_client.update_task(body=result) - logger.debug("Extended lease for task %s", info.task_id) - return - except Exception as e: - if attempt < LEASE_EXTEND_RETRY_COUNT - 1: - await asyncio.sleep(0.5 * (attempt + 2)) - else: - logger.error( - "Failed to extend lease for task %s after %d attempts: %s", - info.task_id, LEASE_EXTEND_RETRY_COUNT, e, - ) + self._lease_manager.untrack(task_id) + self._tracked_task_ids.discard(task_id) # -------------------------------------------------------------------------- diff --git a/src/conductor/client/automator/lease_tracker.py b/src/conductor/client/automator/lease_tracker.py index 794e54e2..f98b5253 100644 --- a/src/conductor/client/automator/lease_tracker.py +++ b/src/conductor/client/automator/lease_tracker.py @@ -1,6 +1,27 @@ -"""Shared lease extension (heartbeat) tracking for TaskRunner and AsyncTaskRunner.""" +"""Centralized lease extension (heartbeat) management for Conductor task runners. +Architecture: + LeaseManager runs a single background daemon thread that periodically checks + for tasks needing lease extension heartbeats. Due heartbeats are dispatched + to a small fixed ThreadPoolExecutor for parallel, non-blocking API calls. + + This decouples heartbeat work entirely from worker poll loops, preventing + heartbeat API calls (and their retries) from blocking task polling. + + Thread-safe: track() and untrack() can be called from any thread or event loop. +""" + +import logging +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from typing import Any, Dict, Optional + +from conductor.client.http.models.task_result import TaskResult + +logger = logging.getLogger(__name__) # Lease extension constants (matches Java SDK) LEASE_EXTEND_RETRY_COUNT = 3 @@ -15,3 +36,189 @@ class LeaseInfo: response_timeout_seconds: float last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) interval_seconds: float # 80% of responseTimeoutSeconds + task_client: Any = None # Sync TaskResourceApi for sending heartbeats + + +class LeaseManager: + """Centralized lease extension manager for all workers in a process. + + One background daemon thread checks for due heartbeats at a fixed interval. + A small ThreadPoolExecutor sends heartbeat API calls in parallel. + Poll loops are never blocked by heartbeat work. + + Usage: + manager = LeaseManager.get_instance() + manager.track(task_id, workflow_id, timeout, task_client) + # ... task completes ... + manager.untrack(task_id) + """ + + _instance: Optional['LeaseManager'] = None + _instance_lock = threading.Lock() + _instance_pid: Optional[int] = None + + @classmethod + def get_instance(cls, check_interval: float = 1.0, + max_heartbeat_workers: int = 4) -> 'LeaseManager': + """Get or create the process-wide LeaseManager singleton. + + Fork-safe: a new instance is created after fork (threads don't survive fork). + """ + current_pid = os.getpid() + if cls._instance is None or cls._instance_pid != current_pid: + with cls._instance_lock: + if cls._instance is None or cls._instance_pid != current_pid: + cls._instance = cls( + check_interval=check_interval, + max_heartbeat_workers=max_heartbeat_workers, + ) + cls._instance_pid = current_pid + return cls._instance + + @classmethod + def _reset_instance(cls): + """Reset the singleton. For testing only.""" + with cls._instance_lock: + if cls._instance is not None: + cls._instance.shutdown() + cls._instance = None + cls._instance_pid = None + + def __init__(self, check_interval: float = 1.0, max_heartbeat_workers: int = 4): + self._tracked: Dict[str, LeaseInfo] = {} + self._lock = threading.Lock() + self._executor = ThreadPoolExecutor( + max_workers=max_heartbeat_workers, + thread_name_prefix="lease-heartbeat", + ) + self._stop_event = threading.Event() + self._check_interval = check_interval + self._thread: Optional[threading.Thread] = None + self._started = False + self._start_lock = threading.Lock() + + def _ensure_started(self) -> None: + """Lazily start the background thread on first track() call.""" + if self._started: + return + with self._start_lock: + if not self._started: + self._thread = threading.Thread( + target=self._run, daemon=True, name="lease-manager", + ) + self._thread.start() + self._started = True + logger.debug( + "LeaseManager started (check_interval=%.1fs)", self._check_interval, + ) + + def track(self, task_id: str, workflow_instance_id: str, + response_timeout_seconds: float, task_client: Any) -> None: + """Start tracking a task for lease extension heartbeats. + + Thread-safe. Can be called from any worker thread or event loop. + + Args: + task_id: Conductor task ID. + workflow_instance_id: Workflow instance this task belongs to. + response_timeout_seconds: The task's server-side response timeout. + task_client: A **sync** TaskResourceApi for sending heartbeat API calls. + """ + interval = response_timeout_seconds * LEASE_EXTEND_DURATION_FACTOR + if interval < 1: + logger.debug( + "Skipping lease tracking for task %s (interval %.1fs too short)", + task_id, interval, + ) + return + + info = LeaseInfo( + task_id=task_id, + workflow_instance_id=workflow_instance_id, + response_timeout_seconds=response_timeout_seconds, + last_heartbeat_time=time.monotonic(), + interval_seconds=interval, + task_client=task_client, + ) + with self._lock: + self._tracked[task_id] = info + self._ensure_started() + logger.debug( + "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", + task_id, response_timeout_seconds, interval, + ) + + def untrack(self, task_id: str) -> None: + """Stop tracking a task. Thread-safe.""" + with self._lock: + removed = self._tracked.pop(task_id, None) + if removed is not None: + logger.debug("Untracked lease for task %s", task_id) + + @property + def tracked_count(self) -> int: + """Number of currently tracked tasks.""" + with self._lock: + return len(self._tracked) + + # -- Background thread ----------------------------------------------------- + + def _run(self) -> None: + """Background loop — checks for due heartbeats at fixed intervals.""" + while not self._stop_event.is_set(): + try: + self._check_and_send() + except Exception as e: + logger.error("LeaseManager error: %s", e) + self._stop_event.wait(self._check_interval) + + def _check_and_send(self) -> None: + """Find tasks with due heartbeats and dispatch to the thread pool.""" + now = time.monotonic() + with self._lock: + due = [ + info for info in self._tracked.values() + if now - info.last_heartbeat_time >= info.interval_seconds + ] + for info in due: + # Update timestamp immediately to prevent double-dispatch on next tick + info.last_heartbeat_time = time.monotonic() + self._executor.submit(self._send_heartbeat, info) + + @staticmethod + def _send_heartbeat(info: LeaseInfo) -> None: + """Send a single lease extension heartbeat with retry. + + Runs in a pool thread — blocking retries only block the pool thread, + never a poll loop. + """ + result = TaskResult( + task_id=info.task_id, + workflow_instance_id=info.workflow_instance_id, + extend_lease=True, + ) + for attempt in range(LEASE_EXTEND_RETRY_COUNT): + try: + info.task_client.update_task(body=result) + logger.debug("Extended lease for task %s", info.task_id) + return + except Exception as e: + if attempt < LEASE_EXTEND_RETRY_COUNT - 1: + time.sleep(0.5 * (attempt + 2)) + else: + logger.error( + "Failed to extend lease for task %s after %d attempts: %s", + info.task_id, LEASE_EXTEND_RETRY_COUNT, e, + ) + + # -- Lifecycle ------------------------------------------------------------- + + def shutdown(self) -> None: + """Stop the background thread and thread pool.""" + self._stop_event.set() + if self._started and self._thread is not None: + self._thread.join(timeout=5) + self._executor.shutdown(wait=False) + with self._lock: + self._tracked.clear() + logger.debug("LeaseManager shut down") diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 242c8799..af566de1 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -36,7 +36,7 @@ from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline from conductor.client.worker.exception import NonRetryableException from conductor.client.automator.json_schema_generator import generate_json_schema_from_function -from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR +from conductor.client.automator.lease_tracker import LeaseManager logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -112,8 +112,9 @@ def __init__( self._consecutive_empty_polls = 0 # Track empty polls to implement backoff self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint - self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats - self._lease_lock = threading.Lock() # Protects _lease_info for free-threaded Python + self._lease_manager = LeaseManager.get_instance() + self._tracked_task_ids = set() # Local set for cleanup on shutdown + self._tracked_task_ids_lock = threading.Lock() def run(self) -> None: if self.configuration is not None: @@ -153,9 +154,12 @@ def _cleanup(self) -> None: """Clean up resources - called on exit.""" logger.debug("Cleaning up TaskRunner resources...") - # Stop all lease extension tracking - with self._lease_lock: - self._lease_info.clear() + # Untrack all tasks this runner was tracking from the shared LeaseManager + with self._tracked_task_ids_lock: + task_ids = list(self._tracked_task_ids) + self._tracked_task_ids.clear() + for task_id in task_ids: + self._lease_manager.untrack(task_id) # Shutdown ThreadPoolExecutor (EAFP style - more Pythonic) try: @@ -429,9 +433,6 @@ def __register_task_definition(self) -> None: def run_once(self) -> None: try: - # Send lease extension heartbeats for any tasks that are due - self._send_due_heartbeats() - # Check completed async tasks first (non-blocking) self.__check_completed_async_tasks() @@ -1077,74 +1078,29 @@ def __update_task(self, task_result: TaskResult): return None - # -- Lease extension (heartbeat) methods ---------------------------------- + # -- Lease extension (heartbeat) delegation to LeaseManager ---------------- def _track_lease(self, task: Task) -> None: - """Start tracking a task for lease extension heartbeat.""" - lease_enabled = getattr(self.worker, 'lease_extend_enabled', False) - if not lease_enabled: + """Start tracking a task for lease extension via the shared LeaseManager.""" + if not getattr(self.worker, 'lease_extend_enabled', False): return timeout = getattr(task, 'response_timeout_seconds', None) or 0 if timeout <= 0: return - interval = timeout * LEASE_EXTEND_DURATION_FACTOR - if interval < 1: - return - info = LeaseInfo( + self._lease_manager.track( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, response_timeout_seconds=timeout, - last_heartbeat_time=time.monotonic(), - interval_seconds=interval, - ) - with self._lease_lock: - self._lease_info[task.task_id] = info - logger.debug( - "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", - task.task_id, timeout, interval, + task_client=self.task_client, ) + with self._tracked_task_ids_lock: + self._tracked_task_ids.add(task.task_id) def _untrack_lease(self, task_id: str) -> None: """Stop tracking a task for lease extension.""" - with self._lease_lock: - removed = self._lease_info.pop(task_id, None) - if removed is not None: - logger.debug("Untracked lease for task %s", task_id) - - def _send_due_heartbeats(self) -> None: - """Check all tracked tasks and send heartbeats for any that are due.""" - if not self._lease_info: - return - now = time.monotonic() - with self._lease_lock: - infos = list(self._lease_info.values()) - for info in infos: - elapsed = now - info.last_heartbeat_time - if elapsed < info.interval_seconds: - continue - self._send_heartbeat(info) - info.last_heartbeat_time = time.monotonic() - - def _send_heartbeat(self, info: LeaseInfo) -> None: - """Send a single lease extension heartbeat with retry.""" - result = TaskResult( - task_id=info.task_id, - workflow_instance_id=info.workflow_instance_id, - extend_lease=True, - ) - for attempt in range(LEASE_EXTEND_RETRY_COUNT): - try: - self.task_client.update_task(body=result) - logger.debug("Extended lease for task %s", info.task_id) - return - except Exception as e: - if attempt < LEASE_EXTEND_RETRY_COUNT - 1: - time.sleep(0.5 * (attempt + 2)) - else: - logger.error( - "Failed to extend lease for task %s after %d attempts: %s", - info.task_id, LEASE_EXTEND_RETRY_COUNT, e, - ) + self._lease_manager.untrack(task_id) + with self._tracked_task_ids_lock: + self._tracked_task_ids.discard(task_id) # -------------------------------------------------------------------------- diff --git a/tests/unit/automator/test_lease_manager.py b/tests/unit/automator/test_lease_manager.py new file mode 100644 index 00000000..9c25d37e --- /dev/null +++ b/tests/unit/automator/test_lease_manager.py @@ -0,0 +1,320 @@ +"""Tests for the centralized LeaseManager.""" + +import threading +import time +import unittest +from unittest.mock import MagicMock, call, patch + +from conductor.client.automator.lease_tracker import ( + LeaseManager, + LeaseInfo, + LEASE_EXTEND_DURATION_FACTOR, + LEASE_EXTEND_RETRY_COUNT, +) + + +class TestLeaseManagerTrackUntrack(unittest.TestCase): + """Test track/untrack operations.""" + + def setUp(self): + LeaseManager._reset_instance() + self.manager = LeaseManager(check_interval=60) # Long interval — we trigger manually + + def tearDown(self): + self.manager.shutdown() + LeaseManager._reset_instance() + + def test_track_adds_task(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.assertEqual(self.manager.tracked_count, 1) + + def test_untrack_removes_task(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.manager.untrack('task-1') + self.assertEqual(self.manager.tracked_count, 0) + + def test_untrack_nonexistent_is_noop(self): + self.manager.untrack('nonexistent') + self.assertEqual(self.manager.tracked_count, 0) + + def test_track_skips_short_interval(self): + """Tasks with response_timeout < ~1.25s (interval < 1s) should be skipped.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 1.0, client) # 1.0 * 0.8 = 0.8 < 1 + self.assertEqual(self.manager.tracked_count, 0) + + def test_track_accepts_valid_timeout(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) # 10 * 0.8 = 8.0 >= 1 + self.assertEqual(self.manager.tracked_count, 1) + + def test_track_multiple_tasks(self): + client = MagicMock() + for i in range(10): + self.manager.track(f'task-{i}', f'wf-{i}', 30.0, client) + self.assertEqual(self.manager.tracked_count, 10) + + def test_track_overwrites_existing(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.manager.track('task-1', 'wf-1', 60.0, client) + self.assertEqual(self.manager.tracked_count, 1) + + +class TestLeaseManagerHeartbeat(unittest.TestCase): + """Test heartbeat dispatch logic.""" + + def setUp(self): + LeaseManager._reset_instance() + self.manager = LeaseManager(check_interval=60) + + def tearDown(self): + self.manager.shutdown() + LeaseManager._reset_instance() + + def test_heartbeat_sent_when_due(self): + """Heartbeat should be dispatched when interval has elapsed.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) + + # Fast-forward: set last_heartbeat_time to the past + with self.manager._lock: + info = self.manager._tracked['task-1'] + info.last_heartbeat_time = time.monotonic() - 20 # Well past the 8s interval + + self.manager._check_and_send() + + # Wait for the pool thread to execute the heartbeat + self.manager._executor.shutdown(wait=True) + client.update_task.assert_called_once() + result = client.update_task.call_args[1]['body'] + self.assertEqual(result.task_id, 'task-1') + self.assertEqual(result.workflow_instance_id, 'wf-1') + self.assertTrue(result.extend_lease) + + def test_heartbeat_not_sent_when_not_due(self): + """Heartbeat should NOT be dispatched when interval hasn't elapsed.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) + + self.manager._check_and_send() + + self.manager._executor.shutdown(wait=True) + client.update_task.assert_not_called() + + def test_heartbeat_retries_on_failure(self): + """Heartbeat should retry up to LEASE_EXTEND_RETRY_COUNT times.""" + client = MagicMock() + client.update_task.side_effect = Exception("server error") + + info = LeaseInfo( + task_id='task-1', + workflow_instance_id='wf-1', + response_timeout_seconds=30.0, + last_heartbeat_time=time.monotonic(), + interval_seconds=24.0, + task_client=client, + ) + + with patch('conductor.client.automator.lease_tracker.time.sleep'): + LeaseManager._send_heartbeat(info) + + self.assertEqual(client.update_task.call_count, LEASE_EXTEND_RETRY_COUNT) + + def test_heartbeat_stops_retrying_on_success(self): + """Heartbeat should stop retrying after a successful call.""" + client = MagicMock() + client.update_task.side_effect = [Exception("fail"), None] # Fail then succeed + + info = LeaseInfo( + task_id='task-1', + workflow_instance_id='wf-1', + response_timeout_seconds=30.0, + last_heartbeat_time=time.monotonic(), + interval_seconds=24.0, + task_client=client, + ) + + with patch('conductor.client.automator.lease_tracker.time.sleep'): + LeaseManager._send_heartbeat(info) + + self.assertEqual(client.update_task.call_count, 2) + + def test_multiple_tasks_heartbeats_dispatched_independently(self): + """Each due task gets its own heartbeat dispatch.""" + client_a = MagicMock() + client_b = MagicMock() + + self.manager.track('task-a', 'wf-a', 10.0, client_a) + self.manager.track('task-b', 'wf-b', 10.0, client_b) + + # Make both due + with self.manager._lock: + past = time.monotonic() - 20 + self.manager._tracked['task-a'].last_heartbeat_time = past + self.manager._tracked['task-b'].last_heartbeat_time = past + + self.manager._check_and_send() + self.manager._executor.shutdown(wait=True) + + client_a.update_task.assert_called_once() + client_b.update_task.assert_called_once() + + +class TestLeaseManagerNonBlocking(unittest.TestCase): + """Test that heartbeats don't block the caller.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_poll_loop_not_blocked_by_slow_heartbeat(self): + """The caller should return immediately even if heartbeat is slow.""" + slow_client = MagicMock() + slow_client.update_task.side_effect = lambda **kw: time.sleep(2) + + manager = LeaseManager(check_interval=60) + manager.track('task-1', 'wf-1', 10.0, slow_client) + + with manager._lock: + manager._tracked['task-1'].last_heartbeat_time = time.monotonic() - 20 + + start = time.monotonic() + manager._check_and_send() # Submits to pool, returns immediately + elapsed = time.monotonic() - start + + # _check_and_send should return in < 100ms (it just submits to the pool) + self.assertLess(elapsed, 0.1, "check_and_send blocked for too long") + + manager.shutdown() + + +class TestLeaseManagerSingleton(unittest.TestCase): + """Test singleton behavior.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_get_instance_returns_same_object(self): + a = LeaseManager.get_instance() + b = LeaseManager.get_instance() + self.assertIs(a, b) + a.shutdown() + + def test_reset_creates_new_instance(self): + a = LeaseManager.get_instance() + LeaseManager._reset_instance() + b = LeaseManager.get_instance() + self.assertIsNot(a, b) + b.shutdown() + + @patch('conductor.client.automator.lease_tracker.os.getpid') + def test_new_instance_after_fork(self, mock_getpid): + """After fork (different PID), a fresh instance should be created.""" + mock_getpid.return_value = 1000 + a = LeaseManager.get_instance() + + mock_getpid.return_value = 2000 # Simulate fork + b = LeaseManager.get_instance() + + self.assertIsNot(a, b) + a.shutdown() + b.shutdown() + + +class TestLeaseManagerBackgroundThread(unittest.TestCase): + """Test the background thread lifecycle.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_thread_starts_lazily_on_first_track(self): + manager = LeaseManager(check_interval=60) + self.assertFalse(manager._started) + + client = MagicMock() + manager.track('task-1', 'wf-1', 10.0, client) + self.assertTrue(manager._started) + self.assertTrue(manager._thread.is_alive()) + + manager.shutdown() + + def test_thread_not_started_if_no_tracks(self): + manager = LeaseManager(check_interval=60) + self.assertFalse(manager._started) + manager.shutdown() + + def test_background_thread_sends_heartbeats(self): + """Verify the background thread actually dispatches heartbeats.""" + client = MagicMock() + manager = LeaseManager(check_interval=0.1) # Check every 100ms + + manager.track('task-1', 'wf-1', 10.0, client) + + # Make it due + with manager._lock: + manager._tracked['task-1'].last_heartbeat_time = time.monotonic() - 20 + + # Wait for background thread to pick it up + time.sleep(0.5) + + manager.shutdown() + client.update_task.assert_called() + + def test_shutdown_stops_thread(self): + manager = LeaseManager(check_interval=0.1) + client = MagicMock() + manager.track('task-1', 'wf-1', 10.0, client) + self.assertTrue(manager._thread.is_alive()) + + manager.shutdown() + self.assertFalse(manager._thread.is_alive()) + + +class TestLeaseManagerThreadSafety(unittest.TestCase): + """Test concurrent track/untrack operations.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_concurrent_track_untrack(self): + """Many threads tracking/untracking should not corrupt state.""" + manager = LeaseManager(check_interval=60) + client = MagicMock() + errors = [] + + def track_and_untrack(thread_id): + try: + for i in range(50): + task_id = f'task-{thread_id}-{i}' + manager.track(task_id, f'wf-{thread_id}', 30.0, client) + manager.untrack(task_id) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=track_and_untrack, args=(t,)) for t in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, []) + self.assertEqual(manager.tracked_count, 0) + manager.shutdown() + + +if __name__ == '__main__': + unittest.main() From 3209378720812cac1c533ebfe3266465de641125 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Mon, 20 Apr 2026 13:40:43 -0700 Subject: [PATCH 2/3] test(lease): add async worker E2E tests for LeaseManager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three tests against a live Conductor server: 1. Async task with heartbeat → COMPLETED (50s sleep, 10s timeout) 2. Async task without heartbeat → TIMED_OUT 3. Performance: heartbeat tracking adds 0ms overhead on fast tasks --- .../integration/test_async_lease_extension.py | 344 ++++++++++++++++++ 1 file changed, 344 insertions(+) create mode 100644 tests/integration/test_async_lease_extension.py diff --git a/tests/integration/test_async_lease_extension.py b/tests/integration/test_async_lease_extension.py new file mode 100644 index 00000000..9a73c88b --- /dev/null +++ b/tests/integration/test_async_lease_extension.py @@ -0,0 +1,344 @@ +""" +E2E test for lease extension with async workers (AsyncTaskRunner). + +Proves that the centralized LeaseManager works correctly with async workers: + +1. WITH lease extension: async long-running task COMPLETES even when execution + time exceeds responseTimeoutSeconds — heartbeats keep the lease alive. + +2. WITHOUT lease extension: same async task TIMES OUT after responseTimeoutSeconds. + +3. PERFORMANCE: async worker with heartbeat enabled but short task (no heartbeat + actually needed) has no meaningful overhead vs. one without heartbeat tracking. + +Run: + export CONDUCTOR_SERVER_URL="http://localhost:8000/api" + python3 -m pytest tests/integration/test_async_lease_extension.py -v -s + +Prerequisites: + - Conductor server running (e.g. http://localhost:8000/api) +""" + +import asyncio +import logging +import os +import sys +import time +import threading +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from conductor.client.automator.task_handler import TaskHandler +from conductor.client.configuration.configuration import Configuration +from conductor.client.worker.worker_task import worker_task +from conductor.client.http.models.workflow_def import WorkflowDef +from conductor.client.http.models.task_def import TaskDef +from conductor.client.http.models.workflow_task import WorkflowTask +from conductor.client.http.models.start_workflow_request import StartWorkflowRequest +from conductor.client.orkes.orkes_workflow_client import OrkesWorkflowClient +from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Short response timeout — task must heartbeat to stay alive +RESPONSE_TIMEOUT_SECONDS = 10 + +# Task sleeps longer than the response timeout to prove heartbeat works. +# Must be long enough that the server's workflow sweeper catches the expired +# task BEFORE the worker completes. +TASK_SLEEP_SECONDS = 50 + +# Short task duration for performance test — well within timeout +FAST_TASK_SLEEP_SECONDS = 2 + +# Number of fast tasks for performance comparison +PERF_TASK_COUNT = 5 + + +# -- Async Workers ----------------------------------------------------------- + +@worker_task( + task_definition_name='async_lease_heartbeat_task', + lease_extend_enabled=True, + register_task_def=True, + task_def=TaskDef( + name='async_lease_heartbeat_task', + response_timeout_seconds=RESPONSE_TIMEOUT_SECONDS, + timeout_seconds=180, + retry_count=0, + ), + overwrite_task_def=True, +) +async def async_lease_heartbeat_task(job_id: str) -> dict: + """Async long-running task with heartbeat — should complete.""" + logger.info("[async_heartbeat] Starting job %s, sleeping %ss (timeout=%ss)", + job_id, TASK_SLEEP_SECONDS, RESPONSE_TIMEOUT_SECONDS) + await asyncio.sleep(TASK_SLEEP_SECONDS) + logger.info("[async_heartbeat] Completed job %s", job_id) + return {'job_id': job_id, 'status': 'completed', 'slept': TASK_SLEEP_SECONDS} + + +@worker_task( + task_definition_name='async_lease_no_heartbeat_task', + lease_extend_enabled=False, + register_task_def=True, + task_def=TaskDef( + name='async_lease_no_heartbeat_task', + response_timeout_seconds=RESPONSE_TIMEOUT_SECONDS, + timeout_seconds=120, + retry_count=0, + ), + overwrite_task_def=True, +) +async def async_lease_no_heartbeat_task(job_id: str) -> dict: + """Async long-running task without heartbeat — should time out.""" + logger.info("[async_no_heartbeat] Starting job %s, sleeping %ss (timeout=%ss)", + job_id, TASK_SLEEP_SECONDS, RESPONSE_TIMEOUT_SECONDS) + await asyncio.sleep(TASK_SLEEP_SECONDS) + logger.info("[async_no_heartbeat] Completed job %s", job_id) + return {'job_id': job_id, 'status': 'completed', 'slept': TASK_SLEEP_SECONDS} + + +@worker_task( + task_definition_name='async_lease_fast_with_hb', + lease_extend_enabled=True, + register_task_def=True, + task_def=TaskDef( + name='async_lease_fast_with_hb', + response_timeout_seconds=60, + timeout_seconds=120, + retry_count=0, + ), + overwrite_task_def=True, +) +async def async_lease_fast_with_hb(job_id: str) -> dict: + """Fast async task with heartbeat enabled — for overhead measurement.""" + await asyncio.sleep(FAST_TASK_SLEEP_SECONDS) + return {'job_id': job_id, 'status': 'completed'} + + +@worker_task( + task_definition_name='async_lease_fast_no_hb', + lease_extend_enabled=False, + register_task_def=True, + task_def=TaskDef( + name='async_lease_fast_no_hb', + response_timeout_seconds=60, + timeout_seconds=120, + retry_count=0, + ), + overwrite_task_def=True, +) +async def async_lease_fast_no_hb(job_id: str) -> dict: + """Fast async task without heartbeat — baseline for comparison.""" + await asyncio.sleep(FAST_TASK_SLEEP_SECONDS) + return {'job_id': job_id, 'status': 'completed'} + + +# -- Test class -------------------------------------------------------------- + +class TestAsyncLeaseExtension(unittest.TestCase): + + @classmethod + def setUpClass(cls): + from tests.integration.conftest import skip_if_server_unavailable + skip_if_server_unavailable() + + cls.config = Configuration() + cls.metadata_client = OrkesMetadataClient(cls.config) + cls.workflow_client = OrkesWorkflowClient(cls.config) + + def _register_workflow(self, wf_name, task_names): + """Register a workflow with one or more tasks in sequence.""" + workflow = WorkflowDef(name=wf_name, version=1) + tasks = [] + for task_name in (task_names if isinstance(task_names, list) else [task_names]): + tasks.append(WorkflowTask( + name=task_name, + task_reference_name=f'{task_name}_ref', + input_parameters={'job_id': '${workflow.input.job_id}'}, + )) + workflow._tasks = tasks + try: + self.metadata_client.update_workflow_def(workflow, overwrite=True) + except Exception: + self.metadata_client.register_workflow_def(workflow, overwrite=True) + logger.info("Registered workflow: %s", wf_name) + + def _start_workflow(self, wf_name, job_id): + """Start a workflow and return the execution ID.""" + req = StartWorkflowRequest() + req.name = wf_name + req.version = 1 + req.input = {'job_id': job_id} + wf_id = self.workflow_client.start_workflow(start_workflow_request=req) + logger.info("Started workflow %s: %s", wf_name, wf_id) + return wf_id + + def _wait_for_workflow(self, wf_id, timeout_seconds=90): + """Poll until workflow reaches a terminal state.""" + for _ in range(timeout_seconds): + wf = self.workflow_client.get_workflow(wf_id, include_tasks=True) + if wf.status in ('COMPLETED', 'FAILED', 'TIMED_OUT', 'TERMINATED'): + return wf + time.sleep(1) + return self.workflow_client.get_workflow(wf_id, include_tasks=True) + + def _run_workers_in_background(self, duration_seconds=90): + """Start workers in a background thread, return stop function.""" + handler = TaskHandler( + configuration=self.config, + scan_for_annotated_workers=True, + ) + handler.start_processes() + + def stop(): + handler.stop_processes() + + timer = threading.Timer(duration_seconds, stop) + timer.daemon = True + timer.start() + + return stop + + # -- Tests ---------------------------------------------------------------- + + def test_01_async_with_heartbeat_completes(self): + """Async task WITH lease_extend_enabled=True completes when sleep > responseTimeout.""" + print("\n" + "=" * 80) + print("TEST: Async with heartbeat — task should COMPLETE") + print(f" responseTimeoutSeconds={RESPONSE_TIMEOUT_SECONDS}s, task sleeps {TASK_SLEEP_SECONDS}s") + print("=" * 80) + + wf_name = 'test_async_lease_heartbeat' + self._register_workflow(wf_name, 'async_lease_heartbeat_task') + + stop_workers = self._run_workers_in_background(duration_seconds=90) + time.sleep(3) # let workers start + + try: + wf_id = self._start_workflow(wf_name, 'ASYNC-HB-001') + wf = self._wait_for_workflow(wf_id, timeout_seconds=80) + + print(f"\n Workflow ID: {wf_id}") + print(f" Final status: {wf.status}") + for task in (wf.tasks or []): + print(f" Task {task.task_def_name}: {task.status}") + + self.assertEqual(wf.status, 'COMPLETED', + f"Workflow should COMPLETE with heartbeat, got {wf.status}") + + tasks_by_ref = {t.reference_task_name: t for t in wf.tasks} + task = tasks_by_ref.get('async_lease_heartbeat_task_ref') + self.assertIsNotNone(task) + self.assertEqual(task.status, 'COMPLETED') + self.assertEqual(task.output_data.get('job_id'), 'ASYNC-HB-001') + self.assertEqual(task.output_data.get('slept'), TASK_SLEEP_SECONDS) + print("\n PASS: Async task completed with heartbeat keeping lease alive") + finally: + stop_workers() + + def test_02_async_without_heartbeat_times_out(self): + """Async task WITHOUT lease_extend_enabled times out when sleep > responseTimeout.""" + print("\n" + "=" * 80) + print("TEST: Async without heartbeat — task should TIME OUT") + print(f" responseTimeoutSeconds={RESPONSE_TIMEOUT_SECONDS}s, task sleeps {TASK_SLEEP_SECONDS}s") + print("=" * 80) + + wf_name = 'test_async_lease_no_heartbeat' + self._register_workflow(wf_name, 'async_lease_no_heartbeat_task') + + stop_workers = self._run_workers_in_background(duration_seconds=90) + time.sleep(3) + + try: + wf_id = self._start_workflow(wf_name, 'ASYNC-NOHB-001') + wf = self._wait_for_workflow(wf_id, timeout_seconds=80) + + print(f"\n Workflow ID: {wf_id}") + print(f" Final status: {wf.status}") + for task in (wf.tasks or []): + print(f" Task {task.task_def_name}: {task.status}") + + self.assertIn(wf.status, ('FAILED', 'TIMED_OUT'), + f"Workflow should FAIL/TIMEOUT without heartbeat, got {wf.status}") + + tasks_by_ref = {t.reference_task_name: t for t in wf.tasks} + task = tasks_by_ref.get('async_lease_no_heartbeat_task_ref') + self.assertIsNotNone(task) + self.assertIn(task.status, ('TIMED_OUT', 'FAILED', 'CANCELED'), + f"Task should be TIMED_OUT/FAILED, got {task.status}") + print("\n PASS: Async task timed out as expected without heartbeat") + finally: + stop_workers() + + def test_03_no_performance_overhead(self): + """Heartbeat tracking adds no meaningful overhead to fast async tasks.""" + print("\n" + "=" * 80) + print("TEST: Performance — heartbeat enabled vs disabled on fast tasks") + print(f" Running {PERF_TASK_COUNT} tasks each, sleep={FAST_TASK_SLEEP_SECONDS}s") + print("=" * 80) + + wf_with_hb = 'test_async_perf_with_hb' + wf_no_hb = 'test_async_perf_no_hb' + self._register_workflow(wf_with_hb, 'async_lease_fast_with_hb') + self._register_workflow(wf_no_hb, 'async_lease_fast_no_hb') + + stop_workers = self._run_workers_in_background(duration_seconds=120) + time.sleep(3) + + try: + # Run tasks WITH heartbeat tracking + hb_workflow_ids = [] + for i in range(PERF_TASK_COUNT): + wf_id = self._start_workflow(wf_with_hb, f'PERF-HB-{i:03d}') + hb_workflow_ids.append(wf_id) + + # Run tasks WITHOUT heartbeat tracking + no_hb_workflow_ids = [] + for i in range(PERF_TASK_COUNT): + wf_id = self._start_workflow(wf_no_hb, f'PERF-NOHB-{i:03d}') + no_hb_workflow_ids.append(wf_id) + + # Wait for all to complete + hb_times = [] + for wf_id in hb_workflow_ids: + wf = self._wait_for_workflow(wf_id, timeout_seconds=30) + self.assertEqual(wf.status, 'COMPLETED', + f"Fast HB task should complete, got {wf.status}") + task = wf.tasks[0] + duration_ms = task.end_time - task.start_time + hb_times.append(duration_ms) + + no_hb_times = [] + for wf_id in no_hb_workflow_ids: + wf = self._wait_for_workflow(wf_id, timeout_seconds=30) + self.assertEqual(wf.status, 'COMPLETED', + f"Fast no-HB task should complete, got {wf.status}") + task = wf.tasks[0] + duration_ms = task.end_time - task.start_time + no_hb_times.append(duration_ms) + + avg_hb = sum(hb_times) / len(hb_times) + avg_no_hb = sum(no_hb_times) / len(no_hb_times) + overhead_ms = avg_hb - avg_no_hb + overhead_pct = (overhead_ms / avg_no_hb * 100) if avg_no_hb > 0 else 0 + + print(f"\n With heartbeat: avg {avg_hb:.0f}ms {hb_times}") + print(f" Without heartbeat: avg {avg_no_hb:.0f}ms {no_hb_times}") + print(f" Overhead: {overhead_ms:+.0f}ms ({overhead_pct:+.1f}%)") + + # Heartbeat tracking should add < 500ms overhead per task + # (LeaseManager.track is just a dict insert + set add) + self.assertLess(overhead_ms, 500, + f"Heartbeat overhead too high: {overhead_ms:.0f}ms") + + print("\n PASS: No meaningful performance overhead from heartbeat tracking") + finally: + stop_workers() + + +if __name__ == '__main__': + unittest.main() From f7617adad88ceeff50780a3f3041deedc0f41129 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Mon, 20 Apr 2026 14:29:23 -0700 Subject: [PATCH 3/3] fix(lease): close sync HTTP client in AsyncTaskRunner cleanup --- src/conductor/client/automator/async_task_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/conductor/client/automator/async_task_runner.py b/src/conductor/client/automator/async_task_runner.py index 6306adcf..ba3e3653 100644 --- a/src/conductor/client/automator/async_task_runner.py +++ b/src/conductor/client/automator/async_task_runner.py @@ -202,6 +202,13 @@ async def _cleanup(self) -> None: except (IOError, OSError) as e: logger.warning(f"Error closing async client: {e}") + # Close sync HTTP client used for lease heartbeats + if self._sync_task_client: + try: + self._sync_task_client.api_client.rest_client.connection.close() + except Exception: + pass + # Clear event listeners self.event_dispatcher = None