Skip to content

Commit

Permalink
Merge branch 'master' of github.com:karpathy/llm.c
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 18, 2024
2 parents 4fd2df8 + a33a49a commit efbbdc8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
9 changes: 4 additions & 5 deletions llmc/dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,16 @@ void prepare_intra_shard_indices_(DataLoader *loader) {
free(loader->intra_shard_indices);
}
loader->intra_shard_indices = (int*)mallocCheck(loader->shard_num_samples * sizeof(int));
random_permutation_with_init(loader->intra_shard_indices, loader->shard_num_samples, &loader->shuffle_rng, 1);
init_identity_permutation(loader->intra_shard_indices, loader->shard_num_samples);
random_permutation(loader->intra_shard_indices, loader->shard_num_samples, &loader->shuffle_rng);
}

void dataloader_reset(DataLoader *loader) {
loader->current_shard_idx = 0;
loader->current_sample_idx = 0;

if (loader->should_shuffle) { // shuffle the shards
random_permutation_with_init(loader->shard_indices, loader->glob_result.gl_pathc, &loader->shuffle_rng, 0);
random_permutation(loader->shard_indices, loader->glob_result.gl_pathc, &loader->shuffle_rng);
}

dataloader_load_shard_(loader, loader->current_shard_idx);
Expand Down Expand Up @@ -171,9 +172,7 @@ void dataloader_init(DataLoader *loader,
manual_seed(&shuffle_rng, 42 + process_rank);
loader->shuffle_rng = shuffle_rng;
loader->shard_indices = (int*)mallocCheck(loader->glob_result.gl_pathc * sizeof(int));
for (int i = 0; i < loader->glob_result.gl_pathc; i++) {
loader->shard_indices[i] = i; // start with identity permutation
}
init_identity_permutation(loader->shard_indices, loader->glob_result.gl_pathc);
loader->intra_shard_indices = NULL; // dynamically allocated allowing different shard sizes
}

Expand Down
10 changes: 5 additions & 5 deletions llmc/rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ void normal_(float* data, unsigned int numel, float mean, float std, mt19937_sta
}
}

void random_permutation_with_init(int* data, int numel, mt19937_state* state, int should_init) {
if (should_init) {
for (int i = 0; i < numel; i++) {
data[i] = i;
}
void init_identity_permutation(int *data, int numel) {
for (int i = 0; i < numel; i++) {
data[i] = i;
}
}

void random_permutation(int* data, int numel, mt19937_state* state) {
for (int i = numel - 1; i > 0; i--) {
// pick an index j in [0, i] with equal probability
int j = randint32(state) % (i + 1);
Expand Down

0 comments on commit efbbdc8

Please sign in to comment.