Skip to content

Commit

Permalink
Split worker.py into multiple worker files under workers/
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710938412
  • Loading branch information
jaindeepali authored and copybara-github committed Jan 30, 2025
1 parent a99e5c8 commit 561e4c3
Show file tree
Hide file tree
Showing 43 changed files with 1,568 additions and 840 deletions.
2 changes: 1 addition & 1 deletion iris/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import abc
import pathlib
from typing import Any, Dict, Sequence, Union
from iris import worker_util
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/ars_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from typing import Any, Callable, Dict, Optional, Sequence

from iris import normalizer
from iris import worker_util
from iris.algorithms import algorithm
from iris.algorithms import stateless_perturbation_generators
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/ars_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

from iris import normalizer
from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.workers import worker_util
import numpy as np
import tensorflow as tf
from absl.testing import absltest
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/cma_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import cma
from iris import normalizer
from iris import worker_util
from iris.algorithms import algorithm
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/cma_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from iris import worker_util
from iris.algorithms import cma_algorithm
from iris.workers import worker_util
import numpy as np
from absl.testing import absltest

Expand Down
3 changes: 1 addition & 2 deletions iris/algorithms/es_enas_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
import functools
from multiprocessing import dummy as mp_threads
from typing import Any, Dict, Sequence

from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.algorithms import controllers
from iris.workers import worker_util
import pyglove as pg


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/es_enas_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

# pytype: disable=attribute-error
from gym import spaces
from iris import worker_util
from iris.algorithms import es_enas_algorithm
from iris.policies import nas_policy
from iris.workers import worker_util
import numpy as np
import pyglove as pg
from absl.testing import absltest
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/learnable_ars_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from flax import linen as nn
from iris import checkpoint_util
from iris import normalizer
from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.algorithms import stateless_perturbation_generators
from iris.workers import worker_util
import jax
import jax.numpy as jnp
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/multi_agent_ars_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from absl import logging
from iris import checkpoint_util
from iris import normalizer
from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/multi_agent_ars_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import os
from iris import checkpoint_util
from iris import normalizer
from iris import worker_util
from iris.algorithms import multi_agent_ars_algorithm
from iris.workers import worker_util
import numpy as np
import tensorflow as tf
from absl.testing import absltest
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pes_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from typing import Any, Dict, Optional, Sequence

from iris import normalizer
from iris import worker_util
from iris.algorithms import algorithm
from iris.algorithms import stateless_perturbation_generators
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pes_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from iris import worker_util
from iris.algorithms import pes_algorithm
from iris.workers import worker_util
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/piars_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import gym
from gym import spaces
from gym.spaces import utils
from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.policies import keras_pi_policy
from iris.workers import worker_util
import numpy as np
import tensorflow as tf
from tf_agents.agents.categorical_dqn import categorical_dqn_agent
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/piars_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

import gym
from iris import worker_util
from iris.algorithms import piars_algorithm
from iris.policies import keras_pi_policy
from iris.workers import worker_util
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pyglove_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import functools
from multiprocessing import dummy as mp_threads
from typing import Any, Dict, Sequence
from iris import worker_util
from iris.algorithms import algorithm
from iris.algorithms import controllers
from iris.workers import worker_util
import numpy as np
import pyglove as pg

Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pyglove_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from iris import worker_util
from iris.algorithms import pyglove_algorithm
from iris.workers import worker_util
import numpy as np
import pyglove as pg
from absl.testing import absltest
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pyribs_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from typing import Any, Dict, Sequence

from iris import normalizer
from iris import worker_util
from iris.algorithms import algorithm
from iris.workers import worker_util
import numpy as np
from ribs import archives
from ribs import emitters
Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/pyribs_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from unittest import mock

from iris import normalizer
from iris import worker_util
from iris.algorithms import algorithm
from iris.algorithms import pyribs_algorithm
from iris.workers import worker_util
import numpy as np
from ribs import archives

Expand Down
3 changes: 1 addition & 2 deletions iris/algorithms/rbo_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
"""

from typing import Sequence

from iris import worker_util
from iris.algorithms import ars_algorithm
from iris.algorithms import optimizers
from iris.workers import worker_util
import numpy as np


Expand Down
2 changes: 1 addition & 1 deletion iris/algorithms/rbo_algorithm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from iris import worker_util
from iris.algorithms import rbo_algorithm
from iris.workers import worker_util
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
Expand Down
13 changes: 8 additions & 5 deletions iris/configs/simple_example_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

"""Example configuration for Iris experiments."""


from iris import worker
from iris.algorithms import ars_algorithm
from iris.workers import simple_worker
from ml_collections import config_dict
import numpy as np

Expand All @@ -37,9 +36,13 @@ def get_worker_config():

return config_dict.ConfigDict(
dict(
worker_class=worker.SimpleWorker,
worker_args={'initial_params': np.ones(10),
'blackbox_function': lambda x: -1 * np.sum(x**2)}))
worker_class=simple_worker.SimpleWorker,
worker_args={
'initial_params': np.ones(10),
'blackbox_function': lambda x: -1 * np.sum(x**2),
},
)
)


def get_algo_config():
Expand Down
4 changes: 2 additions & 2 deletions iris/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Coordinator class for distributed blackbox optimization library."""

from collections.abc import Sequence, Mapping
from collections.abc import Mapping, Sequence
from concurrent import futures
import dataclasses
import os
Expand All @@ -28,8 +28,8 @@
import courier
from iris import checkpoint_util
from iris import logger
from iris import worker_util
from iris.algorithms import algorithm
from iris.workers import worker_util
import launchpad as lp
import numpy as np
from tensorflow.io import gfile
Expand Down
4 changes: 2 additions & 2 deletions iris/coordinator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import gym
from iris import checkpoint_util
from iris import coordinator
from iris import worker
from iris.algorithms import ars_algorithm
from iris.policies import nn_policy
from iris.workers import rl_worker
import launchpad as lp
from ml_collections import config_dict
import numpy as np
Expand Down Expand Up @@ -159,7 +159,7 @@ def setUp(self):
),
worker=config_dict.ConfigDict(
dict(
worker_class=worker.RLWorker,
worker_class=rl_worker.RLWorker,
worker_args=dict(
env=TestEnv,
policy=nn_policy.FullyConnectedNeuralNetworkPolicy,
Expand Down
2 changes: 1 addition & 1 deletion iris/maml/adaptation_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Adaptation optimizers mapping parameter input to objective value."""
import enum
from typing import Callable, Sequence, Tuple, Union
from iris import worker_util
from iris.workers import worker_util
import numpy as np

FloatLike = Union[float, np.float32, np.float64]
Expand Down
8 changes: 4 additions & 4 deletions iris/maml/adaptation_optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for adaptation_optimizers."""
from iris import worker
from iris.maml import adaptation_optimizers
from iris.workers import simple_worker
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -24,10 +23,11 @@ class AdaptationOptimizersTest(parameterized.TestCase):
"""Tests adaptation optimizers."""

def setUp(self):
self.worker_obj = worker.SimpleWorker(
self.worker_obj = simple_worker.SimpleWorker(
worker_id=0,
initial_params=5.0 * np.ones(2),
blackbox_function=lambda x: -1 * np.sum(x**2))
blackbox_function=lambda x: -1 * np.sum(x**2),
)
self.init_params = self.worker_obj._init_state['init_params']
super().setUp()

Expand Down
Loading

0 comments on commit 561e4c3

Please sign in to comment.