From 0f3dce2baf85719a4af9b1375303d78693a681b7 Mon Sep 17 00:00:00 2001 From: Julia Varga Date: Wed, 4 Dec 2024 18:50:12 +0200 Subject: [PATCH] Actifptm (#10) * version of extended metrics, doesnt work because of tracing and dynamic conflict * pairwise actifptm, iptm and chain-ptm calculation and plotting * updated version * working actifptm, including for the whole complex --------- Co-authored-by: jvarga --- alphafold/common/confidence.py | 10 +++++++--- alphafold/model/model.py | 8 ++++---- alphafold/model/modules.py | 3 ++- alphafold/model/modules_multimer.py | 9 +++++---- setup.py | 2 +- 5 files changed, 19 insertions(+), 13 deletions(-) diff --git a/alphafold/common/confidence.py b/alphafold/common/confidence.py index e5ea99af0..d83dd7432 100644 --- a/alphafold/common/confidence.py +++ b/alphafold/common/confidence.py @@ -20,6 +20,7 @@ from alphafold.common import residue_constants import scipy.special + def compute_tol(prev_pos, current_pos, mask, use_jnp=False): # Early stopping criteria based on criteria used in # AF2Complex: https://www.nature.com/articles/s41467-022-29394-2 @@ -168,7 +169,7 @@ def predicted_tm_score(logits, breaks, residue_weights = None, return (per_alignment * residue_weights).max() -def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False): +def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False, keep_pae=False): """Post processes prediction_result to get confidence metrics.""" confidence_metrics = {} plddt = compute_plddt(prediction_result['predicted_lddt']['logits'], use_jnp=use_jnp) @@ -176,16 +177,19 @@ def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=F confidence_metrics["mean_plddt"] = (plddt * mask).sum()/mask.sum() if 'predicted_aligned_error' in prediction_result: + if keep_pae: + prediction_result['pae_matrix_with_logits'] = prediction_result['predicted_aligned_error'] + confidence_metrics.update(compute_predicted_aligned_error( logits=prediction_result['predicted_aligned_error']['logits'], breaks=prediction_result['predicted_aligned_error']['breaks'], use_jnp=use_jnp)) - + confidence_metrics['ptm'] = predicted_tm_score( logits=prediction_result['predicted_aligned_error']['logits'], breaks=prediction_result['predicted_aligned_error']['breaks'], residue_weights=mask, - use_jnp=use_jnp) + use_jnp=use_jnp) if "asym_id" in prediction_result["predicted_aligned_error"]: # Compute the ipTM only for the multimer model. diff --git a/alphafold/model/model.py b/alphafold/model/model.py index 88e90f1f4..7a6a8a621 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -35,11 +35,13 @@ class RunModel: def __init__(self, config: ml_collections.ConfigDict, params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None, - is_training = False): + is_training = False, extended_ptm_config=None): self.config = config self.params = params self.multimer_mode = config.model.global_config.multimer_mode + self.config.model.calc_extended_ptm = extended_ptm_config['calc_extended_ptm'] if extended_ptm_config else False + self.config.model.use_probs_extended = extended_ptm_config['use_probs_extended'] if extended_ptm_config else False if self.multimer_mode: def _forward_fn(batch): @@ -148,7 +150,6 @@ def predict(self, L = aatype.shape[1] # initialize - zeros = lambda shape: np.zeros(shape, dtype=np.float16) prev = {'prev_msa_first_row': zeros([L,256]), 'prev_pair': zeros([L,L,128]), @@ -170,7 +171,7 @@ def _jnp_to_np(x): # initialize random key key = jax.random.PRNGKey(random_seed) - # iterate through recyckes + # iterate through recycles for r in range(num_iters): # grab subset of features if self.multimer_mode: @@ -197,6 +198,5 @@ def _jnp_to_np(x): break if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance: break - logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result)) return result, r \ No newline at end of file diff --git a/alphafold/model/modules.py b/alphafold/model/modules.py index b67707a9e..d4f211ed2 100644 --- a/alphafold/model/modules.py +++ b/alphafold/model/modules.py @@ -206,6 +206,7 @@ def get_prev(ret): prediction_result=ret, mask=batch["seq_mask"], rank_by=self.config.rank_by, + keep_pae=self.config.calc_extended_ptm, use_jnp=True)) ret["tol"] = confidence.compute_tol( @@ -467,7 +468,7 @@ def slice_recycle_idx(x): mask=batch["seq_mask"][0], rank_by=self.config.rank_by, use_jnp=True)) - + ret["tol"] = confidence.compute_tol( prev["prev_pos"], ret["prev"]["prev_pos"], diff --git a/alphafold/model/modules_multimer.py b/alphafold/model/modules_multimer.py index 7cd8a6fd5..8f4f17c94 100644 --- a/alphafold/model/modules_multimer.py +++ b/alphafold/model/modules_multimer.py @@ -416,7 +416,6 @@ def __call__( is_training, return_representations=False, safe_key=None): - c = self.config impl = AlphaFoldIteration(c, self.global_config) @@ -444,7 +443,7 @@ def apply_network(prev, safe_key): safe_key=safe_key) # initialize - prev = batch.pop("prev", None) + prev = batch.pop("prev", None) if prev is None: L = num_residues prev = {'prev_msa_first_row': jnp.zeros([L,256]), @@ -457,7 +456,7 @@ def apply_network(prev, safe_key): ret = apply_network(prev=prev, safe_key=safe_key) ret["prev"] = get_prev(ret) - + if not return_representations: del ret['representations'] @@ -466,7 +465,9 @@ def apply_network(prev, safe_key): prediction_result=ret, mask=batch["seq_mask"], rank_by=self.config.rank_by, - use_jnp=True)) + keep_pae=self.config.calc_extended_ptm, + use_jnp=True + )) ret["tol"] = confidence.compute_tol( prev["prev_pos"], diff --git a/setup.py b/setup.py index 244b7503c..bde2e912b 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name='alphafold-colabfold', - version='2.3.6', + version='2.3.7', long_description_content_type='text/markdown', description='An implementation of the inference pipeline of AlphaFold v2.3.1. ' 'This is a completely new model that was entered as AlphaFold2 in CASP14 '