diff --git a/docs/api.rst b/docs/api.rst index f1400334..c285a1c3 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -230,6 +230,7 @@ Utilities :nosignatures: :toctree: api/generated/ + _CachedPortalClient portal_client_manager Exceptions @@ -262,4 +263,4 @@ Exceptions PCMSettingError InvalidDatasetError DatasetRegisterError - RecordTypeError \ No newline at end of file + RecordTypeError diff --git a/openff/qcsubmit/_tests/results/test_filters.py b/openff/qcsubmit/_tests/results/test_filters.py index 486e961b..c9fac248 100644 --- a/openff/qcsubmit/_tests/results/test_filters.py +++ b/openff/qcsubmit/_tests/results/test_filters.py @@ -1,5 +1,6 @@ import datetime import logging +from tempfile import TemporaryDirectory import numpy import pytest @@ -31,6 +32,7 @@ SMILESFilter, UnperceivableStereoFilter, ) +from openff.qcsubmit.utils import _CachedPortalClient, portal_client_manager from . import RecordStatusEnum, SinglepointRecord @@ -545,5 +547,9 @@ def test_unperceivable_stereo_filter(toolkits, n_expected, public_client): ) assert collection.n_results == 1 - filtered = collection.filter(UnperceivableStereoFilter(toolkits=toolkits)) + with ( + TemporaryDirectory() as d, + portal_client_manager(lambda a: _CachedPortalClient(a, cache_dir=d)), + ): + filtered = collection.filter(UnperceivableStereoFilter(toolkits=toolkits)) assert filtered.n_results == n_expected diff --git a/openff/qcsubmit/_tests/results/test_results.py b/openff/qcsubmit/_tests/results/test_results.py index 2308ff3c..61903862 100644 --- a/openff/qcsubmit/_tests/results/test_results.py +++ b/openff/qcsubmit/_tests/results/test_results.py @@ -3,6 +3,7 @@ """ import datetime +from tempfile import TemporaryDirectory import pytest from openff.toolkit.topology import Molecule @@ -26,6 +27,7 @@ ) from openff.qcsubmit.results.filters import ResultFilter from openff.qcsubmit.results.results import TorsionDriveResult, _BaseResultCollection +from openff.qcsubmit.utils import _CachedPortalClient, portal_client_manager from . import ( OptimizationRecord, @@ -315,7 +317,33 @@ def test_to_records( public_client, collection_name, spec_name=spec_name ) assert collection.n_molecules == expected_n_mols - records_and_molecules = collection.to_records() + + def disconnected_client(addr, cache_dir): + ret = _CachedPortalClient(addr, cache_dir) + ret._req_session = None + return ret + + with TemporaryDirectory() as d: + client = _CachedPortalClient(public_client.address, d) + with portal_client_manager(lambda _: client): + with ( + client._no_session(), + pytest.raises(Exception, match="no attribute 'prepare_request'"), + ): + collection.to_records() + records_and_molecules = collection.to_records() + # TorsionDriveResultCollection.to_records requires fetching + # molecules, which cannot currently be cached + if collection_type is not TorsionDriveResultCollection: + with client._no_session(): + assert len(collection.to_records()) == len(records_and_molecules) + # the previous checks show that the *same* client can access + # its cache without making new requests. disconnected_client + # instead shows that a newly-constructed client pointing at the + # same cache_dir can still access the cache + with portal_client_manager(lambda addr: disconnected_client(addr, d)): + assert len(collection.to_records()) == len(records_and_molecules) + assert len(records_and_molecules) == expected_n_recs record, molecule = records_and_molecules[0] @@ -351,9 +379,13 @@ def test_optimization_to_basic_result_collection(public_client): optimization_result_collection = OptimizationResultCollection.from_server( public_client, ["OpenFF Gen 2 Opt Set 3 Pfizer Discrepancy"] ) - basic_collection = optimization_result_collection.to_basic_result_collection( - "hessian" - ) + with ( + TemporaryDirectory() as d, + portal_client_manager(lambda a: _CachedPortalClient(a, d)), + ): + basic_collection = optimization_result_collection.to_basic_result_collection( + "hessian" + ) assert basic_collection.n_results == 197 assert basic_collection.n_molecules == 49 diff --git a/openff/qcsubmit/_tests/test_submissions.py b/openff/qcsubmit/_tests/test_submissions.py index c47e1f38..6e5823e1 100644 --- a/openff/qcsubmit/_tests/test_submissions.py +++ b/openff/qcsubmit/_tests/test_submissions.py @@ -4,6 +4,8 @@ Here we use the qcfractal snowflake fixture to set up the database. """ +from tempfile import TemporaryDirectory + import pytest from openff.toolkit.topology import Molecule from qcelemental.models.procedures import OptimizationProtocols @@ -37,7 +39,7 @@ OptimizationResultCollection, TorsionDriveResultCollection, ) -from openff.qcsubmit.utils import get_data +from openff.qcsubmit.utils import _CachedPortalClient, get_data, portal_client_manager def await_results(client, timeout=120, check_fn=PortalClient.get_singlepoints, ids=[1]): @@ -1408,7 +1410,11 @@ def test_invalid_cmiles(fulltest_client, factory_type, result_collection_type): assert ds.specifications.keys() == {"default"} results = result_collection_type.from_datasets(datasets=ds) assert results.n_molecules == 1 - records = results.to_records() + with ( + TemporaryDirectory() as d, + portal_client_manager(lambda a: _CachedPortalClient(a, d)), + ): + records = results.to_records() assert len(records) == 1 # Single points and optimizations look here fulltest_client.modify_molecule( @@ -1427,6 +1433,10 @@ def test_invalid_cmiles(fulltest_client, factory_type, result_collection_type): ds._cache_data.update_entries(entries) results = result_collection_type.from_datasets(datasets=ds) assert results.n_molecules == 1 - with pytest.warns(UserWarning, match="invalid CMILES"): + with ( + pytest.warns(UserWarning, match="invalid CMILES"), + TemporaryDirectory() as d, + portal_client_manager(lambda a: _CachedPortalClient(a, d)), + ): records = results.to_records() assert len(records) == 0 diff --git a/openff/qcsubmit/utils/__init__.py b/openff/qcsubmit/utils/__init__.py index 2aea616d..cd5bcaf0 100644 --- a/openff/qcsubmit/utils/__init__.py +++ b/openff/qcsubmit/utils/__init__.py @@ -1,4 +1,5 @@ from openff.qcsubmit.utils.utils import ( + _CachedPortalClient, check_missing_stereo, chunk_generator, clean_strings, @@ -22,4 +23,5 @@ "get_symmetry_classes", "get_symmetry_group", "portal_client_manager", + "_CachedPortalClient", ] diff --git a/openff/qcsubmit/utils/utils.py b/openff/qcsubmit/utils/utils.py index 4cf219a1..fa61cecd 100644 --- a/openff/qcsubmit/utils/utils.py +++ b/openff/qcsubmit/utils/utils.py @@ -1,5 +1,17 @@ +import logging +import os from contextlib import contextmanager -from typing import Callable, Dict, Generator, List, Tuple +from typing import ( + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, +) from openff.toolkit import topology as off from openff.toolkit.utils.toolkits import ( @@ -7,6 +19,241 @@ UndefinedStereochemistryError, ) from qcportal import PortalClient +from qcportal.cache import RecordCache, get_records_with_cache +from qcportal.optimization.record_models import OptimizationRecord +from qcportal.singlepoint.record_models import SinglepointRecord +from qcportal.torsiondrive.record_models import TorsiondriveRecord + +logger = logging.getLogger(__name__) + + +class _CachedPortalClient(PortalClient): + """A cached version of a `qcportal.PortalClient + `_. + """ + + def __init__( + self, + address: str, + cache_dir: str, + username: Optional[str] = None, + password: Optional[str] = None, + verify: bool = True, + show_motd: bool = True, + *, + cache_max_size: int = 0, + memory_cache_key: Optional[str] = None, + ): + """Parameters + ---------- + address + The host or IP address of the FractalServer instance, including + protocol and port if necessary ("https://ml.qcarchive.molssi.org", + "http://192.168.1.10:8888") + cache_dir + Directory to store an internal cache of records and other data. + Unlike a normal ``PortalClient``, this argument is required. + username + The username to authenticate with. + password + The password to authenticate with. + verify + Verifies the SSL connection with a third party server. This may be + False if a FractalServer was not provided an SSL certificate and + defaults back to self-signed SSL keys. + show_motd + If a Message-of-the-Day is available, display it + cache_max_size + Maximum size of the cache directory + """ + super().__init__( + address, + username=username, + password=password, + verify=verify, + show_motd=show_motd, + cache_dir=cache_dir, + cache_max_size=cache_max_size, + memory_cache_key=memory_cache_key, + ) + self.record_cache = RecordCache( + os.path.join(self.cache.cache_dir, "cache.sqlite"), read_only=False + ) + + def __repr__(self) -> str: + """A short representation of the current PortalClient. + + Returns + ------- + str + The desired representation. + """ + ret = "CachedPortalClient(server_name='{}', address='{}', username='{}', cache_dir='{}')".format( + self.server_name, self.address, self.username, self.cache.cache_dir + ) + return ret + + def get_optimizations( + self, + record_ids: Union[int, Sequence[int]], + missing_ok: bool = False, + *, + include: Optional[Iterable[str]] = None, + ) -> Union[Optional[OptimizationRecord], List[Optional[OptimizationRecord]]]: + """Obtain optimization records with the specified IDs. + + Records will be returned in the same order as the record ids. + + Parameters + ---------- + record_ids + Single ID or sequence/list of records to obtain + missing_ok + Unlike a ``PortalClient``, this argument is ignored. If set to + True, a warning will be printed. Any missing records will cause a + ``RuntimeError`` to be raised. + include + Additional fields to include in the returned record + + Returns + ------- + : + If a single ID was specified, returns just that record. Otherwise, + returns a list of records. + """ + if missing_ok: + logger.warning( + "missing_ok was set to True, but CachedPortalClient" + " doesn't actually support this so it's being set to False" + ) + if unpack := not isinstance(record_ids, Sequence): + record_ids = [record_ids] + res = get_records_with_cache( + client=self, + record_cache=self.record_cache, + record_type=OptimizationRecord, + record_ids=record_ids, + include=include, + force_fetch=False, + ) + if unpack: + return res[0] + else: + return res + + def get_singlepoints( + self, + record_ids: Union[int, Sequence[int]], + missing_ok: bool = False, + *, + include: Optional[Iterable[str]] = None, + ) -> Union[Optional[SinglepointRecord], List[Optional[SinglepointRecord]]]: + """ + Obtain singlepoint records with the specified IDs. + + Records will be returned in the same order as the record ids. + + Parameters + ---------- + record_ids + Single ID or sequence/list of records to obtain + missing_ok + Unlike a ``PortalClient``, this argument is ignored. If set to + True, a warning will be printed. Any missing records will cause a + ``RuntimeError`` to be raised. + include + Additional fields to include in the returned record + + Returns + ------- + : + If a single ID was specified, returns just that record. Otherwise, + returns a list of records. + """ + if missing_ok: + logger.warning( + "missing_ok was set to True, but CachedPortalClient" + " doesn't actually support this so it's being set to False" + ) + if unpack := not isinstance(record_ids, Sequence): + record_ids = [record_ids] + res = get_records_with_cache( + client=self, + record_cache=self.record_cache, + record_type=SinglepointRecord, + record_ids=record_ids, + include=include, + force_fetch=False, + ) + if unpack: + return res[0] + else: + return res + + def get_torsiondrives( + self, + record_ids: Union[int, Sequence[int]], + missing_ok: bool = False, + *, + include: Optional[Iterable[str]] = None, + ) -> Union[Optional[TorsiondriveRecord], List[Optional[TorsiondriveRecord]]]: + """ + Obtain torsiondrive records with the specified IDs. + + Records will be returned in the same order as the record ids. + + Parameters + ---------- + record_ids + Single ID or sequence/list of records to obtain + missing_ok + Unlike a ``PortalClient``, this argument is ignored. If set to + True, a warning will be printed. Any missing records will cause a + ``RuntimeError`` to be raised. + include + Additional fields to include in the returned record + + Returns + ------- + : + If a single ID was specified, returns just that record. Otherwise, + returns a list of records. + """ + if missing_ok: + logger.warning( + "missing_ok was set to True, but CachedPortalClient" + " doesn't actually support this so it's being set to False" + ) + if unpack := not isinstance(record_ids, Sequence): + record_ids = [record_ids] + res = get_records_with_cache( + client=self, + record_cache=self.record_cache, + record_type=TorsiondriveRecord, + record_ids=record_ids, + include=include, + force_fetch=False, + ) + if unpack: + return res[0] + else: + return res + + @contextmanager + def _no_session(self): + """This is a supplemental context manager to the ``no_internet`` + manager in _tests/utils/test_manager.py. ``PortalClient`` creates a + ``requests.Session`` on initialization that can be reused without + accessing ``socket.socket`` again. Combining ``no_internet`` and + ``client._no_session`` should completely ensure that the local cache is + used rather than re-fetching data from QCArchive. + """ + tmp = self._req_session + self._req_session = None + try: + yield + finally: + self._req_session = tmp def _default_portal_client(client_address) -> PortalClient: @@ -22,6 +269,10 @@ def portal_client_manager(portal_client_fn: Callable[[str], PortalClient]): keyword arguments to the ``PortalClient``, such as ``verify=False`` or a ``cache_dir``. + .. warning:: + It is not safe to share the same client across threads or to construct + multiple clients accessing the same cache database. + Parameters ---------- portal_client_fn: @@ -40,7 +291,6 @@ def portal_client_manager(portal_client_fn: Callable[[str], PortalClient]): >>> return PortalClient(client_address, cache_dir=".") >>> with portal_client_manager(my_portal_client): >>> records_and_molecules = ds.to_records() - """ global _default_portal_client original_client_fn = _default_portal_client