Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/deepmind/alphafold
Browse files Browse the repository at this point in the history
  • Loading branch information
dialvarezs committed Jun 13, 2022
2 parents 26af707 + b0b7afa commit 4440a18
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 20 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ If you have any questions, please contact the AlphaFold team at

## First time setup

You will need a machine running Linux, AlphaFold does not support other
operating systems.

The following steps are required in order to run AlphaFold:

1. Install [Docker](https://www.docker.com/).
Expand Down
2 changes: 1 addition & 1 deletion alphafold/model/all_atom_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def torsion_angles_to_frames(
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]

all_frames_to_backb = jax.tree_multimap(
all_frames_to_backb = jax.tree_map(
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
chi4_frame_to_backb[:, None])
Expand Down
5 changes: 2 additions & 3 deletions alphafold/model/folding_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
)
outputs.append(output)

output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs)
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
# Pass along for LDDT-Head.
output['act'] = activations['act']

Expand Down Expand Up @@ -823,7 +823,7 @@ def compute_frames(
alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames']
use_alt = use_alt[:, None]

renamed_gt_frames = jax.tree_multimap(
renamed_gt_frames = jax.tree_map(
lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames)

return renamed_gt_frames, frames_batch['rigidgroups_gt_exists']
Expand Down Expand Up @@ -1160,4 +1160,3 @@ def __call__(self,
'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8)
})
return outputs

4 changes: 2 additions & 2 deletions alphafold/model/geometry/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def __post_init__(self):
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])

def __add__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_multimap(lambda x, y: x + y, self, other)
return jax.tree_map(lambda x, y: x + y, self, other)

def __sub__(self, other: Vec3Array) -> Vec3Array:
return jax.tree_multimap(lambda x, y: x - y, self, other)
return jax.tree_map(lambda x, y: x - y, self, other)

def __mul__(self, other: Float) -> Vec3Array:
return jax.tree_map(lambda x: x * other, self)
Expand Down
4 changes: 2 additions & 2 deletions alphafold/model/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def outer_fn_layer_stack(x):
assert_fn = functools.partial(
np.testing.assert_allclose, atol=1e-4, rtol=1e-4)

jax.tree_multimap(assert_fn, unrolled_grad,
_slice_layers_params(layer_stack_grad))
jax.tree_map(assert_fn, unrolled_grad,
_slice_layers_params(layer_stack_grad))

def test_random(self):
"""Random numbers should be handled correctly."""
Expand Down
14 changes: 7 additions & 7 deletions alphafold/model/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def mapped_fn(*args):
# Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args)

in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
flat_sizes = jax.tree_flatten(in_sizes)[0]
in_size = max(flat_sizes)
assert all(i in {in_size, -1} for i in flat_sizes)
Expand All @@ -137,7 +137,7 @@ def mapped_fn(*args):
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size

def apply_fun_to_slice(slice_start, slice_size):
input_slice = jax.tree_multimap(
input_slice = jax.tree_map(
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
), args, in_axes_)
return fun(*input_slice)
Expand All @@ -158,19 +158,19 @@ def make_output_shape(axis, shard_shape, remainder_shape):
shard_shape[axis] * num_extra_shards +
remainder_shape[axis],) + shard_shape[axis + 1:]

out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
out_shapes)
out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
out_shapes)

# Calls dynamic Update slice with different argument order
# This is here since tree_multimap only works with positional arguments
# This is here since tree_map only works with positional arguments
def dynamic_update_slice_in_dim(full_array, update, axis, i):
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)

def compute_shard(outputs, slice_start, slice_size):
slice_out = apply_fun_to_slice(slice_start, slice_size)
update_slice = partial(
dynamic_update_slice_in_dim, i=slice_start)
return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)
return jax.tree_map(update_slice, outputs, slice_out, out_axes_)

def scan_iteration(outputs, i):
new_outputs = compute_shard(outputs, i, shard_size)
Expand All @@ -181,7 +181,7 @@ def scan_iteration(outputs, i):
def allocate_buffer(dtype, shape):
return jnp.zeros(shape, dtype=dtype)

outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)
outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)

if slice_starts.shape[0] > 0:
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
Expand Down
11 changes: 7 additions & 4 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ ARG CUDA=11.2
# JAXLIB no longer built for all minor CUDA versions:
# https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0166-may-11-2021
ARG CUDA_JAXLIB=11.1
ARG DEBIAN_FRONTEND=noninteractive

# Use bash to support string substitution.
SHELL ["/bin/bash", "-c"]

RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get upgrade -y \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
&& apt-get upgrade -y \
&& apt-get install -y --no-install-recommends \
build-essential \
cmake \
cuda-command-line-tools-${CUDA/./-} \
Expand Down Expand Up @@ -72,8 +73,10 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \
# Install pip packages.
RUN pip3 install --upgrade --no-cache-dir pip \
&& pip3 install --no-cache-dir -r /app/alphafold/requirements.txt \
&& pip3 install --upgrade --no-cache-dir jax==0.2.14 jaxlib==0.1.69+cuda${CUDA_JAXLIB/./} -f \
https://storage.googleapis.com/jax-releases/jax_releases.html
&& pip3 install --upgrade --no-cache-dir \
jax==0.2.14 \
jaxlib==0.1.69+cuda${CUDA_JAXLIB/./} \
-f https://storage.googleapis.com/jax-releases/jax_releases.html

# Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.8/site-packages
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ jax==0.2.14
ml-collections==0.1.0
numpy==1.19.5
pandas==1.3.4
protobuf==3.20.1
scipy==1.7.0
tensorflow-cpu==2.5.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name='alphafold',
version='2.2.0',
version='2.2.2',
description='An implementation of the inference pipeline of AlphaFold v2.0.'
'This is a completely new model that was entered as AlphaFold2 in CASP14 '
'and published in Nature.',
Expand Down

0 comments on commit 4440a18

Please sign in to comment.