Skip to content

Commit

Permalink
add Ray Train and Tune (#30)
Browse files Browse the repository at this point in the history
* train & tune

* ray train tune

* train & tune
  • Loading branch information
luweizheng authored Apr 12, 2024
1 parent 4ed0c65 commit 3a77bc3
Show file tree
Hide file tree
Showing 10 changed files with 1,593 additions and 135 deletions.
1 change: 1 addition & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ subtrees:
- file: ch-ray-train-tune/index
entries:
- file: ch-ray-train-tune/ray-train
- file: ch-ray-train-tune/ray-tune
- file: ch-mpi/index
entries:
- file: ch-mpi/mpi-intro
Expand Down
2 changes: 1 addition & 1 deletion ch-data-science/machine-learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"(machine-learning-intro)=\n",
"(sec-machine-learning-intro)=\n",
"# 机器学习\n",
"\n",
"机器学习指让计算机学习已有数据中的统计规律,并用来预测未知数据。机器学习项目总共分两个阶段:训练(Training)和推理(Inference)。计算机学习已有数据的过程被称为训练阶段,预测未知数据的过程被称为推理阶段。\n",
Expand Down
4 changes: 2 additions & 2 deletions ch-mpi-large-model/data-parallel.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(data-parallel)=
(sec-data-parallel)=
# 数据并行

数据并行是一种最常见的大模型并行方法,相对其他并行,数据并行最简单。如 {numref}`data-parallel-img` 所示,模型被拷贝到不同的 GPU 设备上,训练数据被切分为多份,每份分给不同的 GPU 进行训练。这种编程范式又被称为单程序多数据(Single Program Multiple Data,SPMD)。
Expand All @@ -13,7 +13,7 @@ name: data-parallel-img

## 非并行训练

{numref}`machine-learning-intro` 介绍了神经网络模型训练的过程。我们先从非并行的场景开始,这里使用 MNIST 手写数字识别案例来演示,如 {numref}`data-parallel-single` 所示,它包含了一次前向传播和一次反向传播。
{numref}`sec-machine-learning-intro` 介绍了神经网络模型训练的过程。我们先从非并行的场景开始,这里使用 MNIST 手写数字识别案例来演示,如 {numref}`data-parallel-single` 所示,它包含了一次前向传播和一次反向传播。

```{figure} ../img/ch-mpi-large-model/data-parallel-single.svg
---
Expand Down
240 changes: 132 additions & 108 deletions ch-ray-train-tune/ray-train.ipynb

Large diffs are not rendered by default.

1,328 changes: 1,328 additions & 0 deletions ch-ray-train-tune/ray-tune.ipynb

Large diffs are not rendered by default.

67 changes: 44 additions & 23 deletions drawio/ch-ray-train-tune/ray-train-key-parts.drawio
Original file line number Diff line number Diff line change
@@ -1,52 +1,73 @@
<mxfile host="Electron" modified="2024-04-03T07:36:02.584Z" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/22.1.18 Chrome/120.0.6099.199 Electron/28.1.2 Safari/537.36" etag="ro6oz1oZ-BwwbX8bOf9P" version="22.1.18" type="device">
<mxfile host="Electron" modified="2024-04-11T02:42:18.853Z" agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) draw.io/22.1.18 Chrome/120.0.6099.199 Electron/28.1.2 Safari/537.36" etag="OnOEuhLFypzp8dmplVJC" version="22.1.18" type="device">
<diagram name="第 1 页" id="YkZ-crnKk8mhjwdHUlSD">
<mxGraphModel dx="2060" dy="1104" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0" />
<mxCell id="1" parent="0" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-1" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=none;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-1" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=none;" parent="1" vertex="1">
<mxGeometry x="210" y="359" width="180" height="180" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-2" value="&lt;font style=&quot;font-size: 18px;&quot;&gt;Trainer&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-2" value="&lt;font style=&quot;font-size: 18px;&quot;&gt;&lt;b&gt;Trainer&lt;/b&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" parent="1" vertex="1">
<mxGeometry x="270" y="360" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-3" value="&lt;font style=&quot;font-size: 14px;&quot; face=&quot;Garamond&quot;&gt;train_func&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#F5F5F5;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-3" value="&lt;font style=&quot;font-size: 14px;&quot;&gt;train_loop&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#F5F5F5;fontFamily=Garamond;fontStyle=2" parent="1" vertex="1">
<mxGeometry x="225" y="390" width="150" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-4" value="&lt;font face=&quot;Garamond&quot; style=&quot;font-size: 14px;&quot;&gt;ScalingConfig&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#F5F5F5;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-4" value="&lt;font style=&quot;font-size: 14px;&quot;&gt;ScalingConfig&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#F5F5F5;fontFamily=Garamond;fontStyle=2" parent="1" vertex="1">
<mxGeometry x="225" y="430" width="150" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-5" value="&lt;span style=&quot;font-size: 14px;&quot;&gt;&lt;font face=&quot;Garamond&quot;&gt;ray.data&lt;/font&gt;&lt;/span&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#90C9E6;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-5" value="&lt;span style=&quot;font-size: 14px;&quot;&gt;&lt;font&gt;ray.data&lt;/font&gt;&lt;/span&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#90C9E6;fontFamily=Garamond;fontStyle=2" parent="1" vertex="1">
<mxGeometry x="225" y="470" width="150" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-6" value="&lt;font style=&quot;font-size: 14px;&quot; face=&quot;Garamond&quot;&gt;prepare_data_loader&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#90C9E6;" vertex="1" parent="1">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-6" value="&lt;font style=&quot;font-size: 14px;&quot;&gt;prepare_data_loader&lt;/font&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#90C9E6;fontFamily=Garamond;fontStyle=2" parent="1" vertex="1">
<mxGeometry x="225" y="500" width="150" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-8" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=none;" vertex="1" parent="1">
<mxGeometry x="450" y="404" width="160" height="90" as="geometry" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-8" value="" style="rounded=1;whiteSpace=wrap;html=1;fillColor=none;" parent="1" vertex="1">
<mxGeometry x="450" y="367" width="160" height="90" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-9" value="" style="endArrow=classic;html=1;rounded=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;strokeWidth=1.5;" edge="1" parent="1" source="pZMyhNOI0UIP7ZEqHQ5e-1" target="pZMyhNOI0UIP7ZEqHQ5e-8">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-9" value="" style="endArrow=classic;html=1;rounded=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;entryPerimeter=0;strokeWidth=1.5;" parent="1" target="pZMyhNOI0UIP7ZEqHQ5e-8" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="440" y="330" as="sourcePoint" />
<mxPoint x="490" y="280" as="targetPoint" />
<mxPoint x="390" y="412" as="sourcePoint" />
<mxPoint x="490" y="243" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-10" value="&lt;font style=&quot;font-size: 18px;&quot;&gt;.fit()&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;strokeWidth=1.5;" vertex="1" parent="1">
<mxGeometry x="390" y="417" width="60" height="30" as="geometry" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-10" value="&lt;font face=&quot;Garamond&quot; style=&quot;font-size: 18px;&quot;&gt;&lt;i&gt;.fit()&lt;/i&gt;&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;strokeWidth=1.5;" parent="1" vertex="1">
<mxGeometry x="390" y="380" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-11" value="Checkpoint" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=18;" vertex="1" parent="1">
<mxGeometry x="500" y="404" width="60" height="30" as="geometry" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-11" value="&lt;b&gt;Checkpoint&lt;/b&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;fontSize=18;" parent="1" vertex="1">
<mxGeometry x="500" y="367" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-12" value="&lt;b&gt;&lt;font style=&quot;font-size: 14px;&quot;&gt;本地&lt;/font&gt;&lt;/b&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#E6D0DE;" vertex="1" parent="1">
<mxGeometry x="460" y="445" width="60" height="40" as="geometry" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-12" value="&lt;b&gt;&lt;font style=&quot;font-size: 14px;&quot;&gt;本地&lt;/font&gt;&lt;/b&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#E6D0DE;" parent="1" vertex="1">
<mxGeometry x="460" y="408" width="60" height="40" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-13" value="&lt;span style=&quot;font-size: 14px;&quot;&gt;&lt;b&gt;持久化&lt;/b&gt;&lt;/span&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#E6D0DE;" vertex="1" parent="1">
<mxGeometry x="540" y="445" width="60" height="40" as="geometry" />
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-13" value="&lt;span style=&quot;font-size: 14px;&quot;&gt;&lt;b&gt;持久化&lt;/b&gt;&lt;/span&gt;" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#E6D0DE;" parent="1" vertex="1">
<mxGeometry x="540" y="408" width="60" height="40" as="geometry" />
</mxCell>
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-14" value="" style="endArrow=classic;html=1;rounded=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" edge="1" parent="1" source="pZMyhNOI0UIP7ZEqHQ5e-12" target="pZMyhNOI0UIP7ZEqHQ5e-13">
<mxCell id="pZMyhNOI0UIP7ZEqHQ5e-14" value="" style="endArrow=classic;html=1;rounded=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;" parent="1" source="pZMyhNOI0UIP7ZEqHQ5e-12" target="pZMyhNOI0UIP7ZEqHQ5e-13" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="590" y="775" as="sourcePoint" />
<mxPoint x="640" y="725" as="targetPoint" />
<mxPoint x="590" y="738" as="sourcePoint" />
<mxPoint x="640" y="688" as="targetPoint" />
</mxGeometry>
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-1" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=none;" vertex="1" parent="1">
<mxGeometry x="396" y="465" width="215" height="70" as="geometry" />
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-2" value="&lt;span style=&quot;font-size: 18px;&quot;&gt;&lt;b&gt;Ray 集群&lt;/b&gt;&lt;/span&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;rounded=0;" vertex="1" parent="1">
<mxGeometry x="463.5" y="464" width="80" height="30" as="geometry" />
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-3" value="Worker" style="rounded=0;whiteSpace=wrap;html=1;fontSize=14;fillColor=none;strokeColor=#36393d;fontFamily=Times New Roman;" vertex="1" parent="1">
<mxGeometry x="403.5" y="497" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-4" value="Worker" style="rounded=0;whiteSpace=wrap;html=1;fontSize=14;fillColor=none;strokeColor=#36393d;fontFamily=Times New Roman;" vertex="1" parent="1">
<mxGeometry x="473.5" y="497" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-5" value="Worker" style="rounded=0;whiteSpace=wrap;html=1;fontSize=14;fillColor=none;strokeColor=#36393d;fontFamily=Times New Roman;" vertex="1" parent="1">
<mxGeometry x="543.5" y="497" width="60" height="30" as="geometry" />
</mxCell>
<mxCell id="-tJ4ZFR9xjcvRMVxW-Dr-6" value="" style="endArrow=classic;startArrow=classic;html=1;rounded=0;" edge="1" parent="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="420" y="462" as="sourcePoint" />
<mxPoint x="420" y="412" as="targetPoint" />
</mxGeometry>
</mxCell>
</root>
Expand Down
Loading

0 comments on commit 3a77bc3

Please sign in to comment.