diff --git a/_toc.yml b/_toc.yml index dad8117..fd31798 100644 --- a/_toc.yml +++ b/_toc.yml @@ -27,6 +27,9 @@ subtrees: - file: ch-dask-dataframe/indexing - file: ch-dask-dataframe/map-partitions - file: ch-dask-dataframe/shuffle + - file: ch-dask-ml/index + entries: + - file: ch-dask-ml/distributed-training - file: ch-ray-core/index entries: - file: ch-ray-core/ray-intro diff --git a/ch-dask-dataframe/shuffle.ipynb b/ch-dask-dataframe/shuffle.ipynb index 964a9ed..9456359 100644 --- a/ch-dask-dataframe/shuffle.ipynb +++ b/ch-dask-dataframe/shuffle.ipynb @@ -11,9 +11,9 @@ "\n", "## Shuffle 实现机制\n", "\n", - "{numref}`sec-dask-task-graph` 介绍了 Dask 主要基于 Task Graph 构建, Dask 的 Task Graph 是一个有向无环图。有向边表示下游 Partition 的输入依赖上游 Partition 的输出,任何数据移动都会在 Task Graph 上生成一条有向边。很多计算任务的 Shuffle 有大量的数据移动,有的场景下,所有数据都会打散,也意味着上游会有多条指向下游的边。这种基于 Task Graph 的 Shuffle 会使得 Task Graph 非常庞大,Task Graph 过大会使得 Dask Scheduler 的负载过重,进一步导致计算极其缓慢。如 {numref}`fig-dask-shuffle` 左侧所示,`tasks` 是基于 Task Graph 的机制,上游和下游之间建立了有向边,如果有中间层(通常因为上游流入的数据太大,需要将数据进一步切分成多个 Partition),那么中间层还会进一步增加 Task Graph 的复杂程度。\n", + "{numref}`sec-dask-task-graph` 介绍了 Dask 主要基于 Task Graph 构建, Dask 的 Task Graph 是一个有向无环图。有向边表示下游 Partition 的输入依赖上游 Partition 的输出,任何数据移动都会在 Task Graph 上生成一条有向边。很多计算任务的 Shuffle 有大量的数据移动,有的场景下,所有数据都会打散,也意味着上游会有多条指向下游的边。这种基于 Task Graph 的 Shuffle 会使得 Task Graph 非常庞大,Task Graph 过大会使得 Dask Scheduler 的负载过重,进一步导致计算极其缓慢。如 {numref}`fig-shuffle-tasks-p2p` 左侧所示,`tasks` 是基于 Task Graph 的机制,上游和下游之间建立了有向边,如果有中间层(通常因为上游流入的数据太大,需要将数据进一步切分成多个 Partition),那么中间层还会进一步增加 Task Graph 的复杂程度。\n", "\n", - "为解决 Task Graph 过大的问题,Dask 设计了一种点对点(Peer-to-peer)的 Shuffle 机制。如 {numref}`fig-dask-shuffle` 右侧所示,`p2p` 在 Task Graph 中引入了一个虚拟的障碍(Barrier)节点。Barrier 并不是一个真正的 Task,引入 Barrier 节点可以使 Task Graph 复杂度显著下降。\n", + "为解决 Task Graph 过大的问题,Dask 设计了一种点对点(Peer-to-peer)的 Shuffle 机制。如 {numref}`fig-shuffle-tasks-p2p` 右侧所示,`p2p` 在 Task Graph 中引入了一个虚拟的障碍(Barrier)节点。Barrier 并不是一个真正的 Task,引入 Barrier 节点可以使 Task Graph 复杂度显著下降。\n", "\n", "```{figure} ../img/ch-dask-dataframe/shuffle-tasks-p2p.png\n", "---\n", @@ -26,7 +26,7 @@ "目前,Dask 提供了两类 Shuffle 实现策略:单机和分布式。\n", "\n", "* 单机。如果数据大小超出了内存空间,可以将中间数据写到磁盘上。单机场景默认使用这种策略。\n", - "* 分布式。如 {numref}`fig-dask-shuffle` 所示,分布式场景提供了两种 Shuffle 策略,`tasks` 和 `p2p`。`tasks` 是基于 Task Graph 的 Shuffle 实现,很多场景效率比较低,会遇到刚提到的 Task Graph 过大的问题。`p2p` 基于点对点的 Shuffle 实现,Task Graph 的复杂性显著降低,性能也显著提升。Dask 会优先选择 `p2p`。\n", + "* 分布式。如 {numref}`fig-shuffle-tasks-p2p` 所示,分布式场景提供了两种 Shuffle 策略,`tasks` 和 `p2p`。`tasks` 是基于 Task Graph 的 Shuffle 实现,很多场景效率比较低,会遇到刚提到的 Task Graph 过大的问题。`p2p` 基于点对点的 Shuffle 实现,Task Graph 的复杂性显著降低,性能也显著提升。Dask 会优先选择 `p2p`。\n", "\n", "`dask.config.set({\"dataframe.shuffle.method\": \"p2p\"})` 对当前 Python 脚本的所有计算都使用 `p2p` 方式进行 Shuffle。也可以针对某个算子设置 Shuffle 策略,比如 `ddf.merge(shuffle_method=\"p2p\")`。\n", "\n", diff --git a/ch-dask-ml/distributed-training.ipynb b/ch-dask-ml/distributed-training.ipynb new file mode 100644 index 0000000..f70fbb3 --- /dev/null +++ b/ch-dask-ml/distributed-training.ipynb @@ -0,0 +1,3583 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "(sec-dask-ml-distributed-training)=\n", + "# 分布式机器学习\n", + "\n", + "如果训练数据量很大,Dask-ML 提供了分布式机器学习功能,可以在集群上对大数据进行训练。目前,Dask 提供了两类分布式机器学习 API:\n", + "\n", + "* scikit-learn:与 scikit-learn 的调用方式类似\n", + "* XGBoost 和 LightGBM:与 XGBoost 和 LightGBM 的调用方式类似\n", + "\n", + "## scikit-learn API\n", + "\n", + "基于 Dask Array、Dask DataFrame 和 Dask Delayed 提供的分布式计算能力,参考 scikit-learn,Dask-ML 对机器学习算法做了分布式的实现,比如 `dask_ml.linear_model` 中的线性回归 [`LinearRegression`](https://ml.dask.org/modules/generated/dask_ml.linear_model.LinearRegression.html)、逻辑回归 [`LogisticRegression`](https://ml.dask.org/modules/generated/dask_ml.linear_model.LogisticRegression.html),`dask_ml.cluster` 中的 [`KMeans`](https://ml.dask.org/modules/generated/dask_ml.cluster.KMeans.html)。Dask-ML 尽量保持这些机器学习算法的使用方法与 scikit-learn 一致。\n", + "\n", + "在一个 2 节点组成的 Dask 集群上使用 `dask_ml.linear_model` 中线性模型。这个集群每个节点有 90GiB 内存,我们随机生成一个 37GiB 的数据集,并将其切分为训练集和测试集。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "%config InlineBackend.figure_format = 'svg'\n", + "import time\n", + "\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "\n", + "from dask.distributed import Client, LocalCluster" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dask_ml.datasets\n", + "import sklearn.linear_model\n", + "import dask_ml.linear_model\n", + "from dask_ml.model_selection import train_test_split" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-ad77e682-0ae4-11ef-8730-000012e4fe80

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Direct
\n", + " Dashboard: http://10.0.0.3:43549/status\n", + "
\n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "

Scheduler Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-c7851ab9-9963-4c85-b394-bb74e8e2967f

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.0.0.3:8786\n", + " \n", + " Workers: 2\n", + "
\n", + " Dashboard: http://10.0.0.3:43549/status\n", + " \n", + " Total threads: 128\n", + "
\n", + " Started: 5 hours ago\n", + " \n", + " Total memory: 180.00 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.2:46501

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.2:46501\n", + " \n", + " Total threads: 64\n", + "
\n", + " Dashboard: http://10.0.0.2:42539/status\n", + " \n", + " Memory: 90.00 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.2:40241\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-rxylv59_\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 6.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 301.68 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 572.9739612289254 B\n", + " \n", + " Write bytes: 1.71 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: tcp://10.0.0.3:39997

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.0.0.3:39997\n", + " \n", + " Total threads: 64\n", + "
\n", + " Dashboard: http://10.0.0.3:40955/status\n", + " \n", + " Memory: 90.00 GiB\n", + "
\n", + " Nanny: tcp://10.0.0.3:34825\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-kdphx4zv\n", + "
\n", + " Tasks executing: \n", + " \n", + " Tasks in memory: \n", + "
\n", + " Tasks ready: \n", + " \n", + " Tasks in flight: \n", + "
\n", + " CPU usage: 4.0%\n", + " \n", + " Last seen: Just now\n", + "
\n", + " Memory usage: 300.18 MiB\n", + " \n", + " Spilled bytes: 0 B\n", + "
\n", + " Read bytes: 8.27 kiB\n", + " \n", + " Write bytes: 10.57 kiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = Client(\"10.0.0.3:8786\")\n", + "client" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Array Chunk
Bytes 37.25 GiB 381.47 MiB
Shape (10000000, 500) (100000, 500)
Dask graph 100 chunks in 1 graph layer
Data type float64 numpy.ndarray
\n", + "
\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + "\n", + " \n", + " 500\n", + " 10000000\n", + "\n", + "
" + ], + "text/plain": [ + "dask.array" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X, y = dask_ml.datasets.make_classification(n_samples=10_000_000, \n", + " n_features=500, \n", + " random_state=42,\n", + " chunks=10_000_000 // 100\n", + ")\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n", + "X" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "调用 `fit` 方法(与 scikit-learn 类似):" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "lr = dask_ml.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X_train, y_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "训练好的模型可以用来预测(`predict`),也可以计算准确度(`score`)。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ True, False, True, True, True])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_predicted = lr.predict(X_test)\n", + "y_predicted[:5].compute()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.668674" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lr.score(X_test, y_test).compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "如果在单机的 scikit-learn 上使用同样大小的数据训练模型,会因为内存不足而报错。\n", + "\n", + "尽管 Dask-ML 这种分布式训练的 API 与 scikit-learn 极其相似,scikit-learn 只能使用单核,Dask-ML 可以使用多核甚至集群,但并不意味着所有场景下都选择 Dask-ML,因为有些时候 Dask-ML 并非性能或性价比最优的选择。这一点与 Dask DataFrame 和 pandas 关系一样,如果数据量能放进单机内存,原生的 pandas 、NumPy 和 scikit-learn 的性能和兼容性总是最优的。\n", + "\n", + "下面的代码对不同规模的训练数据进行了性能分析,在单机多核且数据量较小的场景,Dask-ML 的性能并不比 scikit-learn 更快。主要因为:很多机器学习算法是迭代式的,scikit-learn 中,迭代式算法使用 Python 原生 `for` 循环来实现;Dask-ML 参考了这种 `for` 循环,但对于 Dask 的 Task Graph 来说,`for` 循环会使得 Task Graph 很臃肿,执行效率并不是很高。\n", + "\n", + "你也可以根据你所拥有的内存来测试一下性能。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-b4f64c31-0ae4-11ef-8730-000012e4fe80

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", + " Dashboard: http://127.0.0.1:8787/status\n", + "
\n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

