-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Predicate option for BatchGenerator #162
Comments
Just learned about |
While experimenting with some xbatcher ideas, I came up with a function that basically implements a predicate as would be desirable here. Here is a minimal example of what this functionality could look like. This code is hacky and pretty slow, but I think the core functionality (creating fixed-size batches with a predicate) could be very useful in some situations, especially with NaN-filtering. @maxrjones @rabernat What do y'all think? This example script makes a list of 1000 integers valued 0-9, then finds the indices of all the eights and returns them in batches of five. The use case I was interested in was having fixed-size batches of size 32, while discarding NxN samples that had either any NaNs, or were all NaNs. import random
def my_gen(batch_set, batch_size=5, predicate=None, sample_dim_name=None):
n = 0
while n < 1000:
m = 0
batch=[]
while m < batch_size:
if n >= 1000:
break
else:
this_batch = batch_set[{sample_dim_name:n}] if sample_dim else batch_set[n]
if not predicate or predicate(this_batch):
batch.append(n)
m += 1
n += 1
else:
n += 1
continue
if n >= 1000:
break
yield batch
batch_set = []
for i in range(1000):
batch_set.append(random.randint(0,9))
pred = lambda this_batch: this_batch == 8
gen = my_gen(batch_set, predicate=pred)
for g in gen:
print(g) |
Better code sample, which wraps xbatcher and also offers fixed batch sizes: import xarray as xr
import xbatcher as xb
import numpy as np
import random
da1 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da2 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
da3 = xr.DataArray(np.random.randint(0,9,(400,400)), dims=['d1', 'd2'])
ds = xr.Dataset({'da1':da1, 'da2':da2, 'da3':da3})
bgen = xb.BatchGenerator(
ds,
{'d1':5, 'd2':5},
{'d1':2, 'd2':2}
)
def my_gen2(bgen, batch_size=5, predicate=None):
b = (batch for batch in bgen)
n = 0
batch_stack = []
while n < 400: # hardcoded n is a kludge; while-loop is necessary
this_batch = next(b)
if not predicate or predicate(this_batch):
batch_stack.append(this_batch)
n += 1
else:
n += 1
continue
if len(batch_stack) == batch_size:
yield xr.concat(batch_stack, 'sample')
batch_stack = []
pred2 = lambda batch: np.mod(batch['da1'].sum(), 10) == 0
gen = my_gen2(bgen, batch_size=2, predicate=pred2)
res = []
for g in gen:
res.append(g)
len(res) |
Is your feature request related to a problem?
When you create a batch generator, what happens when you have data with NaNs? For example, if we consider an ocean data set, like a map of sea surface temperature, you may iterate through different regions where the stencil is valid, partially valid, or completely full of NaNs. The fact that xbatcher can't filter for these situations means that if you need this, you will have to apply filters inside the batch loop, meaning that you will end up with load imbalances.
Describe the solution you'd like
I would like to see an option in BatchGenerator for a selection predicate. Basically, you would pass a function to BatchGenerator that takes slices as inputs, and evaluates to either
True
orFalse
. BatchGenerator would then use the result to select only the slices that returnedTrue
, thereby restoring load balance.Describe alternatives you've considered
No response
Additional context
I think this is similar to #158
The text was updated successfully, but these errors were encountered: