Skip to content

Commit

Permalink
Charge filter (#151)
Browse files Browse the repository at this point in the history
* add charge filter

* add tests and merge master

* update the naming to be consistent with the results filters

* fix test

* update tests
  • Loading branch information
jthorton authored Jun 29, 2021
1 parent a0d9134 commit 5f1ca32
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 2 deletions.
32 changes: 30 additions & 2 deletions openff/qcsubmit/tests/test_workflow_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_deregister_component_error():
"""

with pytest.raises(ComponentRegisterError):
deregister_component(component="ChargeFilter")
deregister_component(component="BadComponent")


def test_get_component():
Expand All @@ -131,7 +131,7 @@ def test_get_component_error():
Make sure an error is rasied when we try to get a component that was not registered.
"""
with pytest.raises(ComponentRegisterError):
get_component(component_name="ChargeFilter")
get_component(component_name="BadComponent")


def test_custom_component():
Expand Down Expand Up @@ -1065,3 +1065,31 @@ def test_improper_enumerator():
indexer = mol.properties["dihedrals"]
assert indexer.n_impropers == 1
assert indexer.imporpers[0].scan_increment == [4]


def test_formal_charge_filter_exclusive():
"""
Raise an error if both allowed and filtered charges are supplied
"""

with pytest.raises(ValidationError):
workflow_components.ChargeFilter(charges_to_include=[0, 1], charges_to_exclude=[-1])


def test_formal_charge_filter():
"""
Make sure we can correctly filter by the molecules net formal charge.
"""

molecule = Molecule.from_mapped_smiles("[N+:1](=[O:2])([O-:3])[O-:4]")

# filter out the molecule
charge_filter = workflow_components.ChargeFilter(charges_to_exclude=[-1, 0])
result = charge_filter.apply([molecule], processors=1)
assert result.n_molecules == 0
assert result.n_filtered == 1

# now allow it through
charge_filter = workflow_components.ChargeFilter(charges_to_include=[-1])
result = charge_filter.apply([molecule], processors=1)
assert result.n_molecules == 1
1 change: 1 addition & 0 deletions openff/qcsubmit/workflow_components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
StandardConformerGenerator,
)
from openff.qcsubmit.workflow_components.filters import (
ChargeFilter,
CoverageFilter,
ElementFilter,
MolecularWeightFilter,
Expand Down
3 changes: 3 additions & 0 deletions openff/qcsubmit/workflow_components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
StandardConformerGenerator,
)
from openff.qcsubmit.workflow_components.filters import (
ChargeFilter,
CoverageFilter,
ElementFilter,
MolecularWeightFilter,
Expand Down Expand Up @@ -53,6 +54,7 @@
EnumerateStereoisomers,
WBOFragmenter,
PfizerFragmenter,
ChargeFilter,
ScanFilter,
ScanEnumerator,
]
Expand Down Expand Up @@ -166,6 +168,7 @@ def list_components() -> List[Type[CustomWorkflowComponent]]:
register_component(MolecularWeightFilter)
register_component(ElementFilter)
register_component(ScanFilter)
register_component(ChargeFilter)

# state enumeration
register_component(EnumerateTautomers)
Expand Down
68 changes: 68 additions & 0 deletions openff/qcsubmit/workflow_components/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from openff.toolkit.typing.engines.smirnoff import ForceField
from pydantic import Field, root_validator, validator
from rdkit.Chem.rdMolAlign import AlignMol
from simtk import unit
from typing_extensions import Literal

from openff.qcsubmit.common_structures import ComponentProperties
Expand Down Expand Up @@ -694,3 +695,70 @@ def _apply(self, molecules: List[Molecule]) -> ComponentResult:
result.add_molecule(molecule)

return result


class ChargeFilter(BasicSettings, CustomWorkflowComponent):
"""
Filter molecules if their formal charge is not in the `charges_to_include` list or is in the `charges_to_exclude` list.
"""

type: Literal["ChargeFilter"] = "ChargeFilter"

charges_to_include: Optional[List[int]] = Field(
None,
description="The list of net molecule formal charges which are allowed in the dataset."
"This option is mutually exclusive with ``charges_to_exclude``.",
)
charges_to_exclude: Optional[List[int]] = Field(
None,
description="The list of net molecule formal charges which are to be removed from the dataset."
"This option is mutually exclusive with ``charges_to_include``.",
)

@classmethod
def description(cls) -> str:
return "Filter molecules by net formal charge."

@classmethod
def fail_reason(cls) -> str:
return "The molecules net formal charge was not requested or was in the `charges_to_exclude`."

@classmethod
def properties(cls) -> ComponentProperties:
return ComponentProperties(process_parallel=True, produces_duplicates=False)

@root_validator
def _validate_mutually_exclusive(cls, values):
charges_to_include = values.get("charges_to_include")
charges_to_exclude = values.get("charges_to_exclude")

message = "exactly one of ``charges_to_include` and `charges_to_exclude` must specified."

assert charges_to_include is not None or charges_to_exclude is not None, message
assert charges_to_include is None or charges_to_exclude is None, message

return values

def _apply(self, molecules: List[Molecule]) -> ComponentResult:
"""
Filter molecules based on their net formal charge
"""

result = self._create_result()

for molecule in molecules:
total_charge = molecule.total_charge.value_in_unit(unit.elementary_charge)

if (
self.charges_to_include is not None
and total_charge not in self.charges_to_include
) or (
self.charges_to_exclude is not None
and total_charge in self.charges_to_exclude
):
result.filter_molecule(molecule=molecule)

else:
result.add_molecule(molecule)

return result

0 comments on commit 5f1ca32

Please sign in to comment.