LocalCluster

\n", + "

1872fd25

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\n", + " \n", + "
\n", + " Dashboard: http://127.0.0.1:8787/status\n", + " \n", + " Workers: 8\n", + "
\n", + " Total threads: 64\n", + " \n", + " Total memory: 90.00 GiB\n", + "
Status: runningUsing processes: True
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-897dca6c-6012-4df7-9a10-bd08f8810617

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://127.0.0.1:38477\n", + " \n", + " Workers: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:8787/status\n", + " \n", + " Total threads: 64\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 90.00 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 0

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:44219\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:36081/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:34355\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-439c1uaa\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 1

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:41549\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:44857/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:41265\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-hyxlvh30\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 2

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:42877\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:40235/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:40939\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-e70v3hq2\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 3

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:34321\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:40295/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:35007\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-udlmb2zo\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 4

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:36039\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:45691/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:34883\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-g5h5ob4b\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 5

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:35057\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:43309/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:43945\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-61hsl1ap\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 6

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:36811\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:44197/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:44607\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-syjczr8e\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 7

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:42081\n", + " \n", + " Total threads: 8\n", + "
\n", + " Dashboard: http://127.0.0.1:35819/status\n", + " \n", + " Memory: 11.25 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:33971\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-1rw7_3km\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = Client(LocalCluster())\n", + "client" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-05-05T21:42:17.956299\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num_sample = [500_000, 1_000_000, 1_500_000]\n", + "num_feature = 1_000\n", + "timings = []\n", + "\n", + "for n in num_sample:\n", + " X, y = dask_ml.datasets.make_classification(n_samples=n, \n", + " n_features=num_feature, \n", + " random_state=42,\n", + " chunks=n // 10\n", + " )\n", + " t1 = time.time()\n", + " sklearn.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X, y)\n", + " timings.append(('scikit-learn', n, time.time() - t1))\n", + " t1 = time.time()\n", + " dask_ml.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X, y)\n", + " timings.append(('dask-ml', n, time.time() - t1))\n", + "\n", + "df = pd.DataFrame(timings, columns=['method', '# of samples', 'time'])\n", + "sns.barplot(data=df, x='# of samples', y='time', hue='method')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可见,在逻辑回归这个场景上,比起 scikit-learn,Dask-ML 在单机多核上并无太多优势。而且很多传统机器学习算法对训练数据量的要求没那么高,随着训练数据的增加,传统的机器学习算法的性能不会显著增加。训练数据量和模型性能之间的关系可以通过学习曲线(Learning Curves)来可视化,随着训练数据量增加,像朴素贝叶斯等算法的性能提升十分有限。如果一些机器学习算法无法进行分布式训练或分布式训练成本很高,可以考虑对训练数据采样,数据大小能够放进单机内存,使用 scikit-learn 这种单机框架训练。\n", + "\n", + "综上,如果有一个超出单机内存的训练数据,要根据问题特点、所使用的算法和成本等多方面因素来决定使用何种方式处理。\n", + "\n", + "## XGBoost 和 LightGBM\n", + "\n", + "XGBoost 和 LightGBM 是两种决策树模型的实现,他们本身就对分布式训练友好,且集成了 Dask 的分布式能力。下面以 XGBoost 为例,介绍 XGBoost 如何基于 Dask 实现分布式训练,LightGBM 与之类似。\n", + "\n", + "在 XGBoost 中,训练一个模型既可以使用 `train` 方法,也可以使用 scikit-learn 式的 `fit` 方法。这两种方式都支持 Dask 分布式训练。\n", + "\n", + "下面的代码对单机的 XGBoost 和 Dask 分布式训练两种方式进行了性能对比。如果使用 Dask,需要将 [`xgboost.DMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.DMatrix) 修改为 [`xgboost.dask.DaskDMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.DaskDMatrix),将 [`xgboost.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train) 修改为 [`xgboost.dask.train`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.train);并传入 Dask 集群客户端 `client`。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n", + "[22:13:43] task [xgboost.dask-0]:tcp://127.0.0.1:44219 got new rank 0\n", + "[22:13:43] task [xgboost.dask-1]:tcp://127.0.0.1:41549 got new rank 1\n", + "[22:13:43] task [xgboost.dask-2]:tcp://127.0.0.1:42877 got new rank 2\n", + "[22:13:43] task [xgboost.dask-3]:tcp://127.0.0.1:34321 got new rank 3\n", + "[22:13:43] task [xgboost.dask-4]:tcp://127.0.0.1:36039 got new rank 4\n", + "[22:13:43] task [xgboost.dask-5]:tcp://127.0.0.1:35057 got new rank 5\n", + "[22:13:43] task [xgboost.dask-6]:tcp://127.0.0.1:36811 got new rank 6\n", + "[22:13:43] task [xgboost.dask-7]:tcp://127.0.0.1:42081 got new rank 7\n", + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " warnings.warn(\n", + "[22:16:27] task [xgboost.dask-0]:tcp://127.0.0.1:44219 got new rank 0\n", + "[22:16:27] task [xgboost.dask-1]:tcp://127.0.0.1:41549 got new rank 1\n", + "[22:16:27] task [xgboost.dask-2]:tcp://127.0.0.1:42877 got new rank 2\n", + "[22:16:27] task [xgboost.dask-3]:tcp://127.0.0.1:34321 got new rank 3\n", + "[22:16:27] task [xgboost.dask-4]:tcp://127.0.0.1:36039 got new rank 4\n", + "[22:16:27] task [xgboost.dask-5]:tcp://127.0.0.1:35057 got new rank 5\n", + "[22:16:28] task [xgboost.dask-6]:tcp://127.0.0.1:36811 got new rank 6\n", + "[22:16:28] task [xgboost.dask-7]:tcp://127.0.0.1:42081 got new rank 7\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " 2024-05-05T22:16:40.172826\n", + " image/svg+xml\n", + " \n", + " \n", + " Matplotlib v3.8.4, https://matplotlib.org/\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n" + ], + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import xgboost as xgb\n", + "\n", + "num_sample = [100_000, 500_000]\n", + "num_feature = 1_000\n", + "xgb_timings = []\n", + "\n", + "for n in num_sample:\n", + " X, y = dask_ml.datasets.make_classification(n_samples=n, \n", + " n_features=num_feature, \n", + " random_state=42,\n", + " chunks=n // 10\n", + " )\n", + " dtrain = xgb.DMatrix(X, y)\n", + " t1 = time.time()\n", + " xgb.train(\n", + " {\"tree_method\": \"hist\", \"objective\": \"binary:hinge\"},\n", + " dtrain,\n", + " num_boost_round=4,\n", + " evals=[(dtrain, \"train\")],\n", + " verbose_eval=False,\n", + " )\n", + " xgb_timings.append(('xgboost', n, time.time() - t1))\n", + " dtrain_dask = xgb.dask.DaskDMatrix(client, X, y)\n", + " t1 = time.time()\n", + " xgb.dask.train(\n", + " client,\n", + " {\"tree_method\": \"hist\", \"objective\": \"binary:hinge\"},\n", + " dtrain_dask,\n", + " num_boost_round=4,\n", + " evals=[(dtrain_dask, \"train\")],\n", + " verbose_eval=False,\n", + " )\n", + " xgb_timings.append(('dask-ml', n, time.time() - t1))\n", + "\n", + "df = pd.DataFrame(xgb_timings, columns=['method', '# of samples', 'time'])\n", + "sns.barplot(data=df, x='# of samples', y='time', hue='method')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "如果是 XGBoost 的 scikit-learn 式 API,需要将 [`xgboost.XGBClassifier`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier) 修改为 [`xgboost.dask.DaskXGBClassifier`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.dask.DaskXGBClassifier) 或者 [`xgboost.XGBRegressor`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor) 修改为 [`xgboost.dask.DaskXGBRegressor`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor)。\n", + "\n", + "### 分布式 GPU 训练\n", + "\n", + "Dask 可以管理多块 GPU,XGBoost 可以基于 Dask 进行多 GPU 训练,我们需要安装 Dask-CUDA 以启动一个多 GPU 的 Dask 集群。Dask 可以将 XGBoost 分布到多张 GPU 卡上进行训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/distributed/deploy/spec.py:324: UserWarning: Port 8787 is already in use.\n", + "Perhaps you already have a cluster running?\n", + "Hosting the HTTP server on port 44607 instead\n", + " self.scheduler = cls(**self.scheduler_spec.get(\"options\", {}))\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-7c3ce804-0aef-11ef-98d2-000012e4fe80

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: dask_cuda.LocalCUDACluster
\n", + " Dashboard: http://127.0.0.1:44607/status\n", + "
\n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

