diff --git a/ch-ray-ml/tune-algorithm-scheduler.ipynb b/ch-ray-ml/tune-algorithm-scheduler.ipynb deleted file mode 100644 index ea50f31..0000000 --- a/ch-ray-ml/tune-algorithm-scheduler.ipynb +++ /dev/null @@ -1,1907 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "(sec-tune-algorithm-scheduler)=\n", - "# 超参数调优搜索算法和调度器\n", - "\n", - "Ray Tune 的超参数搜索中比较重要的概念是搜索算法和调度器:搜索算法确定如何从搜索空间中选择新的超参数组合(即试验);调度器决定对提前结束一些不太有前景的试验,节省计算资源。搜索算法是必须的,调度器不是必须的。这两者可以协作来选择超参数,比如使用随机搜索算法和异步连续减半算法(Async Successive Halving Algorithm,ASHA)调度器,调度器对一些看起来没希望的试验提前结束。另外,一些超参数优化的包通常提供了封装好的搜索算法,有的还提供了调度器,这些包有自己的使用方式和习惯,Ray Tune 对这些包进行了封装,尽量使得这些包的使用方式统一。下面简单介绍一些常见的搜索算法和调度器。\n", - "\n", - "## Hyperband\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "import sys\n", - "sys.path.append(\"..\")\n", - "from utils import nyc_flights\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "import torch\n", - "import torchvision\n", - "import torch.nn as nn\n", - "from torch.utils.data import DataLoader\n", - "from torchvision.models import resnet18\n", - "\n", - "import ray\n", - "from sklearn.model_selection import train_test_split\n", - "from ray.tune.search.hyperopt import HyperOptSearch\n", - "import xgboost as xgb\n", - "from ray import tune\n", - "from ray.tune.schedulers import AsyncHyperBandScheduler\n", - "from ray.tune.integration.xgboost import TuneReportCheckpointCallback\n", - "from ray.tune.schedulers import PopulationBasedTraining\n", - "\n", - "folder_path = nyc_flights()\n", - "file_path = os.path.join(folder_path, \"nyc-flights\", \"1991.csv\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "读取数据,进行必要的数据预处理:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "input_cols = [\n", - " \"Year\",\n", - " \"Month\",\n", - " \"DayofMonth\",\n", - " \"DayOfWeek\",\n", - " \"CRSDepTime\",\n", - " \"CRSArrTime\",\n", - " \"UniqueCarrier\",\n", - " \"FlightNum\",\n", - " \"ActualElapsedTime\",\n", - " \"Origin\",\n", - " \"Dest\",\n", - " \"Distance\",\n", - " \"Diverted\",\n", - " \"ArrDelay\",\n", - "]\n", - "\n", - "df = pd.read_csv(file_path, usecols=input_cols,)\n", - "\n", - "# 预测是否延误\n", - "df[\"ArrDelayBinary\"] = 1.0 * (df[\"ArrDelay\"] > 10)\n", - "\n", - "df = df[df.columns.difference([\"ArrDelay\"])]\n", - "\n", - "for col in df.select_dtypes([\"object\"]).columns:\n", - " df[col] = df[col].astype(\"category\").cat.codes.astype(np.int32)\n", - "\n", - "for col in df.columns:\n", - " df[col] = df[col].astype(np.float32)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "XGBoost `train()` 函数的 `params` 参数接收树深度等超参数。需要注意的是,XGBoost 等训练框架提供的 `train()` 函数不像 PyTorch 那样有 `for epoch in range(...)` 这样的显式迭代训练过程,如果希望每次训练迭代后立即反馈性能指标,需要在 `train()` 的 `callbacks` 中传入回调函数,Ray 提供了 [`TuneReportCheckpointCallback`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.integration.xgboost.TuneReportCheckpointCallback.html),这个回调函数会在每次训练迭代后将相关指标报告给 Ray Tune。具体到本例中,XGBoost 的 `train()` 函数的 `params` 参数传入了 `\"eval_metric\": [\"logloss\", \"error\"]`,表示评估时的指标; `evals=[(test_set, \"eval\")]` 表示只关注验证集的指标;以上两者合起来,表示对验证集计算 `logloss` 和 `error` 指标,汇报给 Ray Tune 时,指标名称为 `eval-logloss` 和 `eval-error`。" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def train_flight(config: dict):\n", - " config.update({\n", - " \"objective\": \"binary:logistic\",\n", - " \"eval_metric\": [\"logloss\", \"error\"]\n", - " })\n", - " _y_label = \"ArrDelayBinary\"\n", - " train_x, test_x, train_y, test_y = train_test_split(\n", - " df.loc[:, df.columns != _y_label], \n", - " df[_y_label], \n", - " test_size=0.25\n", - " )\n", - " \n", - " train_set = xgb.DMatrix(train_x, label=train_y)\n", - " test_set = xgb.DMatrix(test_x, label=test_y)\n", - " \n", - " xgb.train(\n", - " params=config,\n", - " dtrain=train_set,\n", - " evals=[(test_set, \"eval\")],\n", - " verbose_eval=False,\n", - " # 每次迭代后, `TuneReportCheckpointCallback` 将评估指标反馈给 Ray Tune\n", - " callbacks=[TuneReportCheckpointCallback(frequency=1)]\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "我们底层使用 `hyperopt` 包所提供的贝叶斯优化搜索算法,如果没安装这个包,请先安装:`pip install hyperopt`。这些包通常有自己的定义搜索空间格式,用户也可以直接使用 Ray Tune 提供的搜索空间定义方式。\n", - "\n", - "调度器方面,我们使用 HyperBand 调度算法。[`AsyncHyperBandScheduler`](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.AsyncHyperBandScheduler.html) 是 Ray Tune 推荐的 HyperBand 算法的实现,它是异步的,能够更充分利用计算资源。`AsyncHyperBandScheduler` 中 `time_attr` 是描述训练时间的单位,默认为 `training_iteration`,表示一次训练迭代周期,`time_attr` 是计算资源额度的基本时间单位。`AsyncHyperBandScheduler` 的其他参数与 `time_attr` 规定的时间单位高度相关,比如 `max_t` 是每个试验所能获得的总时间,即一个试验最多能获得 `max_t` * `time_attr` 的计算资源额度;`grace_period` 表示至少给每个试验 `grace_period` * `time_attr` 的计算资源额度。`reduction_factor` 是上述数学描述中的 $\\eta$,`brackets` 为 HyperBand 算法所涉及的组合的概念。" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - "
\n", - "

Tune Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Current time:2024-04-17 23:23:16
Running for: 00:00:10.45
Memory: 12.6/90.0 GiB
\n", - "
\n", - "
\n", - "
\n", - "

System Info

\n", - " Using AsyncHyperBand: num_stopped=16
Bracket: Iter 8.000: -0.2197494153541173 | Iter 4.000: -0.21991977574377797 | Iter 2.000: -0.2211587603958556 | Iter 1.000: -0.22190215118710216
Bracket: Iter 8.000: -0.2228778516006133 | Iter 4.000: -0.2228778516006133 | Iter 2.000: -0.22374514085706762
Bracket: Iter 8.000: -0.2238690393222754 | Iter 4.000: -0.2238690393222754
Logical resource usage: 1.0/64 CPUs, 0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", - "
\n", - " \n", - "
\n", - "
\n", - "
\n", - "

Trial Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Trial name status loc eta max_depth min_child_weight subsample iter total time (s) eval-logloss eval-error
train_flight_63737_00000TERMINATED10.0.0.3:467650.0160564 2 3 0.963344 10 0.958357 0.522242 0.222878
train_flight_63737_00001TERMINATED10.0.0.3:467240.0027667 3 3 0.930057 10 1.11445 0.525986 0.219595
train_flight_63737_00002TERMINATED10.0.0.3:467000.00932612 3 1 0.532473 1 0.698576 0.53213 0.223699
train_flight_63737_00003TERMINATED10.0.0.3:467950.0807042 7 1 0.824932 10 1.27819 0.42436 0.176524
train_flight_63737_00004TERMINATED10.0.0.3:467960.0697454 1 2 0.908686 10 1.01485 0.516239 0.223466
train_flight_63737_00005TERMINATED10.0.0.3:468680.00334937 4 2 0.799064 10 0.983133 0.528863 0.223869
train_flight_63737_00006TERMINATED10.0.0.3:469320.00637837 5 2 0.555629 2 0.691448 0.528233 0.22136
train_flight_63737_00007TERMINATED10.0.0.3:469350.000145799 8 3 0.84289 1 0.668353 0.532382 0.223079
train_flight_63737_00008TERMINATED10.0.0.3:469590.0267405 5 1 0.766606 2 0.692802 0.520686 0.221159
train_flight_63737_00009TERMINATED10.0.0.3:469890.00848009 2 3 0.576874 2 0.610592 0.53193 0.223745
train_flight_63737_00010TERMINATED10.0.0.3:471250.0016903 8 3 0.824537 2 0.716938 0.532519 0.22407
train_flight_63737_00011TERMINATED10.0.0.3:471270.005344 7 1 0.921332 1 0.609434 0.532074 0.223993
train_flight_63737_00012TERMINATED10.0.0.3:471930.0956213 1 2 0.682057 8 0.791444 0.511592 0.219904
train_flight_63737_00013TERMINATED10.0.0.3:471960.00796245 5 2 0.570677 1 0.619144 0.531066 0.223172
train_flight_63737_00014TERMINATED10.0.0.3:471980.0106115 2 3 0.85295 1 0.582307 0.530977 0.222444
train_flight_63737_00015TERMINATED10.0.0.3:472000.0507297 1 1 0.720122 2 0.655333 0.527164 0.221283
\n", - "
\n", - "
\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[33m(raylet)\u001b[0m Warning: The actor ImplicitFunc is very large (27 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[36m(train_flight pid=46796)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05/train_flight_63737_00004_4_eta=0.0697,max_depth=1,min_child_weight=2,subsample=0.9087_2024-04-17_23-23-08/checkpoint_000000)\n", - "2024-04-17 23:23:16,344\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/train_flight_2024-04-17_23-23-05' in 0.0653s.\n", - "2024-04-17 23:23:16,362\tINFO tune.py:1048 -- Total run time: 10.73 seconds (10.38 seconds for the tuning loop).\n" - ] - } - ], - "source": [ - "search_space = {\n", - " \"max_depth\": tune.randint(1, 9),\n", - " \"min_child_weight\": tune.choice([1, 2, 3]),\n", - " \"subsample\": tune.uniform(0.5, 1.0),\n", - " \"eta\": tune.loguniform(1e-4, 1e-1),\n", - "}\n", - "\n", - "scheduler = AsyncHyperBandScheduler(\n", - " max_t=10,\n", - " grace_period=1,\n", - " reduction_factor=2,\n", - " brackets=3,\n", - ")\n", - "\n", - "tuner = tune.Tuner(\n", - " train_flight,\n", - " tune_config=tune.TuneConfig(\n", - " metric=\"eval-error\",\n", - " mode=\"min\",\n", - " scheduler=scheduler,\n", - " num_samples=16,\n", - " ),\n", - " param_space=search_space,\n", - ")\n", - "results = tuner.fit()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "`Tuner.fit()` 会将所有试验的结果返回成 `ResultGrid` ,也会把各类信息写到持久化存储上,用户可以查看不同超参数下的效果并进行分析和对比:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Best model parameters: {'max_depth': 7, 'min_child_weight': 1, 'subsample': 0.8249317015751376, 'eta': 0.08070421841931029, 'objective': 'binary:logistic', 'eval_metric': ['logloss', 'error']}\n", - "Best model total accuracy: 0.8235\n" - ] - } - ], - "source": [ - "def get_best_model_checkpoint(results):\n", - " best_result = results.get_best_result()\n", - "\n", - " # `TuneReportCheckpointCallback` 提供了从最优结果中返回 Checkpoint 的方法\n", - " best_bst = TuneReportCheckpointCallback.get_model(best_result.checkpoint)\n", - "\n", - " accuracy = 1.0 - best_result.metrics[\"eval-error\"]\n", - " print(f\"Best model parameters: {best_result.config}\")\n", - " print(f\"Best model total accuracy: {accuracy:.4f}\")\n", - " return best_bst\n", - "\n", - "best_bst = get_best_model_checkpoint(results)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
eval-errortraining_iterationconfig/max_depthconfig/min_child_weightconfig/subsample
00.22287810230.963344
10.21959510330.930057
20.2236991310.532473
30.17652410710.824932
40.22346610120.908686
50.22386910420.799064
60.2213602520.555629
70.2230791830.842890
80.2211592510.766606
90.2237452230.576874
100.2240702830.824537
110.2239931710.921332
120.2199048120.682057
130.2231721520.570677
140.2224441230.852950
150.2212832110.720122
\n", - "
" - ], - "text/plain": [ - " eval-error training_iteration config/max_depth config/min_child_weight \\\n", - "0 0.222878 10 2 3 \n", - "1 0.219595 10 3 3 \n", - "2 0.223699 1 3 1 \n", - "3 0.176524 10 7 1 \n", - "4 0.223466 10 1 2 \n", - "5 0.223869 10 4 2 \n", - "6 0.221360 2 5 2 \n", - "7 0.223079 1 8 3 \n", - "8 0.221159 2 5 1 \n", - "9 0.223745 2 2 3 \n", - "10 0.224070 2 8 3 \n", - "11 0.223993 1 7 1 \n", - "12 0.219904 8 1 2 \n", - "13 0.223172 1 5 2 \n", - "14 0.222444 1 2 3 \n", - "15 0.221283 2 1 1 \n", - "\n", - " config/subsample \n", - "0 0.963344 \n", - "1 0.930057 \n", - "2 0.532473 \n", - "3 0.824932 \n", - "4 0.908686 \n", - "5 0.799064 \n", - "6 0.555629 \n", - "7 0.842890 \n", - "8 0.766606 \n", - "9 0.576874 \n", - "10 0.824537 \n", - "11 0.921332 \n", - "12 0.682057 \n", - "13 0.570677 \n", - "14 0.852950 \n", - "15 0.720122 " - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "results_df = results.get_dataframe()\n", - "results_df[[\"eval-error\", \"training_iteration\", \"config/max_depth\", \"config/min_child_weight\", \"config/subsample\"]]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 案例:基于 PBT 进行图像分类\n", - "\n", - "PBT 在训练过程中会对模型权重和超参数都进行调整,因此其训练代码部分必须有更新(加载)模型权重的代码。另外一个区别是训练迭代部分,大部分 PyTorch 训练过程都有 `for epoch in range(...)` 这样显式定义迭代训练的循环,循环一般有终止条件;PBT 训练过程不设置终止条件,当模型指标达到预期或者需要早停,Ray Tune 终止,因此训练迭代处使用 `while True` 一直循环迭代,直到被 Ray Tune 终止。" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "hide-cell" - ] - }, - "outputs": [], - "source": [ - "data_dir = os.path.join(os.getcwd(), \"../data\")\n", - "\n", - "def train_func(model, optimizer, criterion, train_loader):\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " model.train()\n", - " for data, target in train_loader:\n", - " data, target = data.to(device), target.to(device)\n", - " output = model(data)\n", - " loss = criterion(output, target)\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - "\n", - "def test_func(model, data_loader):\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - " model.eval()\n", - " correct = 0\n", - " total = 0\n", - " with torch.no_grad():\n", - " for data, target in data_loader:\n", - " data, target = data.to(device), target.to(device)\n", - " outputs = model(data)\n", - " _, predicted = torch.max(outputs.data, 1)\n", - " total += target.size(0)\n", - " correct += (predicted == target).sum().item()\n", - "\n", - " return correct / total" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "def train_mnist(config):\n", - " step = 1\n", - " transform = torchvision.transforms.Compose(\n", - " [torchvision.transforms.ToTensor(), \n", - " torchvision.transforms.Normalize((0.5,), (0.5,))]\n", - " )\n", - "\n", - " train_loader = DataLoader(\n", - " torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),\n", - " batch_size=128,\n", - " shuffle=True)\n", - " test_loader = DataLoader(\n", - " torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),\n", - " batch_size=128,\n", - " shuffle=True)\n", - "\n", - " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - " model = resnet18(num_classes=10)\n", - " model.conv1 = torch.nn.Conv2d(\n", - " 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n", - " )\n", - " model.to(device)\n", - "\n", - " criterion = nn.CrossEntropyLoss()\n", - "\n", - " optimizer = torch.optim.SGD(\n", - " model.parameters(), \n", - " lr=config.get(\"lr\", 0.01), \n", - " momentum=config.get(\"momentum\", 0.9)\n", - " )\n", - "\n", - " checkpoint = ray.train.get_checkpoint()\n", - " if checkpoint:\n", - " with checkpoint.as_directory() as checkpoint_dir:\n", - " checkpoint_dict = torch.load(os.path.join(checkpoint_dir, \"checkpoint.pt\"))\n", - " \n", - " model.load_state_dict(checkpoint_dict[\"model_state_dict\"])\n", - " optimizer.load_state_dict(checkpoint_dict[\"optimizer_state_dict\"])\n", - " \n", - " # 将 config 传进来的 lr 和 momentum 更新到优化器中 \n", - " for param_group in optimizer.param_groups:\n", - " if \"lr\" in config:\n", - " param_group[\"lr\"] = config[\"lr\"]\n", - " if \"momentum\" in config:\n", - " param_group[\"momentum\"] = config[\"momentum\"]\n", - " \n", - " last_step = checkpoint_dict[\"step\"]\n", - " step = last_step + 1\n", - " \n", - " # Ray Tune 会根据性能指标终止试验\n", - " while True:\n", - " train_func(model, optimizer, criterion, train_loader)\n", - " acc = test_func(model, test_loader)\n", - " metrics = {\"mean_accuracy\": acc, \"lr\": config[\"lr\"]}\n", - "\n", - " if step % config[\"checkpoint_interval\"] == 0:\n", - " with tempfile.TemporaryDirectory() as tmpdir:\n", - " torch.save(\n", - " {\n", - " \"step\": step,\n", - " \"model_state_dict\": model.state_dict(),\n", - " \"optimizer_state_dict\": optimizer.state_dict(),\n", - " },\n", - " os.path.join(tmpdir, \"checkpoint.pt\"),\n", - " )\n", - " ray.train.report(metrics, checkpoint=ray.train.Checkpoint.from_directory(tmpdir))\n", - " else:\n", - " ray.train.report(metrics)\n", - "\n", - " step += 1" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接下来使用 [PopulationBasedTraining](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.schedulers.PopulationBasedTraining.html) 定义 PBT 调度器。`time_attr` 跟刚才提到的其他调度器一样,是一个时间单位。`perturbation_interval` 表示每隔一定时间对超参数进行一些变异扰动,生成新的超参数,通常与 `checkpoint_interval` 使用同一个值,因为超参数变异扰动的同时也将 Checkpoint 写入持久化存储,会带来额外的开销,因此这个值不宜设置得过频繁。PBT 算法从 `hyperparam_mutations` 里选择可能变异的值,`hyperparam_mutations` 是一个键值字典,里面的内容就是变异值。" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "perturbation_interval = 5\n", - "scheduler = PopulationBasedTraining(\n", - " time_attr=\"training_iteration\",\n", - " perturbation_interval=perturbation_interval,\n", - " metric=\"mean_accuracy\",\n", - " mode=\"max\",\n", - " hyperparam_mutations={\n", - " \"lr\": tune.uniform(0.0001, 1),\n", - " \"momentum\": [0.8, 0.9, 0.99],\n", - " },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "接下来就可以进行训练了。我们需要给 PBT 设置停止的条件,本例是 `mean_accuracy` 达到 0.9 或一共完成 20 次迭代。" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - "
\n", - "

Tune Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Current time:2024-04-17 18:09:24
Running for: 00:07:35.75
Memory: 16.7/90.0 GiB
\n", - "
\n", - "
\n", - "
\n", - "

System Info

\n", - " PopulationBasedTraining: 9 checkpoints, 1 perturbs
Logical resource usage: 0/64 CPUs, 1.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)\n", - "
\n", - " \n", - "
\n", - "
\n", - "
\n", - "

Trial Status

\n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
Trial name status loc lr momentum acc iter total time (s) lr
train_mnist_817a7_00000TERMINATED10.0.0.3:269070.291632 0.578225 0.901 7 163.06 0.291632
train_mnist_817a7_00001TERMINATED10.0.0.3:269040.63272 0.94472 0.0996 20 446.4830.63272
train_mnist_817a7_00002TERMINATED10.0.0.3:269030.615735 0.07903790.901 9 219.5480.615735
train_mnist_817a7_00003TERMINATED10.0.0.3:269060.127736 0.486793 0.9084 8 181.9520.127736
\n", - "
\n", - "
\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-04-17 18:03:53,880\tINFO pbt.py:716 -- [pbt]: no checkpoint for trial train_mnist_817a7_00003. Skip exploit for Trial train_mnist_817a7_00001\n", - "2024-04-17 18:09:24,486\tWARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.\n", - "You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.\n", - "You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).\n", - "2024-04-17 18:09:24,492\tINFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/ray_results/pbt_mnist' in 0.0111s.\n", - "2024-04-17 18:09:24,501\tINFO tune.py:1048 -- Total run time: 455.82 seconds (455.74 seconds for the tuning loop).\n" - ] - } - ], - "source": [ - "tuner = tune.Tuner(\n", - " tune.with_resources(train_mnist, {\"gpu\": 1}),\n", - " run_config=ray.train.RunConfig(\n", - " name=\"pbt_mnist\",\n", - " # 停止条件:`stop` 或者 `training_iteration` 两个条件任一先达到\n", - " stop={\"mean_accuracy\": 0.9, \"training_iteration\": 20},\n", - " checkpoint_config=ray.train.CheckpointConfig(\n", - " checkpoint_score_attribute=\"mean_accuracy\",\n", - " num_to_keep=4,\n", - " ),\n", - " storage_path=\"~/ray_results\",\n", - " ),\n", - " tune_config=tune.TuneConfig(\n", - " scheduler=scheduler,\n", - " num_samples=4,\n", - " ),\n", - " param_space={\n", - " \"lr\": tune.uniform(0.001, 1),\n", - " \"momentum\": tune.uniform(0.001, 1),\n", - " \"checkpoint_interval\": perturbation_interval,\n", - " },\n", - ")\n", - "\n", - "results_grid = tuner.fit()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "调优之后,就可以查看不同超参数的结果了,我们选择最优的那个结果,查看 `lr` 的变化过程。" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Best result path: /home/u20200002/ray_results/pbt_mnist/train_mnist_817a7_00003_3_lr=0.1277,momentum=0.4868_2024-04-17_18-01-48\n", - "Best final iteration hyperparameter config:\n", - " {'lr': 0.1277359940819796, 'momentum': 0.48679312797681595, 'checkpoint_interval': 5}\n" - ] - }, - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " 2024-04-17T19:01:04.895229\n", - " image/svg+xml\n", - " \n", - " \n", - " Matplotlib v3.8.4, https://matplotlib.org/\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n" - ], - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "%config InlineBackend.figure_format = 'svg'\n", - "\n", - "best_result = results_grid.get_best_result(metric=\"mean_accuracy\", mode=\"max\")\n", - "\n", - "print('Best result path:', best_result.path)\n", - "print(\"Best final iteration hyperparameter config:\\n\", best_result.config)\n", - "\n", - "df = best_result.metrics_dataframe\n", - "df = df.drop_duplicates(subset=\"training_iteration\", keep=\"last\")\n", - "df.plot(\"training_iteration\", \"mean_accuracy\")\n", - "plt.xlabel(\"Training Iterations\")\n", - "plt.ylabel(\"Test Accuracy\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.7" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}