From 4814b0c4a8bc74bf98f34ba99440f24ea21ab93b Mon Sep 17 00:00:00 2001 From: Oleg Smirnov Date: Tue, 15 Feb 2022 12:34:19 +0100 Subject: [PATCH] Fix Keras imports for optimizer algorithms --- tensorflow_riemopt/optimizers/constrained_rmsprop.py | 4 ++-- tensorflow_riemopt/optimizers/riemannian_adam.py | 4 ++-- tensorflow_riemopt/optimizers/riemannian_gradient_descent.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow_riemopt/optimizers/constrained_rmsprop.py b/tensorflow_riemopt/optimizers/constrained_rmsprop.py index d5441a3..1bf9e87 100644 --- a/tensorflow_riemopt/optimizers/constrained_rmsprop.py +++ b/tensorflow_riemopt/optimizers/constrained_rmsprop.py @@ -9,19 +9,19 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend_config -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import gen_training_ops +from keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow_riemopt.variable import get_manifold @generic_utils.register_keras_serializable(name="ConstrainedRMSprop") -class ConstrainedRMSprop(optimizer_v2.OptimizerV2): +class ConstrainedRMSprop(OptimizerV2): """Optimizer that implements the RMSprop algorithm.""" _HAS_AGGREGATE_GRAD = True diff --git a/tensorflow_riemopt/optimizers/riemannian_adam.py b/tensorflow_riemopt/optimizers/riemannian_adam.py index 7612bfc..621cd37 100644 --- a/tensorflow_riemopt/optimizers/riemannian_adam.py +++ b/tensorflow_riemopt/optimizers/riemannian_adam.py @@ -6,19 +6,19 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend_config -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import gen_training_ops +from keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow_riemopt.variable import get_manifold @generic_utils.register_keras_serializable(name="RiemannianAdam") -class RiemannianAdam(optimizer_v2.OptimizerV2): +class RiemannianAdam(OptimizerV2): """Optimizer that implements the Riemannian Adam algorithm.""" _HAS_AGGREGATE_GRAD = True diff --git a/tensorflow_riemopt/optimizers/riemannian_gradient_descent.py b/tensorflow_riemopt/optimizers/riemannian_gradient_descent.py index 91da4a3..88d5a9d 100644 --- a/tensorflow_riemopt/optimizers/riemannian_gradient_descent.py +++ b/tensorflow_riemopt/optimizers/riemannian_gradient_descent.py @@ -6,19 +6,19 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import ops from tensorflow.python.keras import backend_config -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.utils import generic_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import gen_training_ops +from keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow_riemopt.variable import get_manifold @generic_utils.register_keras_serializable(name="RiemannianSGD") -class RiemannianSGD(optimizer_v2.OptimizerV2): +class RiemannianSGD(OptimizerV2): """Optimizer that implements the Riemannian SGD algorithm.""" _HAS_AGGREGATE_GRAD = True