Skip to content

Commit

Permalink
Update kernels to remove cub dependencies
Browse files Browse the repository at this point in the history
CUB is causing runtime compile problems for users, so this
commit removes it.

This may slow down performance, as we now do all with all
comparisions when activating groups and collections.
  • Loading branch information
jlmaccal authored Oct 14, 2024
1 parent 8125432 commit 7565d1e
Showing 1 changed file with 175 additions and 110 deletions.
285 changes: 175 additions & 110 deletions plugin/platforms/cuda/src/kernels/computeMeld.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
All rights reserved
*/

#include <cub/cub.cuh>
#include <cfloat>

__device__ void computeTorsionAngle(const real4* __restrict__ posq, int atom_i, int atom_j, int atom_k, int atom_l,
Expand Down Expand Up @@ -880,7 +879,31 @@ extern "C" __global__ void computeGridPotentialRest(
}



/**
* CUDA kernel function to evaluate energy restraints and activate them based on specific criteria.
*
* The function handles the following cases for each group of restraints:
* 1. **Special Case: numActive == numRestraints**
* - If the number of active restraints is equal to the total number of restraints, all restraints are activated.
*
* 2. **Special Case: numActive == 1**
* - If only one restraint needs to be active, the function finds the minimum energy value using parallel reduction.
* - All restraints with energy equal to the minimum energy are activated.
*
* 3. **General Case**
* - For any other value of numActive, each restraint's energy is compared to all others to determine if it should be activated.
* - Restraints are activated if their energy is among the numActive lowest energies in the group.
*
* After activating the restraints, the function calculates the total energy for all active restraints in each group using parallel reduction.
*
* @param numGroups The number of groups of restraints to be processed.
* @param numActiveArray An array containing the number of active restraints for each group.
* @param boundsArray An array of bounds defining the start and end indices of each group of restraints.
* @param indexArray An array of indices mapping restraints to their respective energy values.
* @param energyArray An array containing the energy values for all restraints.
* @param activeArray An output array indicating which restraints are activated (1.0 for active, 0.0 for inactive).
* @param groupEnergyArray An output array storing the total energy for each group of active restraints.
*/
extern "C" __global__ void evaluateAndActivate(
const int numGroups,
const int* __restrict__ numActiveArray,
Expand All @@ -890,86 +913,114 @@ extern "C" __global__ void evaluateAndActivate(
float* __restrict__ activeArray,
float* __restrict__ groupEnergyArray)
{
// Setup type alias for collective operations
typedef cub::BlockRadixSort<float, NGROUPTHREADS, RESTS_PER_THREAD> BlockRadixSortT;
typedef cub::BlockReduce<float, NGROUPTHREADS> BlockReduceT;
int groupIndex = blockIdx.x;
if (groupIndex >= numGroups) {
return;
}

// Setup shared memory for sorting.
__shared__ union {
typename BlockRadixSortT::TempStorage sort;
typename BlockReduceT::TempStorage reduce;
float cutoff;
} sharedScratch;
int numActive = numActiveArray[groupIndex];
int start = boundsArray[groupIndex].x;
int end = boundsArray[groupIndex].y;
int numRestraints = end - start;

// local storage for energies to be sorted
float energyScratch[RESTS_PER_THREAD];
// Special case: activate all restraints if numActive equals numRestraints
if (numActive == numRestraints) {
for (int i = threadIdx.x; i < numRestraints; i += blockDim.x) {
int index = start + i;
activeArray[indexArray[index]] = 1.0;
}
}
else if (numActive == 1) {
// Find the minimum energy using warp-level reduction
float minEnergy = FLT_MAX;
for (int i = threadIdx.x; i < numRestraints; i += blockDim.x) {
int index = start + i;
float energy = energyArray[indexArray[index]];
minEnergy = fminf(minEnergy, energy);
}

for (int groupIndex=blockIdx.x; groupIndex<numGroups; groupIndex+=gridDim.x) {
int numActive = numActiveArray[groupIndex];
int start = boundsArray[groupIndex].x;
int end = boundsArray[groupIndex].y;

// Load energies into statically allocated scratch buffer
for(int i=0; i<RESTS_PER_THREAD; i++) {
int index = threadIdx.x * RESTS_PER_THREAD + start + i;
if(index < end) {
energyScratch[i] = energyArray[indexArray[index]];
} else {
energyScratch[i] = FLT_MAX;
}
// Use warp-level primitives to find the minimum energy across the warp
for (int offset = 16; offset > 0; offset /= 2) {
minEnergy = fminf(minEnergy, __shfl_down_sync(0xFFFFFFFF, minEnergy, offset));
}
__syncthreads();

// Sort the energies.
BlockRadixSortT(sharedScratch.sort).Sort(energyScratch);
__syncthreads();
// Broadcast the minimum energy from thread 0 to all threads in the warp
minEnergy = __shfl_sync(0xFFFFFFFF, minEnergy, 0);

// find the nth largest energy and store in scratch
int myMin = threadIdx.x * RESTS_PER_THREAD;
int myMax = myMin + RESTS_PER_THREAD;
if((numActive - 1) >= myMin) {
if((numActive - 1) < myMax) {
// only one thread will get here
int offset = numActive - 1 - myMin;
sharedScratch.cutoff = energyScratch[offset];
// Activate restraints with energy equal to the minimum energy
for (int i = threadIdx.x; i < numRestraints; i += blockDim.x) {
int index = start + i;
if (energyArray[indexArray[index]] <= minEnergy) {
activeArray[indexArray[index]] = 1.0;
} else {
activeArray[indexArray[index]] = 0.0;
}
}
__syncthreads();

// Read the nth largest energy from shared memory.
float cutoff = (volatile float)sharedScratch.cutoff;
__syncthreads();

// now we know the cutoff, so apply it to each group and
// load each energy into a scratch buffer.
for(int i=0; i<RESTS_PER_THREAD; i++) {
int index = threadIdx.x * RESTS_PER_THREAD + start + i;
if(index < end) {
if (energyArray[indexArray[index]] <= cutoff) {
activeArray[indexArray[index]] = 1.0;
energyScratch[i] = energyArray[indexArray[index]];

} else {
activeArray[indexArray[index]] = 0.0;
energyScratch[i] = 0.0;
}
else {
// Each thread processes one restraint if possible
for (int i = threadIdx.x; i < numRestraints; i += blockDim.x) {
int indexA = start + i;
int counter = 0;

for (int j = 0; j < numRestraints; j++) {
int indexB = start + j;
if (energyArray[indexArray[indexA]] > energyArray[indexArray[indexB]]) {
counter++;
}
}

if (counter < numActive) {
activeArray[indexArray[indexA]] = 1.0;
} else {
energyScratch[i] = 0.0;
activeArray[indexArray[indexA]] = 0.0;
}
}
__syncthreads();
}

// Now sum all of the energies to get the total energy
// for the group.
float totalEnergy = BlockReduceT(sharedScratch.reduce).Sum(energyScratch);
if(threadIdx.x == 0) {
groupEnergyArray[groupIndex] = totalEnergy;
// Calculate the total energy for the group
float totalEnergy = 0.0f;
for (int i = threadIdx.x; i < numRestraints; i += blockDim.x) {
int index = start + i;
if (activeArray[indexArray[index]] == 1.0) {
totalEnergy += energyArray[indexArray[index]];
}
__syncthreads();
}

// Use warp-level primitives to reduce total energy across the warp
for (int offset = 16; offset > 0; offset /= 2) {
totalEnergy += __shfl_down_sync(0xFFFFFFFF, totalEnergy, offset);
}

// Store the result
if (threadIdx.x == 0) {
groupEnergyArray[groupIndex] = totalEnergy;
}
}


/**
* CUDA kernel function to evaluate energy values and activate groups based on specific criteria for multiple collections.
*
* The function handles the following cases for each collection of groups:
* 1. **Special Case: numActive == numGroups**
* - If the number of active groups is equal to the total number of groups, all groups are activated.
*
* 2. **Special Case: numActive == 1**
* - If only one group needs to be active, the function finds the minimum energy value using parallel reduction.
* - All groups with energy equal to the minimum energy are activated.
*
* 3. **General Case**
* - For any other value of numActive, each group's energy is compared to all others to determine if it should be activated.
* - Groups are activated if their energy is among the numActive lowest energies in the collection.
*
* @param numCollections The number of collections of groups to be processed.
* @param numActiveArray An array containing the number of active groups for each collection.
* @param boundsArray An array of bounds defining the start and end indices of each collection of groups.
* @param indexArray An array of indices mapping groups to their respective energy values.
* @param energyArray An array containing the energy values for all groups.
* @param activeArray An output array indicating which groups are activated (1.0 for active, 0.0 for inactive).
*/
extern "C" __global__ void evaluateAndActivateCollections(
const int numCollections,
const int* __restrict__ numActiveArray,
Expand All @@ -978,63 +1029,77 @@ extern "C" __global__ void evaluateAndActivateCollections(
const float* __restrict__ energyArray,
float* __restrict__ activeArray)
{
// Setup type alias for sorting.
typedef cub::BlockRadixSort<float, NCOLLTHREADS, GROUPS_PER_THREAD> BlockRadixSortT;

// Setup shared memory for sorting.
__shared__ union {
typename BlockRadixSortT::TempStorage sort;
float cutoff;
} sharedScratch;

// local storage for energies to be sorted
float energyScratch[GROUPS_PER_THREAD];

for (int collIndex=blockIdx.x; collIndex<numCollections; collIndex+=gridDim.x) {
int numActive = numActiveArray[collIndex];
int start = boundsArray[collIndex].x;
int end = boundsArray[collIndex].y;

// Load energies into statically allocated scratch buffer
for(int i=0; i<GROUPS_PER_THREAD; i++) {
int index = threadIdx.x * GROUPS_PER_THREAD + start + i;
if(index < end) {
energyScratch[i] = energyArray[indexArray[index]];
} else {
energyScratch[i] = FLT_MAX;
}
__shared__ float sharedBuffer[1024];

int collectionIndex = blockIdx.x;
if (collectionIndex >= numCollections) {
return;
}

int numActive = numActiveArray[collectionIndex];
int start = boundsArray[collectionIndex].x;
int end = boundsArray[collectionIndex].y;
int numGroups = end - start;

// Special case: activate all groups if numActive equals numGroups
if (numActive == numGroups) {
for (int i = threadIdx.x; i < numGroups; i += blockDim.x) {
int index = start + i;
activeArray[indexArray[index]] = 1.0;
}
}
else if (numActive == 1) {
// Find the minimum energy using shared memory reduction
float minEnergy = FLT_MAX;
for (int i = threadIdx.x; i < numGroups; i += blockDim.x) {
int index = start + i;
float energy = energyArray[indexArray[index]];
minEnergy = fminf(minEnergy, energy);
}
__syncthreads();

// Sort the energies.
BlockRadixSortT(sharedScratch.sort).Sort(energyScratch);
sharedBuffer[threadIdx.x] = minEnergy;
__syncthreads();

// find the nth largest energy and store in scratch
int myMin = threadIdx.x * GROUPS_PER_THREAD;
int myMax = myMin + GROUPS_PER_THREAD;
if((numActive - 1) >= myMin) {
if((numActive - 1) < myMax) {
// only one thread will get here
int offset = numActive - 1 - myMin;
sharedScratch.cutoff = energyScratch[offset];
// Reduce within the block to find the minimum energy
for (int stride = blockDim.x / 2; stride > 0; stride /= 2) {
if (threadIdx.x < stride) {
sharedBuffer[threadIdx.x] = fminf(sharedBuffer[threadIdx.x], sharedBuffer[threadIdx.x + stride]);
}
__syncthreads();
}
__syncthreads();

// Read the nth largest energy from shared memory.
float cutoff = (volatile float)sharedScratch.cutoff;
minEnergy = sharedBuffer[0];
__syncthreads();

// now we know the cutoff, so apply it to each group
for (int i=start + threadIdx.x; i<end; i+=blockDim.x) {
if (energyArray[indexArray[i]] <= cutoff) {
activeArray[indexArray[i]] = 1.0;
// Activate groups with energy equal to the minimum energy
for (int i = threadIdx.x; i < numGroups; i += blockDim.x) {
int index = start + i;
if (energyArray[indexArray[index]] <= minEnergy) {
activeArray[indexArray[index]] = 1.0;
} else {
activeArray[indexArray[index]] = 0.0;
}
else {
activeArray[indexArray[i]] = 0.0;
}
}
else {
// Each thread processes one group if possible
for (int i = threadIdx.x; i < numGroups; i += blockDim.x) {
int indexA = start + i;
int counter = 0;

for (int j = 0; j < numGroups; j++) {
int indexB = start + j;
if (energyArray[indexArray[indexA]] > energyArray[indexArray[indexB]]) {
counter++;
}
}

if (counter < numActive) {
activeArray[indexArray[indexA]] = 1.0;
} else {
activeArray[indexArray[indexA]] = 0.0;
}
}
__syncthreads();
}
}

Expand Down

0 comments on commit 7565d1e

Please sign in to comment.