Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLE of lambda parameter in PowerTransform bijector #1821

Open
mpetteno opened this issue Jul 4, 2024 · 3 comments
Open

MLE of lambda parameter in PowerTransform bijector #1821

mpetteno opened this issue Jul 4, 2024 · 3 comments

Comments

@mpetteno
Copy link

mpetteno commented Jul 4, 2024

Hi everyone, I'm trying tofind the optimal lambda parameter of the PowerTransform bijector with maximum lilkelihood estimation.
In order to do so I had to modify the constructor of the bijector in order to allow power to be a trainable tf.Variable.
Then code is the following:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow_probability import bijectors as tfb
from tensorflow_probability import distributions as tfd
from scipy.stats import boxcox

np.random.seed(42)
tf.random.set_seed(42)

# Data exponential generation
data_dist = tfd.Sample(tfd.Exponential(1.), sample_shape=1000)
x_train = data_dist.sample()
plt.hist(x_train.numpy(), bins=120, density=True, alpha=0.6, color='blue', label='Samples')
plt.show()

nf = tfd.TransformedDistribution(
    tfd.Normal(loc=0, scale=1),
    bijector=tfb.PowerTransform(power=tf.Variable(initial_value=1., name='power'))
)

# Training loop
num_steps = 2000
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
for step in range(num_steps):
    with tf.GradientTape() as tape:
        loss = -tf.reduce_sum(nf.log_prob(x_train))
        grads = tape.gradient(loss, nf.trainable_variables)
    optimizer.apply_gradients(zip(grads, nf.trainable_variables))

    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss.numpy()}, Grads: {grads}, Power: {nf.trainable_variables[0].numpy()}")

_, llm_lmbda = boxcox(x_train, lmbda=None)
print(f"Scipy MLE power is: {llm_lmbda}") # 0.24647694003577084
print(f"My MLE power is: {nf.trainable_variables[0].numpy()}") # 0.6407918334007263

z_samples = nf.sample(1000)
plt.hist(z_samples.numpy(), bins=120, density=True, alpha=0.6, color='green', label='Samples')
plt.show()

The estimation with the MLE from the scipy library gives lambda = 0.25 but mine gives lambda = 0.64. If I use the bijector with a static value of 0.25 I can recover a distribution that is closer to the original exponential so I believe that there might be a problem with the training procedure or with the computation of the forward Jacobian in the PowerTransform bijector but I can't find it.

Anyone can help with this?

@csuter
Copy link
Member

csuter commented Jul 8, 2024

@srvasude it looks like the power parameter of the PowerTransform bijector is not permitted to be a Variable. Do you know why?

https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/bijectors/power_transform.py#L62

@SiegeLordEx
Copy link
Member

Probably a combination of being an ancient bijector and this line:

I guess from a performance perspective you'd want to still maintain the power==0 static branch.

@csuter
Copy link
Member

csuter commented Jul 9, 2024

indeed, the Exp bijector is just implemented as power transform with power=0. we could decouple these, exp would be very simple on its own. or we could keep the static path for efficiency but also allow a tensor input. i started removing the static path but then realized exp was using it...will take another look later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants