diff --git a/alphafold/model/folding_multimer.py b/alphafold/model/folding_multimer.py index c2e47c8ed..2d8049327 100644 --- a/alphafold/model/folding_multimer.py +++ b/alphafold/model/folding_multimer.py @@ -789,7 +789,7 @@ def backbone_loss(gt_rigid: geometry.Rigid3Array, loss_fn = functools.partial( all_atom_multimer.frame_aligned_point_error, l1_clamp_distance=config.atom_clamp_distance, - loss_unit_distance=config.loss_unit_distance) + length_scale=config.loss_unit_distance) loss_fn = jax.vmap(loss_fn, (0, None, None, 0, None, None, None)) fape = loss_fn(target_rigid, gt_rigid, gt_frames_mask, diff --git a/alphafold/model/utils.py b/alphafold/model/utils.py index 2347ffa7b..97d2e2749 100644 --- a/alphafold/model/utils.py +++ b/alphafold/model/utils.py @@ -34,7 +34,7 @@ def final_init(config): def batched_gather(params, indices, axis=0, batch_dims=0): """Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`.""" - take_fn = lambda p, i: jnp.take(p, i, axis=axis) + take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip') for _ in range(batch_dims): take_fn = jax.vmap(take_fn) return take_fn(params, indices) diff --git a/docker/Dockerfile b/docker/Dockerfile index 6aa25df4a..6fda3a172 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -18,13 +18,10 @@ FROM nvidia/cuda:${CUDA_FULL}-base-ubuntu20.04 AS build # previously set). ARG CUDA_FULL ARG CUDA=11.2 -# JAXLIB no longer built for all minor CUDA versions: -# https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0166-may-11-2021 -ARG CUDA_JAXLIB=11.1 ARG DEBIAN_FRONTEND=noninteractive # Use bash to support string substitution. -SHELL ["/bin/bash", "-c"] +SHELL ["/bin/bash", "-o", "pipefail", "-c"] RUN apt-get update \ && apt-get upgrade -y \ @@ -71,11 +68,11 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \ https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt # Install pip packages. -RUN pip3 install --upgrade --no-cache-dir pip \ - && pip3 install --no-cache-dir -r /app/alphafold/requirements.txt \ +RUN pip3 install --upgrade pip --no-cache-dir \ + && pip3 install -r /app/alphafold/requirements.txt --no-cache-dir \ && pip3 install --upgrade --no-cache-dir \ - jax==0.2.14 \ - jaxlib==0.1.69+cuda$(cut -f1,2 -d. <<< ${CUDA_JAXLIB} | sed 's/\.//g') \ + jax==0.3.17 \ + jaxlib==0.3.15+cuda11.cudnn805 \ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # Apply OpenMM patch. diff --git a/docker/requirements.txt b/docker/requirements.txt index 3f056f7b3..2956cb2f4 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -1,3 +1,3 @@ # Dependencies necessary to execute run_docker.py -absl-py==0.13.0 +absl-py==1.0.0 docker==5.0.0 diff --git a/notebooks/AlphaFold.ipynb b/notebooks/AlphaFold.ipynb index 37c8550f5..9ec5d7075 100644 --- a/notebooks/AlphaFold.ipynb +++ b/notebooks/AlphaFold.ipynb @@ -478,6 +478,12 @@ "\n", "run_relax = True #@param {type:\"boolean\"}\n", "\n", + "#@markdown Relaxation is faster with a GPU, but we have found it to be less stable.\n", + "#@markdown You may wish to enable GPU for higher performance, but if it doesn't\n", + "#@markdown converge we suggested reverting to using without GPU.\n", + "\n", + "relax_use_gpu = False #@param {type:\"boolean\"}\n", + "\n", "# --- Run the model ---\n", "if model_type_to_use == notebook_utils.ModelType.MONOMER:\n", " model_names = config.MODEL_PRESETS['monomer'] + ('model_2_ptm',)\n", @@ -554,7 +560,7 @@ " stiffness=10.0,\n", " exclude_residues=[],\n", " max_outer_iterations=3,\n", - " use_gpu=True)\n", + " use_gpu=relax_use_gpu)\n", " relaxed_pdb, _, _ = amber_relaxer.process(prot=unrelaxed_proteins[best_model_name])\n", " else:\n", " print('Warning: Running without the relaxation stage.')\n", diff --git a/requirements.txt b/requirements.txt index be68cd0c1..e571bb10b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,14 @@ -absl-py==0.13.0 +absl-py==1.0.0 biopython==1.79 chex==0.0.7 -dm-haiku==0.0.4 +dm-haiku==0.0.7 dm-tree==0.1.6 docker==5.0.0 immutabledict==2.0.0 -jax==0.2.14 +jax==0.3.17 ml-collections==0.1.0 -numpy==1.19.5 +numpy==1.21.6 pandas==1.3.4 protobuf==3.20.1 scipy==1.7.0 -tensorflow-cpu==2.5.0 +tensorflow-cpu==2.9.0 diff --git a/setup.py b/setup.py index 572bc0b72..3871014c1 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold', - version='2.2.3', + version='2.2.4', description='An implementation of the inference pipeline of AlphaFold v2.0.' 'This is a completely new model that was entered as AlphaFold2 in CASP14 ' 'and published in Nature.',