Skip to content

Commit

Permalink
Merge pull request #540 from c-bata/optuna-artifact-support
Browse files Browse the repository at this point in the history
Support optuna's artifact metadata
  • Loading branch information
HideakiImamura authored Aug 9, 2023
2 parents 0f4fc37 + 768d27b commit d534763
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 37 deletions.
48 changes: 38 additions & 10 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any
from typing import Optional
from typing import Union
import warnings

from bottle import Bottle
from bottle import redirect
Expand Down Expand Up @@ -36,10 +37,12 @@
from ._storage_url import get_storage
from .artifact._backend import delete_all_artifacts
from .artifact._backend import register_artifact_route
from .artifact._backend_to_store import to_artifact_store


if typing.TYPE_CHECKING:
from _typeshed.wsgi import WSGIApplication
from optuna.artifacts._protocol import ArtifactStore
from optuna_dashboard.artifact.protocol import ArtifactBackend


Expand All @@ -54,7 +57,7 @@

def create_app(
storage: BaseStorage,
artifact_backend: Optional[ArtifactBackend] = None,
artifact_store: Optional[ArtifactStore] = None,
debug: bool = False,
) -> Bottle:
app = Bottle()
Expand All @@ -76,7 +79,7 @@ def dashboard() -> BottleViewReturn:
@json_api_view
def api_meta() -> dict[str, Any]:
return {
"artifact_is_available": artifact_backend is not None,
"artifact_is_available": artifact_store is not None,
}

@app.get("/api/studies")
Expand Down Expand Up @@ -156,9 +159,8 @@ def rename_study(study_id: int) -> dict[str, Any]:
@app.delete("/api/studies/<study_id:int>")
@json_api_view
def delete_study(study_id: int) -> dict[str, Any]:
if artifact_backend is not None:
system_attrs = storage.get_study_system_attrs(study_id)
delete_all_artifacts(artifact_backend, system_attrs)
if artifact_store is not None:
delete_all_artifacts(artifact_store, storage, study_id)

try:
storage.delete_study(study_id)
Expand Down Expand Up @@ -347,33 +349,59 @@ def send_static(filename: str) -> BottleViewReturn:
return static_file(filename, root=STATIC_DIR)

register_rdb_migration_route(app, storage)
register_artifact_route(app, storage, artifact_backend)
register_artifact_route(app, storage, artifact_store)
return app


def run_server(
storage: Union[str, BaseStorage],
host: str = "localhost",
port: int = 8080,
artifact_store: Optional[ArtifactStore | ArtifactBackend] = None,
*,
artifact_backend: Optional[ArtifactBackend] = None,
) -> None:
"""Start running optuna-dashboard and blocks until the server terminates.
This function uses wsgiref module which is not intended for the production
use. If you want to run optuna-dashboard more secure and/or more fast,
please use WSGI server like Gunicorn or uWSGI via :func:`wsgi` function.
"""
app = create_app(get_storage(storage), artifact_backend=artifact_backend)
# TODO(c-bata): Remove artifact_backend keyword argument in the future release.
store: ArtifactStore | None = None
if artifact_store is not None:
store = to_artifact_store(artifact_store)
elif artifact_backend is not None:
warnings.warn(
"The `artifact_backend` argument is deprecated. "
"Please use `artifact_store` instead.",
DeprecationWarning,
)
store = to_artifact_store(artifact_backend)

app = create_app(get_storage(storage), artifact_store=store)
run(app, host=host, port=port)


def wsgi(
storage: Union[str, BaseStorage],
artifact_store: Optional[ArtifactBackend | ArtifactStore] = None,
*,
artifact_backend: Optional[ArtifactBackend] = None,
) -> WSGIApplication:
"""This function exposes WSGI interface for people who want to run on the
production-class WSGI servers like Gunicorn or uWSGI.
"""
return create_app(get_storage(storage), artifact_backend=artifact_backend)
# TODO(c-bata): Remove artifact_backend keyword argument in the future release.
store: ArtifactStore | None = None
if artifact_store is not None:
store = to_artifact_store(artifact_store)
elif artifact_backend is not None:
warnings.warn(
"The `artifact_backend` argument is deprecated. "
"Please use `artifact_store` instead.",
DeprecationWarning,
)
store = to_artifact_store(artifact_backend)

return create_app(get_storage(storage), artifact_store=store)
18 changes: 15 additions & 3 deletions optuna_dashboard/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,22 @@
from bottle import run
from optuna.storages import BaseStorage
from optuna.storages import RDBStorage
from optuna.version import __version__ as optuna_ver
from packaging import version

from . import __version__
from ._app import create_app
from ._sql_profiler import register_profiler_view
from ._storage_url import get_storage
from .artifact._backend_to_store import ArtifactBackendToStore
from .artifact.file_system import FileSystemBackend


if TYPE_CHECKING:
from typing import Literal

from optuna.artifacts._protocol import ArtifactStore


