Skip to content

Commit

Permalink
Update launch of kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmaccal authored Oct 14, 2024
1 parent 7565d1e commit 44fde24
Showing 1 changed file with 2 additions and 22 deletions.
24 changes: 2 additions & 22 deletions plugin/platforms/cuda/src/MeldCudaKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1237,26 +1237,6 @@ void CudaCalcMeldForceKernel::initialize(const System &system, const MeldForce &
defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms());
defines["NUM_DERIVS"] = cu.intToString(cu.getEnergyParamDerivNames().size());

// This should be determined by hardware, rather than hard-coded.
const int maxThreadsPerGroup = 1024;
// Note x / y + (x % y !=0) does integer division and round up
const int restraintsPerThread = std::max(
4,
largestGroup / maxThreadsPerGroup + (largestGroup % maxThreadsPerGroup != 0));
threadsPerGroup = largestGroup / restraintsPerThread + (largestGroup % restraintsPerThread != 0);
replacements["NGROUPTHREADS"] = cu.intToString(threadsPerGroup);
replacements["RESTS_PER_THREAD"] = cu.intToString(restraintsPerThread);

// This should be determined by hardware, rather than hard-coded.
const int maxThreadsPerCollection = 1024;
// Note x / y + (x % y !=0) does integer division and round up
const int groupsPerThread = std::max(
4,
largestCollection / maxThreadsPerCollection + (largestCollection % maxThreadsPerCollection != 0));
threadsPerCollection = largestCollection / groupsPerThread + (largestCollection % groupsPerThread != 0);
replacements["NCOLLTHREADS"] = cu.intToString(threadsPerCollection);
replacements["GROUPS_PER_THREAD"] = cu.intToString(groupsPerThread);

CUmodule module = cu.createModule(cu.replaceStrings(CudaMeldKernelSources::vectorOps + CudaMeldKernelSources::computeMeld, replacements), defines);
computeRDCRestKernel = cu.getKernel(module, "computeRDCRest");
computeDistRestKernel = cu.getKernel(module, "computeDistRest");
Expand Down Expand Up @@ -1461,7 +1441,7 @@ double CudaCalcMeldForceKernel::execute(ContextImpl &context, bool includeForces
&restraintActive->getDevicePointer(),
&groupEnergies->getDevicePointer(),
};
cu.executeKernel(evaluateAndActivateKernel, groupArgs, threadsPerGroup * numGroups, threadsPerGroup);
cu.executeKernel(evaluateAndActivateKernel, groupArgs, 32 * numGroups);

// now evaluate and activate groups based on collections
void *collArgs[] = {
Expand All @@ -1471,7 +1451,7 @@ double CudaCalcMeldForceKernel::execute(ContextImpl &context, bool includeForces
&collectionGroupIndices->getDevicePointer(),
&groupEnergies->getDevicePointer(),
&groupActive->getDevicePointer()};
cu.executeKernel(evaluateAndActivateCollectionsKernel, collArgs, threadsPerCollection * numCollections, threadsPerCollection);
cu.executeKernel(evaluateAndActivateCollectionsKernel, collArgs, 1024 * numCollections, 1024);

// Now set the restraints active based on if the groups are active
void *applyGroupsArgs[] = {
Expand Down

0 comments on commit 44fde24

Please sign in to comment.