LocalCUDACluster

\n", + "

e461dd92

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\n", + " \n", + "
\n", + " Dashboard: http://127.0.0.1:44607/status\n", + " \n", + " Workers: 4\n", + "
\n", + " Total threads: 4\n", + " \n", + " Total memory: 90.00 GiB\n", + "
Status: runningUsing processes: True
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-a6b71eff-839c-4686-9316-a886dc1da17a

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://127.0.0.1:33619\n", + " \n", + " Workers: 4\n", + "
\n", + " Dashboard: http://127.0.0.1:44607/status\n", + " \n", + " Total threads: 4\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 90.00 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 0

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:45305\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:46261/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:37589\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-qo8pr3rx\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 1

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:38835\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:38961/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:40985\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-vhjea3dv\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 2

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:46315\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:42153/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:39945\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-5uebhi4w\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: 3

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://127.0.0.1:38331\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://127.0.0.1:42005/status\n", + " \n", + " Memory: 22.50 GiB\n", + "
\n", + " Nanny: tcp://127.0.0.1:42591\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-zylz7yva\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dask_cuda import LocalCUDACluster\n", + "import xgboost as xgb\n", + "client = Client(LocalCUDACluster())\n", + "client" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "clf = xgb.dask.DaskXGBClassifier(verbosity=1)\n", + "clf.set_params(tree_method=\"hist\", device=\"cuda\")\n", + "clf.client = client" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask_ml/datasets.py:373: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", + " informative_idx, beta = dask.compute(\n", + "[23:01:19] task [xgboost.dask-0]:tcp://127.0.0.1:45305 got new rank 0\n", + "[23:01:19] task [xgboost.dask-1]:tcp://127.0.0.1:38835 got new rank 1\n", + "[23:01:19] task [xgboost.dask-2]:tcp://127.0.0.1:46315 got new rank 2\n", + "[23:01:20] task [xgboost.dask-3]:tcp://127.0.0.1:38331 got new rank 3\n" + ] + } + ], + "source": [ + "X, y = dask_ml.datasets.make_classification(n_samples=100_000, \n", + " n_features=1_000, \n", + " random_state=42,\n", + " chunks=100_000 // 100\n", + ")\n", + "clf.fit(X, y, eval_set=[(X, y)], verbose=False)\n", + "prediction = clf.predict(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/ch-dask-ml/index.md b/ch-dask-ml/index.md new file mode 100644 index 0000000..7068513 --- /dev/null +++ b/ch-dask-ml/index.md @@ -0,0 +1,4 @@ +# Dask 机器学习 + +```{tableofcontents} +``` \ No newline at end of file