DEBUG = os.environ.get("OPTUNA_DASHBOARD_DEBUG") == "1"
SERVER_CHOICES = ["auto", "wsgiref", "gunicorn"]
Expand Down Expand Up @@ -113,10 +118,17 @@ def main() -> None:
storage: BaseStorage
storage = get_storage(args.storage, storage_class=args.storage_class)

artifact_backend = None
if args.artifact_dir is not None:
artifact_store: ArtifactStore | None
if args.artifact_dir is None:
artifact_store = None
elif version.parse(optuna_ver) >= version.Version("3.3.0"):
from optuna.artifacts import FileSystemArtifactStore

artifact_store = FileSystemArtifactStore(args.artifact_dir)
else:
artifact_backend = FileSystemBackend(args.artifact_dir)
app = create_app(storage, artifact_backend=artifact_backend, debug=DEBUG)
artifact_store = ArtifactBackendToStore(artifact_backend)
app = create_app(storage, artifact_store=artifact_store, debug=DEBUG)

if DEBUG and isinstance(storage, RDBStorage):
app = register_profiler_view(app, storage)
Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def serialize_frozen_trial(
{k: trial_system_attrs[k] for k in trial_system_attrs if not k.startswith("dashboard")}
),
"note": note.get_note_from_system_attrs(study_system_attrs, trial._trial_id),
"artifacts": list_trial_artifacts(study_system_attrs, trial._trial_id),
"artifacts": list_trial_artifacts(study_system_attrs, trial),
"constraints": trial_system_attrs.get(CONSTRAINTS_KEY, []),
}

Expand Down
85 changes: 62 additions & 23 deletions optuna_dashboard/artifact/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import os.path
from typing import TYPE_CHECKING
import uuid
import warnings

from bottle import BaseRequest
from bottle import Bottle
from bottle import HTTPResponse
from bottle import request
from bottle import response
import optuna
from optuna.trial import FrozenTrial

from .._bottle_util import json_api_view
from .._bottle_util import parse_data_uri
Expand All @@ -23,6 +25,7 @@
from typing import Optional
from typing import TypedDict

from optuna.artifacts._protocol import ArtifactStore
from optuna.storages import BaseStorage

