diff --git a/ch-ray-ml/tune-algorithm-scheduler.ipynb b/ch-ray-ml/tune-algorithm-scheduler.ipynb
deleted file mode 100644
index ea50f31..0000000
--- a/ch-ray-ml/tune-algorithm-scheduler.ipynb
+++ /dev/null
@@ -1,1907 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "(sec-tune-algorithm-scheduler)=\n",
- "# 超参数调优搜索算法和调度器\n",
- "\n",
- "Ray Tune 的超参数搜索中比较重要的概念是搜索算法和调度器:搜索算法确定如何从搜索空间中选择新的超参数组合(即试验);调度器决定对提前结束一些不太有前景的试验,节省计算资源。搜索算法是必须的,调度器不是必须的。这两者可以协作来选择超参数,比如使用随机搜索算法和异步连续减半算法(Async Successive Halving Algorithm,ASHA)调度器,调度器对一些看起来没希望的试验提前结束。另外,一些超参数优化的包通常提供了封装好的搜索算法,有的还提供了调度器,这些包有自己的使用方式和习惯,Ray Tune 对这些包进行了封装,尽量使得这些包的使用方式统一。下面简单介绍一些常见的搜索算法和调度器。\n",
- "\n",
- "## Hyperband\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "tags": [
- "hide-cell"
- ]
- },
- "outputs": [],
- "source": [
- "import os\n",
- "import tempfile\n",
- "\n",
- "import sys\n",
- "sys.path.append(\"..\")\n",
- "from utils import nyc_flights\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "\n",
- "import torch\n",
- "import torchvision\n",
- "import torch.nn as nn\n",
- "from torch.utils.data import DataLoader\n",
- "from torchvision.models import resnet18\n",
- "\n",
- "import ray\n",
- "from sklearn.model_selection import train_test_split\n",
- "from ray.tune.search.hyperopt import HyperOptSearch\n",
- "import xgboost as xgb\n",
- "from ray import tune\n",
- "from ray.tune.schedulers import AsyncHyperBandScheduler\n",
- "from ray.tune.integration.xgboost import TuneReportCheckpointCallback\n",
- "from ray.tune.schedulers import PopulationBasedTraining\n",
- "\n",
- "folder_path = nyc_flights()\n",
- "file_path = os.path.join(folder_path, \"nyc-flights\", \"1991.csv\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "读取数据,进行必要的数据预处理:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "input_cols = [\n",
- " \"Year\",\n",
- " \"Month\",\n",
- " \"DayofMonth\",\n",
- " \"DayOfWeek\",\n",
- " \"CRSDepTime\",\n",
- " \"CRSArrTime\",\n",
- " \"UniqueCarrier\",\n",
- " \"FlightNum\",\n",
- " \"ActualElapsedTime\",\n",
- " \"Origin\",\n",
- " \"Dest\",\n",
- " \"Distance\",\n",
- " \"Diverted\",\n",
- " \"ArrDelay\",\n",
- "]\n",
- "\n",
- "df = pd.read_csv(file_path, usecols=input_cols,)\n",
- "\n",
- "# 预测是否延误\n",
- "df[\"ArrDelayBinary\"] = 1.0 * (df[\"ArrDelay\"] > 10)\n",
- "\n",
- "df = df[df.columns.difference([\"ArrDelay\"])]\n",
- "\n",
- "for col in df.select_dtypes([\"object\"]).columns:\n",
- " df[col] = df[col].astype(\"category\").cat.codes.astype(np.int32)\n",
- "\n",
- "for col in df.columns:\n",
- " df[col] = df[col].astype(np.float32)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "XGBoost `train()` 函数的 `params` 参数接收树深度等超参数。需要注意的是,XGBoost 等训练框架提供的 `train()` 函数不像 PyTorch 那样有 `for epoch in range(...)` 这样的显式迭代训练过程,如果希望每次训练迭代后立即反馈性能指标,需要在 `train()` 的 `callbacks` 中传入回调函数,Ray 提供了 [`TuneReportCheckpointCallback`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.integration.xgboost.TuneReportCheckpointCallback.html),这个回调函数会在每次训练迭代后将相关指标报告给 Ray Tune。具体到本例中,XGBoost 的 `train()` 函数的 `params` 参数传入了 `\"eval_metric\": [\"logloss\", \"error\"]`,表示评估时的指标; `evals=[(test_set, \"eval\")]` 表示只关注验证集的指标;以上两者合起来,表示对验证集计算 `logloss` 和 `error` 指标,汇报给 Ray Tune 时,指标名称为 `eval-logloss` 和 `eval-error`。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "def train_flight(config: dict):\n",
- " config.update({\n",
- " \"objective\": \"binary:logistic\",\n",
- " \"eval_metric\": [\"logloss\", \"error\"]\n",
- " })\n",
- " _y_label = \"ArrDelayBinary\"\n",
- " train_x, test_x, train_y, test_y = train_test_split(\n",
- " df.loc[:, df.columns != _y_label], \n",
- " df[_y_label], \n",
- " test_size=0.25\n",
- " )\n",
- " \n",
- " train_set = xgb.DMatrix(train_x, label=train_y)\n",
- " test_set = xgb.DMatrix(test_x, label=test_y)\n",
- " \n",
- " xgb.train(\n",
- " params=config,\n",
- " dtrain=train_set,\n",
- " evals=[(test_set, \"eval\")],\n",
- " verbose_eval=False,\n",
- " # 每次迭代后, `TuneReportCheckpointCallback` 将评估指标反馈给 Ray Tune\n",
- " callbacks=[TuneReportCheckpointCallback(frequency=1)]\n",
- " )"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "我们底层使用 `hyperopt` 包所提供的贝叶斯优化搜索算法,如果没安装这个包,请先安装:`pip install hyperopt`。这些包通常有自己的定义搜索空间格式,用户也可以直接使用 Ray Tune 提供的搜索空间定义方式。\n",
- "\n",
- "调度器方面,我们使用 HyperBand 调度算法。[`AsyncHyperBandScheduler`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.AsyncHyperBandScheduler.html) 是 Ray Tune 推荐的 HyperBand 算法的实现,它是异步的,能够更充分利用计算资源。`AsyncHyperBandScheduler` 中 `time_attr` 是描述训练时间的单位,默认为 `training_iteration`,表示一次训练迭代周期,`time_attr` 是计算资源额度的基本时间单位。`AsyncHyperBandScheduler` 的其他参数与 `time_attr` 规定的时间单位高度相关,比如 `max_t` 是每个试验所能获得的总时间,即一个试验最多能获得 `max_t` * `time_attr` 的计算资源额度;`grace_period` 表示至少给每个试验 `grace_period` * `time_attr` 的计算资源额度。`reduction_factor` 是上述数学描述中的 $\\eta$,`brackets` 为 HyperBand 算法所涉及的组合的概念。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "
\n",
- "
\n",
- "
Tune Status
\n",
- "
\n",
- "\n",
- "Current time: | 2024-04-17 23:23:16 |
\n",
- "Running for: | 00:00:10.45 |
\n",
- "Memory: | 12.6/90.0 GiB |
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
System Info
\n",
- " Using AsyncHyperBand: num_stopped=16
Bracket: Iter 8.000: -0.2197494153541173 | Iter 4.000: -0.21991977574377797 | Iter 2.000: -0.2211587603958556 | Iter 1.000: -0.22190215118710216
Bracket: Iter 8.000: -0.2228778516006133 | Iter 4.000: -0.2228778516006133 | Iter 2.000: -0.22374514085706762
Bracket: Iter 8.000: -0.2238690393222754 | Iter 4.000: -0.2238690393222754
Logical resource usage: 1.0/64 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
Trial Status
\n",
- "
\n",
- "\n",
- "Trial name | status | loc | eta | max_depth | min_child_weight | subsample | iter | total time (s) | eval-logloss | eval-error |
\n",
- "\n",
- "\n",
- "train_flight_63737_00000 | TERMINATED | 10.0.0.3:46765 | 0.0160564 | 2 | 3 | 0.963344 | 10 | 0.958357 | 0.522242 | 0.222878 |
\n",
- "train_flight_63737_00001 | TERMINATED | 10.0.0.3:46724 | 0.0027667 | 3 | 3 | 0.930057 | 10 | 1.11445 | 0.525986 | 0.219595 |
\n",
- "train_flight_63737_00002 | TERMINATED | 10.0.0.3:46700 | 0.00932612 | 3 | 1 | 0.532473 | 1 | 0.698576 | 0.53213 | 0.223699 |
\n",
- "train_flight_63737_00003 | TERMINATED | 10.0.0.3:46795 | 0.0807042 | 7 | 1 | 0.824932 | 10 | 1.27819 | 0.42436 | 0.176524 |
\n",
- "train_flight_63737_00004 | TERMINATED | 10.0.0.3:46796 | 0.0697454 | 1 | 2 | 0.908686 | 10 | 1.01485 | 0.516239 | 0.223466 |
\n",
- "train_flight_63737_00005 | TERMINATED | 10.0.0.3:46868 | 0.00334937 | 4 | 2 | 0.799064 | 10 | 0.983133 | 0.528863 | 0.223869 |
\n",
- "train_flight_63737_00006 | TERMINATED | 10.0.0.3:46932 | 0.00637837 | 5 | 2 | 0.555629 | 2 | 0.691448 | 0.528233 | 0.22136 |
\n",
- "train_flight_63737_00007 | TERMINATED | 10.0.0.3:46935 | 0.000145799 | 8 | 3 | 0.84289 | 1 | 0.668353 | 0.532382 | 0.223079 |
\n",
- "train_flight_63737_00008 | TERMINATED | 10.0.0.3:46959 | 0.0267405 | 5 | 1 | 0.766606 | 2 | 0.692802 | 0.520686 | 0.221159 |
\n",
- "train_flight_63737_00009 | TERMINATED | 10.0.0.3:46989 | 0.00848009 | 2 | 3 | 0.576874 | 2 | 0.610592 | 0.53193 | 0.223745 |
\n",
- "train_flight_63737_00010 | TERMINATED | 10.0.0.3:47125 | 0.0016903 | 8 | 3 | 0.824537 | 2 | 0.716938 | 0.532519 | 0.22407 |
\n",
- "train_flight_63737_00011 | TERMINATED | 10.0.0.3:47127 | 0.005344 | 7 | 1 | 0.921332 | 1 | 0.609434 | 0.532074 | 0.223993 |
\n",
- "train_flight_63737_00012 | TERMINATED | 10.0.0.3:47193 | 0.0956213 | 1 | 2 | 0.682057 | 8 | 0.791444 | 0.511592 | 0.219904 |
\n",
- "train_flight_63737_00013 | TERMINATED | 10.0.0.3:47196 | 0.00796245 | 5 | 2 | 0.570677 | 1 | 0.619144 | 0.531066 | 0.223172 |
\n",
- "train_flight_63737_00014 | TERMINATED | 10.0.0.3:47198 | 0.0106115 | 2 | 3 | 0.85295 | 1 | 0.582307 | 0.530977 | 0.222444 |
\n",
- "train_flight_63737_00015 | TERMINATED | 10.0.0.3:47200 | 0.0507297 | 1 | 1 | 0.720122 | 2 | 0.655333 | 0.527164 | 0.221283 |
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[33m(raylet)\u001b[0m Warning: The actor ImplicitFunc is very large (27 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "\u001b[36m(train_flight pid=46796)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05/train_flight_63737_00004_4_eta=0.0697,max_depth=1,min_child_weight=2,subsample=0.9087_2024-04-17_23-23-08/checkpoint_000000)\n",
- "2024-04-17 23:23:16,344\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05' in 0.0653s.\n",
- "2024-04-17 23:23:16,362\tINFO tune.py:1048 -- Total run time: 10.73 seconds (10.38 seconds for the tuning loop).\n"
- ]
- }
- ],
- "source": [
- "search_space = {\n",
- " \"max_depth\": tune.randint(1, 9),\n",
- " \"min_child_weight\": tune.choice([1, 2, 3]),\n",
- " \"subsample\": tune.uniform(0.5, 1.0),\n",
- " \"eta\": tune.loguniform(1e-4, 1e-1),\n",
- "}\n",
- "\n",
- "scheduler = AsyncHyperBandScheduler(\n",
- " max_t=10,\n",
- " grace_period=1,\n",
- " reduction_factor=2,\n",
- " brackets=3,\n",
- ")\n",
- "\n",
- "tuner = tune.Tuner(\n",
- " train_flight,\n",
- " tune_config=tune.TuneConfig(\n",
- " metric=\"eval-error\",\n",
- " mode=\"min\",\n",
- " scheduler=scheduler,\n",
- " num_samples=16,\n",
- " ),\n",
- " param_space=search_space,\n",
- ")\n",
- "results = tuner.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "`Tuner.fit()` 会将所有试验的结果返回成 `ResultGrid` ,也会把各类信息写到持久化存储上,用户可以查看不同超参数下的效果并进行分析和对比:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Best model parameters: {'max_depth': 7, 'min_child_weight': 1, 'subsample': 0.8249317015751376, 'eta': 0.08070421841931029, 'objective': 'binary:logistic', 'eval_metric': ['logloss', 'error']}\n",
- "Best model total accuracy: 0.8235\n"
- ]
- }
- ],
- "source": [
- "def get_best_model_checkpoint(results):\n",
- " best_result = results.get_best_result()\n",
- "\n",
- " # `TuneReportCheckpointCallback` 提供了从最优结果中返回 Checkpoint 的方法\n",
- " best_bst = TuneReportCheckpointCallback.get_model(best_result.checkpoint)\n",
- "\n",
- " accuracy = 1.0 - best_result.metrics[\"eval-error\"]\n",
- " print(f\"Best model parameters: {best_result.config}\")\n",
- " print(f\"Best model total accuracy: {accuracy:.4f}\")\n",
- " return best_bst\n",
- "\n",
- "best_bst = get_best_model_checkpoint(results)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " eval-error | \n",
- " training_iteration | \n",
- " config/max_depth | \n",
- " config/min_child_weight | \n",
- " config/subsample | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " 0.222878 | \n",
- " 10 | \n",
- " 2 | \n",
- " 3 | \n",
- " 0.963344 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " 0.219595 | \n",
- " 10 | \n",
- " 3 | \n",
- " 3 | \n",
- " 0.930057 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " 0.223699 | \n",
- " 1 | \n",
- " 3 | \n",
- " 1 | \n",
- " 0.532473 | \n",
- "
\n",
- " \n",
- " 3 | \n",
- " 0.176524 | \n",
- " 10 | \n",
- " 7 | \n",
- " 1 | \n",
- " 0.824932 | \n",
- "
\n",
- " \n",
- " 4 | \n",
- " 0.223466 | \n",
- " 10 | \n",
- " 1 | \n",
- " 2 | \n",
- " 0.908686 | \n",
- "
\n",
- " \n",
- " 5 | \n",
- " 0.223869 | \n",
- " 10 | \n",
- " 4 | \n",
- " 2 | \n",
- " 0.799064 | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " 0.221360 | \n",
- " 2 | \n",
- " 5 | \n",
- " 2 | \n",
- " 0.555629 | \n",
- "
\n",
- " \n",
- " 7 | \n",
- " 0.223079 | \n",
- " 1 | \n",
- " 8 | \n",
- " 3 | \n",
- " 0.842890 | \n",
- "
\n",
- " \n",
- " 8 | \n",
- " 0.221159 | \n",
- " 2 | \n",
- " 5 | \n",
- " 1 | \n",
- " 0.766606 | \n",
- "
\n",
- " \n",
- " 9 | \n",
- " 0.223745 | \n",
- " 2 | \n",
- " 2 | \n",
- " 3 | \n",
- " 0.576874 | \n",
- "
\n",
- " \n",
- " 10 | \n",
- " 0.224070 | \n",
- " 2 | \n",
- " 8 | \n",
- " 3 | \n",
- " 0.824537 | \n",
- "
\n",
- " \n",
- " 11 | \n",
- " 0.223993 | \n",
- " 1 | \n",
- " 7 | \n",
- " 1 | \n",
- " 0.921332 | \n",
- "
\n",
- " \n",
- " 12 | \n",
- " 0.219904 | \n",
- " 8 | \n",
- " 1 | \n",
- " 2 | \n",
- " 0.682057 | \n",
- "
\n",
- " \n",
- " 13 | \n",
- " 0.223172 | \n",
- " 1 | \n",
- " 5 | \n",
- " 2 | \n",
- " 0.570677 | \n",
- "
\n",
- " \n",
- " 14 | \n",
- " 0.222444 | \n",
- " 1 | \n",
- " 2 | \n",
- " 3 | \n",
- " 0.852950 | \n",
- "
\n",
- " \n",
- " 15 | \n",
- " 0.221283 | \n",
- " 2 | \n",
- " 1 | \n",
- " 1 | \n",
- " 0.720122 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " eval-error training_iteration config/max_depth config/min_child_weight \\\n",
- "0 0.222878 10 2 3 \n",
- "1 0.219595 10 3 3 \n",
- "2 0.223699 1 3 1 \n",
- "3 0.176524 10 7 1 \n",
- "4 0.223466 10 1 2 \n",
- "5 0.223869 10 4 2 \n",
- "6 0.221360 2 5 2 \n",
- "7 0.223079 1 8 3 \n",
- "8 0.221159 2 5 1 \n",
- "9 0.223745 2 2 3 \n",
- "10 0.224070 2 8 3 \n",
- "11 0.223993 1 7 1 \n",
- "12 0.219904 8 1 2 \n",
- "13 0.223172 1 5 2 \n",
- "14 0.222444 1 2 3 \n",
- "15 0.221283 2 1 1 \n",
- "\n",
- " config/subsample \n",
- "0 0.963344 \n",
- "1 0.930057 \n",
- "2 0.532473 \n",
- "3 0.824932 \n",
- "4 0.908686 \n",
- "5 0.799064 \n",
- "6 0.555629 \n",
- "7 0.842890 \n",
- "8 0.766606 \n",
- "9 0.576874 \n",
- "10 0.824537 \n",
- "11 0.921332 \n",
- "12 0.682057 \n",
- "13 0.570677 \n",
- "14 0.852950 \n",
- "15 0.720122 "
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "results_df = results.get_dataframe()\n",
- "results_df[[\"eval-error\", \"training_iteration\", \"config/max_depth\", \"config/min_child_weight\", \"config/subsample\"]]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 案例:基于 PBT 进行图像分类\n",
- "\n",
- "PBT 在训练过程中会对模型权重和超参数都进行调整,因此其训练代码部分必须有更新(加载)模型权重的代码。另外一个区别是训练迭代部分,大部分 PyTorch 训练过程都有 `for epoch in range(...)` 这样显式定义迭代训练的循环,循环一般有终止条件;PBT 训练过程不设置终止条件,当模型指标达到预期或者需要早停,Ray Tune 终止,因此训练迭代处使用 `while True` 一直循环迭代,直到被 Ray Tune 终止。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "tags": [
- "hide-cell"
- ]
- },
- "outputs": [],
- "source": [
- "data_dir = os.path.join(os.getcwd(), \"../data\")\n",
- "\n",
- "def train_func(model, optimizer, criterion, train_loader):\n",
- " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
- " model.train()\n",
- " for data, target in train_loader:\n",
- " data, target = data.to(device), target.to(device)\n",
- " output = model(data)\n",
- " loss = criterion(output, target)\n",
- " optimizer.zero_grad()\n",
- " loss.backward()\n",
- " optimizer.step()\n",
- "\n",
- "\n",
- "def test_func(model, data_loader):\n",
- " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
- " model.eval()\n",
- " correct = 0\n",
- " total = 0\n",
- " with torch.no_grad():\n",
- " for data, target in data_loader:\n",
- " data, target = data.to(device), target.to(device)\n",
- " outputs = model(data)\n",
- " _, predicted = torch.max(outputs.data, 1)\n",
- " total += target.size(0)\n",
- " correct += (predicted == target).sum().item()\n",
- "\n",
- " return correct / total"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "def train_mnist(config):\n",
- " step = 1\n",
- " transform = torchvision.transforms.Compose(\n",
- " [torchvision.transforms.ToTensor(), \n",
- " torchvision.transforms.Normalize((0.5,), (0.5,))]\n",
- " )\n",
- "\n",
- " train_loader = DataLoader(\n",
- " torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),\n",
- " batch_size=128,\n",
- " shuffle=True)\n",
- " test_loader = DataLoader(\n",
- " torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),\n",
- " batch_size=128,\n",
- " shuffle=True)\n",
- "\n",
- " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
- "\n",
- " model = resnet18(num_classes=10)\n",
- " model.conv1 = torch.nn.Conv2d(\n",
- " 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n",
- " )\n",
- " model.to(device)\n",
- "\n",
- " criterion = nn.CrossEntropyLoss()\n",
- "\n",
- " optimizer = torch.optim.SGD(\n",
- " model.parameters(), \n",
- " lr=config.get(\"lr\", 0.01), \n",
- " momentum=config.get(\"momentum\", 0.9)\n",
- " )\n",
- "\n",
- " checkpoint = ray.train.get_checkpoint()\n",
- " if checkpoint:\n",
- " with checkpoint.as_directory() as checkpoint_dir:\n",
- " checkpoint_dict = torch.load(os.path.join(checkpoint_dir, \"checkpoint.pt\"))\n",
- " \n",
- " model.load_state_dict(checkpoint_dict[\"model_state_dict\"])\n",
- " optimizer.load_state_dict(checkpoint_dict[\"optimizer_state_dict\"])\n",
- " \n",
- " # 将 config 传进来的 lr 和 momentum 更新到优化器中 \n",
- " for param_group in optimizer.param_groups:\n",
- " if \"lr\" in config:\n",
- " param_group[\"lr\"] = config[\"lr\"]\n",
- " if \"momentum\" in config:\n",
- " param_group[\"momentum\"] = config[\"momentum\"]\n",
- " \n",
- " last_step = checkpoint_dict[\"step\"]\n",
- " step = last_step + 1\n",
- " \n",
- " # Ray Tune 会根据性能指标终止试验\n",
- " while True:\n",
- " train_func(model, optimizer, criterion, train_loader)\n",
- " acc = test_func(model, test_loader)\n",
- " metrics = {\"mean_accuracy\": acc, \"lr\": config[\"lr\"]}\n",
- "\n",
- " if step % config[\"checkpoint_interval\"] == 0:\n",
- " with tempfile.TemporaryDirectory() as tmpdir:\n",
- " torch.save(\n",
- " {\n",
- " \"step\": step,\n",
- " \"model_state_dict\": model.state_dict(),\n",
- " \"optimizer_state_dict\": optimizer.state_dict(),\n",
- " },\n",
- " os.path.join(tmpdir, \"checkpoint.pt\"),\n",
- " )\n",
- " ray.train.report(metrics, checkpoint=ray.train.Checkpoint.from_directory(tmpdir))\n",
- " else:\n",
- " ray.train.report(metrics)\n",
- "\n",
- " step += 1"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接下来使用 [PopulationBasedTraining](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.PopulationBasedTraining.html) 定义 PBT 调度器。`time_attr` 跟刚才提到的其他调度器一样,是一个时间单位。`perturbation_interval` 表示每隔一定时间对超参数进行一些变异扰动,生成新的超参数,通常与 `checkpoint_interval` 使用同一个值,因为超参数变异扰动的同时也将 Checkpoint 写入持久化存储,会带来额外的开销,因此这个值不宜设置得过频繁。PBT 算法从 `hyperparam_mutations` 里选择可能变异的值,`hyperparam_mutations` 是一个键值字典,里面的内容就是变异值。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [],
- "source": [
- "perturbation_interval = 5\n",
- "scheduler = PopulationBasedTraining(\n",
- " time_attr=\"training_iteration\",\n",
- " perturbation_interval=perturbation_interval,\n",
- " metric=\"mean_accuracy\",\n",
- " mode=\"max\",\n",
- " hyperparam_mutations={\n",
- " \"lr\": tune.uniform(0.0001, 1),\n",
- " \"momentum\": [0.8, 0.9, 0.99],\n",
- " },\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "接下来就可以进行训练了。我们需要给 PBT 设置停止的条件,本例是 `mean_accuracy` 达到 0.9 或一共完成 20 次迭代。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "
\n",
- "
\n",
- "
Tune Status
\n",
- "
\n",
- "\n",
- "Current time: | 2024-04-17 18:09:24 |
\n",
- "Running for: | 00:07:35.75 |
\n",
- "Memory: | 16.7/90.0 GiB |
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "
\n",
- "
System Info
\n",
- " PopulationBasedTraining: 9 checkpoints, 1 perturbs
Logical resource usage: 0/64 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
Trial Status
\n",
- "
\n",
- "\n",
- "Trial name | status | loc | lr | momentum | acc | iter | total time (s) | lr |
\n",
- "\n",
- "\n",
- "train_mnist_817a7_00000 | TERMINATED | 10.0.0.3:26907 | 0.291632 | 0.578225 | 0.901 | 7 | 163.06 | 0.291632 |
\n",
- "train_mnist_817a7_00001 | TERMINATED | 10.0.0.3:26904 | 0.63272 | 0.94472 | 0.0996 | 20 | 446.483 | 0.63272 |
\n",
- "train_mnist_817a7_00002 | TERMINATED | 10.0.0.3:26903 | 0.615735 | 0.0790379 | 0.901 | 9 | 219.548 | 0.615735 |
\n",
- "train_mnist_817a7_00003 | TERMINATED | 10.0.0.3:26906 | 0.127736 | 0.486793 | 0.9084 | 8 | 181.952 | 0.127736 |
\n",
- "\n",
- "
\n",
- "
\n",
- "
\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-04-17 18:03:53,880\tINFO pbt.py:716 -- [pbt]: no checkpoint for trial train_mnist_817a7_00003. Skip exploit for Trial train_mnist_817a7_00001\n",
- "2024-04-17 18:09:24,486\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n",
- "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n",
- "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n",
- "2024-04-17 18:09:24,492\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/pbt_mnist' in 0.0111s.\n",
- "2024-04-17 18:09:24,501\tINFO tune.py:1048 -- Total run time: 455.82 seconds (455.74 seconds for the tuning loop).\n"
- ]
- }
- ],
- "source": [
- "tuner = tune.Tuner(\n",
- " tune.with_resources(train_mnist, {\"gpu\": 1}),\n",
- " run_config=ray.train.RunConfig(\n",
- " name=\"pbt_mnist\",\n",
- " # 停止条件:`stop` 或者 `training_iteration` 两个条件任一先达到\n",
- " stop={\"mean_accuracy\": 0.9, \"training_iteration\": 20},\n",
- " checkpoint_config=ray.train.CheckpointConfig(\n",
- " checkpoint_score_attribute=\"mean_accuracy\",\n",
- " num_to_keep=4,\n",
- " ),\n",
- " storage_path=\"~/ray_results\",\n",
- " ),\n",
- " tune_config=tune.TuneConfig(\n",
- " scheduler=scheduler,\n",
- " num_samples=4,\n",
- " ),\n",
- " param_space={\n",
- " \"lr\": tune.uniform(0.001, 1),\n",
- " \"momentum\": tune.uniform(0.001, 1),\n",
- " \"checkpoint_interval\": perturbation_interval,\n",
- " },\n",
- ")\n",
- "\n",
- "results_grid = tuner.fit()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "调优之后,就可以查看不同超参数的结果了,我们选择最优的那个结果,查看 `lr` 的变化过程。"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Best result path: /home/u20200002/ray_results/pbt_mnist/train_mnist_817a7_00003_3_lr=0.1277,momentum=0.4868_2024-04-17_18-01-48\n",
- "Best final iteration hyperparameter config:\n",
- " {'lr': 0.1277359940819796, 'momentum': 0.48679312797681595, 'checkpoint_interval': 5}\n"
- ]
- },
- {
- "data": {
- "image/svg+xml": [
- "\n",
- "\n",
- "\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "%config InlineBackend.figure_format = 'svg'\n",
- "\n",
- "best_result = results_grid.get_best_result(metric=\"mean_accuracy\", mode=\"max\")\n",
- "\n",
- "print('Best result path:', best_result.path)\n",
- "print(\"Best final iteration hyperparameter config:\\n\", best_result.config)\n",
- "\n",
- "df = best_result.metrics_dataframe\n",
- "df = df.drop_duplicates(subset=\"training_iteration\", keep=\"last\")\n",
- "df.plot(\"training_iteration\", \"mean_accuracy\")\n",
- "plt.xlabel(\"Training Iterations\")\n",
- "plt.ylabel(\"Test Accuracy\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.7"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}