diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerManager.java b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerManager.java index 0acc169577..b35cb3c003 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerManager.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/rm/TaskSchedulerManager.java @@ -227,8 +227,8 @@ public int getNumClusterNodes() { return getNumClusterNodes(false); } - public int getNumClusterNodes(boolean tryUpdate){ - if (cachedNodeCount == -1 && tryUpdate){ + public int getNumClusterNodes(boolean tryUpdate) { + if (cachedNodeCount == -1 || tryUpdate) { cachedNodeCount = countAllNodes(); } return cachedNodeCount; diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java index 595c045d80..d814260209 100644 --- a/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java +++ b/tez-dag/src/test/java/org/apache/tez/dag/app/rm/TestTaskSchedulerManager.java @@ -1240,4 +1240,29 @@ public static void waitFor(Supplier check, int checkEveryMillis, throw new TimeoutException("Timed out waiting for condition."); } } + + @Test(timeout = 5000) + public void testGetNumClusterNodes() throws Exception { + Configuration conf = new Configuration(false); + taskSchedulerManager.init(conf); + taskSchedulerManager.start(); + + // Mock the underlying task scheduler to simulate a scale up from 10 nodes to 20 nodes + when(mockTaskScheduler.getClusterNodeCount()).thenReturn(10).thenReturn(20); + + // Initial call, cachedNodeCount is -1, should fetch and cache 10 + int count1 = taskSchedulerManager.getNumClusterNodes(false); + assertEquals(10, count1); + + // Second call, tryUpdate is false, should return cached 10 + int count2 = taskSchedulerManager.getNumClusterNodes(false); + assertEquals(10, count2); + + // Third call, tryUpdate is true, should fetch new value 20 + int count3 = taskSchedulerManager.getNumClusterNodes(true); + assertEquals(20, count3); + + // Verify getClusterNodeCount was called exactly twice (once initially, once on forced update) + verify(mockTaskScheduler, times(2)).getClusterNodeCount(); + } }