from .protocol import ArtifactBackend
Expand Down Expand Up @@ -56,14 +59,14 @@ def get_artifact_path(


def register_artifact_route(
app: Bottle, storage: BaseStorage, artifact_backend: Optional[ArtifactBackend]
app: Bottle, storage: BaseStorage, artifact_store: Optional[ArtifactStore]
) -> None:
@app.get("/artifacts/<study_id:int>/<trial_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
def proxy_artifact(study_id: int, trial_id: int, artifact_id: str) -> HTTPResponse | bytes:
if artifact_backend is None:
if artifact_store is None:
response.status = 400 # Bad Request
return b"Cannot access to the artifacts."
artifact_dict = _get_artifact_meta(storage, study_id, trial_id, artifact_id)
artifact_dict = get_artifact_meta(storage, study_id, trial_id, artifact_id)
if artifact_dict is None:
response.status = 404
return b"Not Found"
Expand All @@ -72,13 +75,14 @@ def proxy_artifact(study_id: int, trial_id: int, artifact_id: str) -> HTTPRespon
if encoding:
headers["Content-Encodings"] = encoding

fp = artifact_backend.open(artifact_id)
fp = artifact_store.open_reader(artifact_id)
return HTTPResponse(fp, headers=headers)

@app.post("/api/artifacts/<study_id:int>/<trial_id:int>")
@json_api_view
def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
if artifact_backend is None:
# TODO(c-bata): Use optuna.artifacts.upload_artifact()
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
file = request.json.get("file")
Expand All @@ -89,7 +93,7 @@ def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
_, data = parse_data_uri(file)
filename = request.json.get("filename", "")
artifact_id = str(uuid.uuid4())
artifact_backend.write(artifact_id, io.BytesIO(data))
artifact_store.write(artifact_id, io.BytesIO(data))

mimetype, encoding = mimetypes.guess_type(filename)
artifact = {
Expand All @@ -102,18 +106,22 @@ def upload_artifact_api(study_id: int, trial_id: int) -> dict[str, Any]:
storage.set_study_system_attr(study_id, attr_key, json.dumps(artifact))
response.status = 201

trial = storage.get_trial(trial_id)
if trial is None:
response.status = 400
return {"reason": "Invalid study_id or trial_id"}
return {
"artifact_id": artifact_id,
"artifacts": list_trial_artifacts(storage.get_study_system_attrs(study_id), trial_id),
"artifacts": list_trial_artifacts(storage.get_study_system_attrs(study_id), trial),
}

@app.delete("/api/artifacts/<study_id:int>/<trial_id:int>/<artifact_id:re:[0-9a-fA-F-]+>")
@json_api_view
def delete_artifact(study_id: int, trial_id: int, artifact_id: str) -> dict[str, Any]:
if artifact_backend is None:
if artifact_store is None:
response.status = 400 # Bad Request
return {"reason": "Cannot access to the artifacts."}
artifact_backend.remove(artifact_id)
artifact_store.remove(artifact_id)

attr_key = _artifact_prefix(trial_id) + artifact_id
storage.set_study_system_attr(study_id, attr_key, json.dumps(None))
Expand All @@ -131,6 +139,12 @@ def upload_artifact(
) -> str:
"""Upload an artifact (files), which is associated with the trial.
.. warning::
This function is deprecated. Please use `optuna.artifacts.upload_artifact
<https://optuna.readthedocs.io/en/latest/reference/generated/optuna.artifacts.
upload_artifact.html>`_ instead.
Example:
.. code-block:: python
Expand All @@ -146,6 +160,13 @@ def objective(trial: optuna.Trial) -> float:
upload_artifact(artifact_backend, trial, file_path)
return ...
"""
warnings.warn(
"This function is deprecated. Please use optuna.artifacts.upload_artifact() instead.\n"
"See https://optuna.readthedocs.io/en/latest/reference/generated/"
"optuna.artifacts.upload_artifact.html",
DeprecationWarning,
)

filename = os.path.basename(file_path)
storage = trial.storage
trial_id = trial._trial_id
Expand All @@ -170,31 +191,49 @@ def _artifact_prefix(trial_id: int) -> str:
return ARTIFACTS_ATTR_PREFIX + f"{trial_id}:"


def _get_artifact_meta(
def get_artifact_meta(
storage: BaseStorage, study_id: int, trial_id: int, artifact_id: str
) -> Optional[ArtifactMeta]:
study_system_attr = storage.get_study_system_attrs(study_id)
attr_key = _artifact_prefix(trial_id=trial_id) + artifact_id
artifact_meta = study_system_attr.get(attr_key)
if artifact_meta is None:
return None
return json.loads(artifact_meta)
if artifact_meta is not None:
return json.loads(artifact_meta)

# See https://github.com/optuna/optuna/blob/f827582a8/optuna/artifacts/_upload.py#L71
trial_system_attrs = storage.get_trial_system_attrs(trial_id)
value = trial_system_attrs.get("artifacts:" + artifact_id)
if value is not None:
return json.loads(value)
return None

def delete_all_artifacts(backend: ArtifactBackend, study_system_attrs: dict[str, Any]) -> None:
artifact_meta_list: list[ArtifactMeta] = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(ARTIFACTS_ATTR_PREFIX)
]
for meta in artifact_meta_list:

def delete_all_artifacts(backend: ArtifactStore, storage: BaseStorage, study_id: int) -> None:
artifact_metas = []
study_system_attrs = storage.get_study_system_attrs(study_id)
for trial in storage.get_all_trials(study_id):
trial_artifacts = list_trial_artifacts(study_system_attrs, trial)
artifact_metas.extend(trial_artifacts)

for meta in artifact_metas:
backend.remove(meta["artifact_id"])


def list_trial_artifacts(study_system_attrs: dict[str, Any], trial_id: int) -> list[ArtifactMeta]:
artifact_metas = [
def list_trial_artifacts(
study_system_attrs: dict[str, Any], trial: FrozenTrial
) -> list[ArtifactMeta]:
dashboard_artifact_metas = [
json.loads(value)
for key, value in study_system_attrs.items()
if key.startswith(_artifact_prefix(trial_id))
if key.startswith(_artifact_prefix(trial._trial_id))
]

# See https://github.com/optuna/optuna/blob/f827582a8/optuna/artifacts/_upload.py#L16
optuna_artifact_metas = [
json.loads(value)
for key, value in trial.system_attrs.items()
if key.startswith("artifacts:")
]

artifact_metas = dashboard_artifact_metas + optuna_artifact_metas
return [a for a in artifact_metas if a is not None]
39 changes: 39 additions & 0 deletions optuna_dashboard/artifact/_backend_to_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from typing import TYPE_CHECKING


if TYPE_CHECKING:
from typing import BinaryIO
from typing import TypeGuard

from optuna.artifacts._protocol import ArtifactStore

from .protocol import ArtifactBackend


def is_artifact_backend(store: ArtifactBackend | ArtifactStore) -> TypeGuard[ArtifactBackend]:
return getattr(store, "open_reader", None) is None


def to_artifact_store(store: ArtifactBackend | ArtifactStore) -> ArtifactStore:
if is_artifact_backend(store):
return ArtifactBackendToStore(store)
# mypy cannot infer the type of `store` here.
return store # type: ignore


class ArtifactBackendToStore:
"""Converts a Dashboard's ArtifactBackend to Optuna's ArtifactStore."""

def __init__(self, artifact_backend: ArtifactBackend) -> None:
self._backend = artifact_backend

def open_reader(self, artifact_id: str) -> BinaryIO:
return self._backend.open(artifact_id)

def write(self, artifact_id: str, content_body: BinaryIO) -> None:
self._backend.write(artifact_id, content_body)

def remove(self, artifact_id: str) -> None:
self._backend.remove(artifact_id)
Loading

0 comments on commit d534763

Please sign in to comment.