Skip to content

Commit

Permalink
Actifptm (#10)
Browse files Browse the repository at this point in the history
* version of extended metrics, doesnt work because of tracing and dynamic conflict

* pairwise actifptm, iptm and chain-ptm calculation and plotting

* updated version

* working actifptm, including for the whole complex

---------

Co-authored-by: jvarga <[email protected]>
  • Loading branch information
gezmi and jvarga authored Dec 4, 2024
1 parent 9c3ec6e commit 0f3dce2
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 13 deletions.
10 changes: 7 additions & 3 deletions alphafold/common/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from alphafold.common import residue_constants
import scipy.special


def compute_tol(prev_pos, current_pos, mask, use_jnp=False):
# Early stopping criteria based on criteria used in
# AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Expand Down Expand Up @@ -168,24 +169,27 @@ def predicted_tm_score(logits, breaks, residue_weights = None,

return (per_alignment * residue_weights).max()

def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False):
def get_confidence_metrics(prediction_result, mask, rank_by = "plddt", use_jnp=False, keep_pae=False):
"""Post processes prediction_result to get confidence metrics."""
confidence_metrics = {}
plddt = compute_plddt(prediction_result['predicted_lddt']['logits'], use_jnp=use_jnp)
confidence_metrics['plddt'] = plddt
confidence_metrics["mean_plddt"] = (plddt * mask).sum()/mask.sum()

if 'predicted_aligned_error' in prediction_result:
if keep_pae:
prediction_result['pae_matrix_with_logits'] = prediction_result['predicted_aligned_error']

confidence_metrics.update(compute_predicted_aligned_error(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks'],
use_jnp=use_jnp))

confidence_metrics['ptm'] = predicted_tm_score(
logits=prediction_result['predicted_aligned_error']['logits'],
breaks=prediction_result['predicted_aligned_error']['breaks'],
residue_weights=mask,
use_jnp=use_jnp)
use_jnp=use_jnp)

if "asym_id" in prediction_result["predicted_aligned_error"]:
# Compute the ipTM only for the multimer model.
Expand Down
8 changes: 4 additions & 4 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class RunModel:
def __init__(self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
is_training = False):
is_training = False, extended_ptm_config=None):

self.config = config
self.params = params
self.multimer_mode = config.model.global_config.multimer_mode
self.config.model.calc_extended_ptm = extended_ptm_config['calc_extended_ptm'] if extended_ptm_config else False
self.config.model.use_probs_extended = extended_ptm_config['use_probs_extended'] if extended_ptm_config else False

if self.multimer_mode:
def _forward_fn(batch):
Expand Down Expand Up @@ -148,7 +150,6 @@ def predict(self,
L = aatype.shape[1]

# initialize

zeros = lambda shape: np.zeros(shape, dtype=np.float16)
prev = {'prev_msa_first_row': zeros([L,256]),
'prev_pair': zeros([L,L,128]),
Expand All @@ -170,7 +171,7 @@ def _jnp_to_np(x):
# initialize random key
key = jax.random.PRNGKey(random_seed)

# iterate through recyckes
# iterate through recycles
for r in range(num_iters):
# grab subset of features
if self.multimer_mode:
Expand All @@ -197,6 +198,5 @@ def _jnp_to_np(x):
break
if r > 0 and result["tol"] < self.config.model.recycle_early_stop_tolerance:
break

logging.info('Output shape was %s', tree.map_structure(lambda x: x.shape, result))
return result, r
3 changes: 2 additions & 1 deletion alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def get_prev(ret):
prediction_result=ret,
mask=batch["seq_mask"],
rank_by=self.config.rank_by,
keep_pae=self.config.calc_extended_ptm,
use_jnp=True))

ret["tol"] = confidence.compute_tol(
Expand Down Expand Up @@ -467,7 +468,7 @@ def slice_recycle_idx(x):
mask=batch["seq_mask"][0],
rank_by=self.config.rank_by,
use_jnp=True))

ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
ret["prev"]["prev_pos"],
Expand Down
9 changes: 5 additions & 4 deletions alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ def __call__(
is_training,
return_representations=False,
safe_key=None):

c = self.config
impl = AlphaFoldIteration(c, self.global_config)

Expand Down Expand Up @@ -444,7 +443,7 @@ def apply_network(prev, safe_key):
safe_key=safe_key)

# initialize
prev = batch.pop("prev", None)
prev = batch.pop("prev", None)
if prev is None:
L = num_residues
prev = {'prev_msa_first_row': jnp.zeros([L,256]),
Expand All @@ -457,7 +456,7 @@ def apply_network(prev, safe_key):

ret = apply_network(prev=prev, safe_key=safe_key)
ret["prev"] = get_prev(ret)

if not return_representations:
del ret['representations']

Expand All @@ -466,7 +465,9 @@ def apply_network(prev, safe_key):
prediction_result=ret,
mask=batch["seq_mask"],
rank_by=self.config.rank_by,
use_jnp=True))
keep_pae=self.config.calc_extended_ptm,
use_jnp=True
))

ret["tol"] = confidence.compute_tol(
prev["prev_pos"],
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

setup(
name='alphafold-colabfold',
version='2.3.6',
version='2.3.7',
long_description_content_type='text/markdown',
description='An implementation of the inference pipeline of AlphaFold v2.3.1. '
'This is a completely new model that was entered as AlphaFold2 in CASP14 '
Expand Down

0 comments on commit 0f3dce2

Please sign in to comment.