Skip to content

Commit

Permalink
Include remove of task in hasmap when completed
Browse files Browse the repository at this point in the history
Signed-off-by: jorgee <[email protected]>
  • Loading branch information
jorgee committed Jan 30, 2025
1 parent 9fdae13 commit 5f398ea
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -469,18 +469,16 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
}

protected String getStateFromTaskStatus() {
final tasks = client.listTasks(jobId)
if( !tasks.iterator().hasNext() ) {
return getStateFromJobStatus()
}
final now = System.currentTimeMillis()
final delta = now - timestamp;
if( !taskState || delta >= 1_000) {
try {
final status = client.getTaskStatus(jobId, taskId)
final status = client.getTaskInArrayStatus(jobId, taskId)
if( status ) {
inspectTaskStatus(status)
}catch (NotFoundException e) {
manageNotFound(tasks)
} else {
// If no task status retrieved check job status
final jobStatus = client.getJobStatus(jobId)
inspectJobStatus(jobStatus)
}
}
return taskState
Expand Down Expand Up @@ -511,20 +509,6 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
}
}

protected String manageNotFound( Iterable<Task> tasks) {
// If task is array, check if the in the task list
for (Task t in tasks) {
if (t.name == client.generateTaskName(jobId, taskId)) {
inspectTaskStatus(t.status)
return taskState
}
}
// if not array or it task is not in the list, check job status.
final status = client.getJobStatus(jobId)
inspectJobStatus(status)
return taskState
}

protected String inspectJobStatus(JobStatus status) {
final newState = status?.state as String
if (newState) {
Expand Down Expand Up @@ -577,6 +561,8 @@ class GoogleBatchTaskHandler extends TaskHandler implements FusionAwareTask {
task.stderr = errorFile
}
status = TaskStatus.COMPLETED
if( belongsToArray )
client.removeFromArrayTasks(jobId, taskId)
return true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ import groovy.util.logging.Slf4j
@Slf4j
@CompileStatic
class BatchClient {

private static long TASK_STATE_INVALID_TIME = 1_000
protected String projectId
protected String location
protected BatchServiceClient batchServiceClient
protected BatchConfig config
private Map<String, TaskStatusRecord> arrayTaskStatus = new HashMap<String, TaskStatusRecord>()

BatchClient(BatchConfig config) {
this.config = config
Expand Down Expand Up @@ -198,4 +199,40 @@ class BatchClient {
// apply the action with
return Failsafe.with(policy).get(action)
}


TaskStatus getTaskInArrayStatus(String jobId, String taskId) {
final taskName = generateTaskName(jobId,taskId)
final now = System.currentTimeMillis()
TaskStatusRecord record = arrayTaskStatus.get(taskName)
if( !record || now - record.timestamp > TASK_STATE_INVALID_TIME ){
log.debug("[GOOGLE BATCH] Updating tasks status for job $jobId")
updateArrayTasks(jobId, now)
record = arrayTaskStatus.get(taskName)
}
return record?.status
}

private void updateArrayTasks(String jobId, long now){
for( Task t: listTasks(jobId) ){
arrayTaskStatus.put(t.name, new TaskStatusRecord(t.status, now))
}

}

void removeFromArrayTasks(String jobId, String taskId){
final taskName = generateTaskName(jobId,taskId)
TaskStatusRecord record = arrayTaskStatus.remove(taskName)
}
}

@CompileStatic
class TaskStatusRecord {
protected TaskStatus status
protected long timestamp

TaskStatusRecord(TaskStatus status, long timestamp) {
this.status = status
this.timestamp = timestamp
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,16 @@ class GoogleBatchTaskHandlerTest extends Specification {

}

TaskStatus makeTaskStatus(String desc) {
TaskStatus.newBuilder()
.addStatusEvents(
TaskStatus makeTaskStatus(TaskStatus.State state, String desc) {
def builder = TaskStatus.newBuilder()
if (state)
builder.setState(state)
if (desc)
builder.addStatusEvents(
StatusEvent.newBuilder()
.setDescription(desc)
)
.build()
builder.build()
}

def 'should detect spot failures from status event'() {
Expand All @@ -486,8 +489,8 @@ class GoogleBatchTaskHandlerTest extends Specification {

when:
client.getTaskStatus(jobId, taskId) >>> [
makeTaskStatus('Task failed due to Spot VM preemption with exit code 50001.'),
makeTaskStatus('Task succeeded')
makeTaskStatus(null,'Task failed due to Spot VM preemption with exit code 50001.'),
makeTaskStatus(null, 'Task succeeded')
]
then:
handler.getJobError().message == "Task failed due to Spot VM preemption with exit code 50001."
Expand Down Expand Up @@ -637,15 +640,15 @@ class GoogleBatchTaskHandlerTest extends Specification {
client.generateTaskName(jobId, taskId) >> "$jobId/group0/$taskId"
//Force errors
client.getTaskStatus(jobId, taskId) >> { throw new NotFoundException(new Exception("Error"), GrpcStatusCode.of(Status.Code.NOT_FOUND), false) }
client.listTasks(jobId) >> TASK_LIST
client.getTaskInArrayStatus(jobId, taskId) >> TASK_STATUS
client.getJobStatus(jobId) >> makeJobStatus(JOB_STATUS, "")
then:
handler.getTaskState() == EXPECTED

where:
EXPECTED | JOB_STATUS | TASK_LIST
"FAILED" | JobStatus.State.FAILED | {[ makeTask("1/group0/2", TaskStatus.State.PENDING), makeTask("1/group0/3", TaskStatus.State.PENDING) ].iterator() } // Task not in the list, get from job
"SUCCEEDED" | JobStatus.State.FAILED | {[ makeTask("1/group0/1", TaskStatus.State.SUCCEEDED), makeTask("1/group0/2", TaskStatus.State.PENDING)].iterator() } //Task in the list, get from task status
EXPECTED | JOB_STATUS | TASK_STATUS
"FAILED" | JobStatus.State.FAILED | null // Task not in the list, get from job
"SUCCEEDED" | JobStatus.State.FAILED | makeTaskStatus(TaskStatus.State.SUCCEEDED, "") // get from task status
}

def makeTask(String name, TaskStatus.State state){
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package nextflow.cloud.google.batch.client

import com.google.cloud.batch.v1.Task
import com.google.cloud.batch.v1.TaskName
import com.google.cloud.batch.v1.TaskStatus
import spock.lang.Specification

/**
*
* @author Jorge Ejarque <[email protected]>
*/
class BatchClientTest extends Specification{



def 'should return task status with getTaskInArray' () {
given:
def project = 'project-id'
def location = 'location-id'
def job1 = 'job1-id'
def task1 = 'task1-id'
def task1Name = TaskName.of(project, location, job1, 'group0', task1).toString()
def job2 = 'job2-id'
def task2 = 'task2-id'
def task2Name = TaskName.of(project, location, job2, 'group0', task2).toString()
def job3 = 'job3-id'
def task3 = 'task3-id'
def task3Name = TaskName.of(project, location, job3, 'group0', task3).toString()
def now = System.currentTimeMillis()
def arrayTasks = new HashMap<String,TaskStatusRecord>()
def client = Spy( new BatchClient( projectId: project, location: location, arrayTaskStatus: arrayTasks ) )

when:
client.listTasks(job2) >> {
def list = new LinkedList<>()
list.add(makeTask(task2Name, TaskStatus.State.FAILED))
return list
}
client.listTasks(job3) >> {
def list = new LinkedList<>()
list.add(makeTask(task3Name, TaskStatus.State.SUCCEEDED))
return list
}
arrayTasks.put(task1Name, makeTaskStatusRecord(TaskStatus.State.RUNNING, System.currentTimeMillis()))
arrayTasks.put(task2Name, makeTaskStatusRecord(TaskStatus.State.PENDING, System.currentTimeMillis() - 1_001))

then:
// recent cached task
client.getTaskInArrayStatus(job1, task1).state == TaskStatus.State.RUNNING
// Outdated cached task
client.getTaskInArrayStatus(job2, task2).state == TaskStatus.State.FAILED
// no cached task
client.getTaskInArrayStatus(job3, task3).state == TaskStatus.State.SUCCEEDED
}

def TaskStatusRecord makeTaskStatusRecord(TaskStatus.State state, long timestamp) {
return new TaskStatusRecord(TaskStatus.newBuilder().setState(state).build(), timestamp)

}

def makeTask(String name, TaskStatus.State state){
Task.newBuilder().setName(name)
.setStatus(TaskStatus.newBuilder().setState(state).build())
.build()

}

}

0 comments on commit 5f398ea

Please sign in to comment.