Calculate member based ensemble loss without running the forward pass twice. #26318
Unanswered
HarveySouth
asked this question in
Q&A
Replies: 1 comment
-
I've moved my simple FFNN from linen to nnx and come up with: @nnx.jit
def jit_loss_calculation(member_prediction, training_labels, non_current_member_outputs):
member_error = (jnp.square(member_prediction - training_labels))
member_contribution_to_ensemble = jnp.divide(member_prediction, n_ensemble_members)
ensemble_centroid = member_contribution_to_ensemble + non_current_member_outputs
member_diversity = jnp.square(ensemble_centroid - member_prediction)
full_loss = member_error - (resolved_lambda * member_diversity)
return full_loss.mean()
def run_ensemble_member_loss_and_grad_in_parallel(training_input, training_labels, shared_data, lock_memory, condition, member_index, model):
def all_predictions_set():
return all(lock_memory)
def ncl_member_loss(model):
member_prediction = model(training_input).squeeze()
with condition:
shared_data.at[member_index].set(member_prediction)
lock_memory[member_index] = True
condition.notify_all()
with condition:
condition.wait_for(all_predictions_set)
jax.block_until_ready(shared_data)
non_current_member_outputs = jnp.mean( jnp.concatenate((shared_data[:member_index], shared_data[member_index+1:])) )
return jit_loss_calculation(member_prediction, training_labels, non_current_member_outputs)
return nnx.value_and_grad(ncl_member_loss)(model)
def setup_parallel_execution(training_input, training_labels, models):
shared_data = jnp.zeros((n_ensemble_members, batch_size))
lock_memory = [False] * n_ensemble_members
condition = threading.Condition() # Use threading.Condition to try and avoid deadlock with jax
with concurrent.futures.ThreadPoolExecutor(max_workers=n_ensemble_members) as executor:
futures = [executor.submit(run_ensemble_member_loss_and_grad_in_parallel,
training_input, training_labels, shared_data, lock_memory, condition, member_index, model) for member_index, model in enumerate(models)]
losses_and_grads = [future.result() for future in futures]
losses, grads = zip(*losses_and_grads)
return losses, grads with for epoch in range(epoch_num):
...
for step, (batch_x, batch_y) in enumerate(training_set):
...
losses, grads = setup_parallel_execution(batch_x, batch_y, ensemble_models)
for i in range(len(member_optimizers)):
member_optimizers[i].update(grads[i]) Validity TBD, and definitely not the best solution, but seems to work and doesn't require running the ensemble more than necessary |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to implement negative correlation learning (NCL) in JAX. NCL is a regression ensemble training algorithm which updates each member in the ensemble with it's own loss function: the squared error between the member prediction and the target value AND the squared error between the ensemble output and the member prediction
Ideally I can:
I'm having two difficulties:
I solved this inefficiently in PyTorch by ignoring the second difficulty, and just running the ensemble twice:
I've been able to run the ensemble members in parallel with vmap as I did in python, but I haven't been able to come up with an alternative approach to efficiently running the training step with JAX and looking for help
Beta Was this translation helpful? Give feedback.
All reactions