-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VQ-VAE training example(v2) returned NAN loss #198
Comments
Hi @EBGU , there's quite a lot of code there! I recognize at least some of this from our vqvae example notebook? Rather than printing the whole file it might be more useful if you could highlight what you have changed? I've just ran our vqvae notebook using a free GPU instance on Google Colab, with TF 2.4.1, you can see the results in the gist below: As far as I can tell things are working correctly? |
Hi @tomhennigan! I also tried your original code without any changes. The result was still NaN. I thought it could be an environmental problem, but there was no error coming up. |
I upgrade my tf to 2.4.1 and it worked! I guess tf 2.2.0 is somehow incompatible with the code. Thanks a lot! |
Hi , |
Dear Team Deepmind,
I am really grateful that you shared a vqvae_example with sonnet2. However, when running it, I currently encounter a problem of NAN vqvae loss from the beginning. The outcome is:
100 train loss: nan recon_error: 1.010 perplexity: 1.031 vqvae loss: nan
and so on.
The plot of the training set is fine, but the reconstruction is pure grey. I tried vq_use_ema = False of True and got the same results.
I have slightly modified your code by replacing downloading and data loading with the previous version(https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb) using a local directory. Also, I'm using TensorFlow version 2.2.0 Sonnet version 2.0.0. My code didn't return any error, just NAN loss.
I wonder if you could kindly help me with this problem.
Thanks a lot!
Sincerely,
Harold
My code:
import os
import subprocess
import tempfile
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree
try:
import sonnet.v2 as snt
tf.enable_v2_behavior()
except ImportError:
import sonnet as snt
from six.moves import cPickle
from six.moves import urllib
from six.moves import xrange
#for plt dispaly
os.system('export DISPLAY=:0')
print("TensorFlow version {}".format(tf.version))
print("Sonnet version {}".format(snt.version))
local_data_dir='/home/harold/Documents/VQ-VAE'
'''
#Downloading cifar10
cifar10 = tfds.as_numpy(tfds.load("cifar10:3.0.2", split="train+test", batch_size=-1))
cifar10.pop("id", None)
cifar10.pop("label")
tree.map_structure(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)
'''
#Data loading
'''
train_data_dict = tree.map_structure(lambda x: x[:40000], cifar10)
valid_data_dict = tree.map_structure(lambda x: x[40000:50000], cifar10)
test_data_dict = tree.map_structure(lambda x: x[50000:], cifar10)
def cast_and_normalise_images(data_dict):
"""Convert images to floating point with the range [-0.5, 0.5]"""
images = data_dict['image']
data_dict['image'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
return data_dict
train_data_variance = np.var(train_data_dict['image'] / 255.0)
print('train data variance: %s' % train_data_variance)
'''
def unpickle(filename):
with open(filename, 'rb') as fo:
return cPickle.load(fo, encoding='latin1')
def reshape_flattened_image_batch(flat_image_batch):
return flat_image_batch.reshape(-1, 3, 32, 32).transpose([0, 2, 3, 1]) # convert from NCHW to NHWC
def combine_batches(batch_list):
images = np.vstack([reshape_flattened_image_batch(batch['data'])
for batch in batch_list])
labels = np.vstack([np.array(batch['labels']) for batch in batch_list]).reshape(-1, 1)
return {'images': images, 'labels': labels}
train_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir,
'cifar-10-batches-py/data_batch_%d' % i))
for i in range(1,5)
])
valid_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir,
'cifar-10-batches-py/data_batch_5'))])
test_data_dict = combine_batches([
unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/test_batch'))])
def cast_and_normalise_images(data_dict):
"""Convert images to floating point with the range [-0.5, 0.5]"""
images = data_dict['images']
data_dict['images'] = (tf.cast(images, tf.float32) / 255.0) - 0.5
return data_dict
train_data_variance = np.var(train_data_dict['images'] / 255.0)
print('train data variance: %s' % train_data_variance)
#Encoder & Decoder Architecture
class ResidualStack(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(ResidualStack, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
def call(self, inputs):
h = inputs
for conv3, conv1 in self._layers:
conv3_out = conv3(tf.nn.relu(h))
conv1_out = conv1(tf.nn.relu(conv3_out))
h += conv1_out
return tf.nn.relu(h) # Resnet V1 style
class Encoder(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Encoder, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
def call(self, x):
h = tf.nn.relu(self._enc_1(x))
h = tf.nn.relu(self._enc_2(h))
h = tf.nn.relu(self._enc_3(h))
return self._residual_stack(h)
class Decoder(snt.Module):
def init(self, num_hiddens, num_residual_layers, num_residual_hiddens,
name=None):
super(Decoder, self).init(name=name)
self._num_hiddens = num_hiddens
self._num_residual_layers = num_residual_layers
self._num_residual_hiddens = num_residual_hiddens
def call(self, x):
h = self._dec_1(x)
h = self._residual_stack(h)
h = tf.nn.relu(self._dec_2(h))
x_recon = self._dec_3(h)
return x_recon
class VQVAEModel(snt.Module):
def init(self, encoder, decoder, vqvae, pre_vq_conv1,
data_variance, name=None):
super(VQVAEModel, self).init(name=name)
self._encoder = encoder
self._decoder = decoder
self._vqvae = vqvae
self._pre_vq_conv1 = pre_vq_conv1
self._data_variance = data_variance
def call(self, inputs, is_training):
z = self._pre_vq_conv1(self._encoder(inputs))
vq_output = self._vqvae(z, is_training=is_training)
x_recon = self._decoder(vq_output['quantize'])
recon_error = tf.reduce_mean((x_recon - inputs) ** 2) / self._data_variance
loss = recon_error + vq_output['loss']
return {
'z': z,
'x_recon': x_recon,
'loss': loss,
'recon_error': recon_error,
'vq_output': vq_output,
}
#Build Model and train
#%%time
Set hyper-parameters.
batch_size = 32
image_size = 32
100k steps should take < 30 minutes on a modern (>= 2017) GPU.
10k steps gives reasonable accuracy with VQVAE on Cifar10.
num_training_updates = 10000
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
These hyper-parameters define the size of the model (number of parameters and layers).
The hyper-parameters in the paper were (For ImageNet):
batch_size = 128
image_size = 128
num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
This value is not that important, usually 64 works.
This will not change the capacity in the information-bottleneck.
embedding_dim = 64
The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512
commitment_cost should be set appropriately. It's often useful to try a couple
of values. It mostly depends on the scale of the reconstruction cost
(log p(x|z)). So if the reconstruction cost is 100x higher, the
commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25
Use EMA updates for the codebook (instead of the Adam optimizer).
This typically converges faster, and makes the model less dependent on choice
of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
developed afterwards). See Appendix of the paper for more details.
vq_use_ema = False
This is only used for EMA updates.
decay = 0.99
learning_rate = 3e-4
# Data Loading.
train_dataset = (
tf.data.Dataset.from_tensor_slices(train_data_dict)
.map(cast_and_normalise_images)
.shuffle(10000)
.repeat(-1) # repeat indefinitely
.batch(batch_size, drop_remainder=True)
.prefetch(-1))
valid_dataset = (
tf.data.Dataset.from_tensor_slices(valid_data_dict)
.map(cast_and_normalise_images)
.repeat(1) # 1 epoch
.batch(batch_size)
.prefetch(-1))
'''
train_batch = next(iter(train_dataset))
def convert_batch_to_image_grid(image_batch):
reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
.transpose(0, 2, 1, 3, 4)
.reshape(4 * 32, 8 * 32, 3))
return reshaped + 0.5
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')
plt.show()
'''
# Build modules.
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,
kernel_shape=(1, 1),
stride=(1, 1),
name="to_vq")
if vq_use_ema:
vq_vae = snt.nets.VectorQuantizerEMA(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost,
decay=decay)
else:
vq_vae = snt.nets.VectorQuantizer(
embedding_dim=embedding_dim,
num_embeddings=num_embeddings,
commitment_cost=commitment_cost)
model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
data_variance=train_data_variance)
optimizer = snt.optimizers.Adam(learning_rate=learning_rate)
@tf.function
def train_step(data):
with tf.GradientTape() as tape:
model_output = model(data['images'], is_training=True)
trainable_variables = model.trainable_variables
grads = tape.gradient(model_output['loss'], trainable_variables)
optimizer.apply(grads, trainable_variables)
return model_output
train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []
for step_index, data in enumerate(train_dataset):
train_results = train_step(data)
train_losses.append(train_results['loss'])
train_recon_errors.append(train_results['recon_error'])
train_perplexities.append(train_results['vq_output']['perplexity'])
train_vqvae_loss.append(train_results['vq_output']['loss'])
if (step_index + 1) % 100 == 0:
print('%d train loss: %f ' % (step_index + 1,
np.mean(train_losses[-100:])) +
('recon_error: %.3f ' % np.mean(train_recon_errors[-100:])) +
('perplexity: %.3f ' % np.mean(train_perplexities[-100:])) +
('vqvae loss: %.3f' % np.mean(train_vqvae_loss[-100:])))
if step_index == num_training_updates:
break
#Plot loss
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')
ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')
plt.show()
#Visualization
Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))
Put data through the model with is_training=False, so that in the case of
using EMA the codebook is not updated.
train_reconstructions = model(train_batch['images'],
is_training=False)['x_recon'].numpy()
valid_reconstructions = model(valid_batch['images'],
is_training=False)['x_recon'].numpy()
def convert_batch_to_image_grid(image_batch):
reshaped = (image_batch.reshape(4, 8, 32, 32, 3)
.transpose(0, 2, 1, 3, 4)
.reshape(4 * 32, 8 * 32, 3))
return reshaped + 0.5
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(2,2,1)
ax.imshow(convert_batch_to_image_grid(train_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('training data originals')
plt.axis('off')
ax = f.add_subplot(2,2,2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions),
interpolation='nearest')
ax.set_title('training data reconstructions')
plt.axis('off')
ax = f.add_subplot(2,2,3)
ax.imshow(convert_batch_to_image_grid(valid_batch['images'].numpy()),
interpolation='nearest')
ax.set_title('validation data originals')
plt.axis('off')
ax = f.add_subplot(2,2,4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions),
interpolation='nearest')
ax.set_title('validation data reconstructions')
plt.axis('off')
plt.show()
The text was updated successfully, but these errors were encountered: