-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: jameswillis <[email protected]>
- Loading branch information
1 parent
8a9a4f9
commit e713def
Showing
3 changed files
with
576 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "fa4dae809be2867", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"# DBSCAN\n", | ||
"DBSCAN is a popular algorithm for finding clusters of spatial data. It identifies core points that have enough (defined by the user) neighbors within some distance (also user defined). Points that are not core points but are within the distance of a core point are considered border points of the cluster. Points that are not core points and are not within the distance of a core point are considered outliers and not part of any cluster.\n", | ||
"\n", | ||
"The algorithm requires two parameters:\n", | ||
"* `epsilon` - The farthest apart two points can be while still being considered connected or related. `epsilon` must be a positive double float.\n", | ||
"* `minPoints` - The minimum number of neighbor points (as determined by epsilon). A point needs `minPoints` neighbors to be considered a core point. `minPoints` must be a positive integer." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d25f585f", | ||
"metadata": {}, | ||
"source": [ | ||
"## Example overview\n", | ||
"In this example, we will generate some random data and use DBSCAN to cluster that data. Then, we'll visualize the clusters using a scatter plot.\n", | ||
"\n", | ||
"This demo is derived from the [scikit-learn DBSCAN demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html)." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a1a7540c75a535ce", | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install scikit-learn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "32ec30f6e1885b75", | ||
"metadata": {}, | ||
"source": [ | ||
"# Define Sedona Context" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "3a51adb8-f89f-4cb3-9a41-24a36d8f1fcf", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sedona.spark import SedonaContext\n", | ||
"import os\n", | ||
"\n", | ||
"config = SedonaContext.builder().getOrCreate()\n", | ||
"sedona = SedonaContext.create(config)\n", | ||
"bucket_path = os.getenv(\"USER_S3_PATH\")\n", | ||
"\n", | ||
"sedona.sparkContext.setCheckpointDir(bucket_path + \"checkpoints\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cff198142e2ebced", | ||
"metadata": {}, | ||
"source": [ | ||
"## Data Generation\n", | ||
"In the following code section, we'll generate some data using sklearn's `make_blobs` function. We've set the data to consist of 750 points with 3 clusters. After clustering the data, we'll visualize it in `pyplot`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "99f8c27a-9c0b-4f8f-a388-54a781d892e3", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sklearn.datasets import make_blobs\n", | ||
"from sklearn.preprocessing import StandardScaler\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"\n", | ||
"center_clusters = [[1, 1], [-1, -1], [1, -1]]\n", | ||
"feature_matrix, labels_true = make_blobs(\n", | ||
" n_samples=750, centers=center_clusters, cluster_std=0.4, random_state=0\n", | ||
")\n", | ||
"\n", | ||
"feature_matrix = StandardScaler().fit_transform(feature_matrix)\n", | ||
"\n", | ||
"plt.scatter(feature_matrix[:, 0], feature_matrix[:, 1])\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "372ae8aee6714c8c", | ||
"metadata": {}, | ||
"source": [ | ||
"## Clustering\n", | ||
"In the following section, we'll use the DBSCAN implementation in Wherobots to cluster the data in a dataframe, setting `epsilon` to `0.3` and `minPoints` to `10`.\n", | ||
"\n", | ||
"Wherobots' DBSCAN returns outliers by default." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fe5af921-957c-48cc-942a-e6c744c72bc5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import pyspark.sql.functions as f\n", | ||
"from sedona.sql.st_constructors import ST_MakePoint\n", | ||
"from wherobots.clustering.dbscan import dbscan\n", | ||
"\n", | ||
"df = sedona.createDataFrame(X).select(ST_MakePoint(\"_1\", \"_2\").alias(\"geometry\"))\n", | ||
"clusters_df = dbscan(df, 0.3, 10, include_outliers=True)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "d060243f-c00b-436d-99d4-e0fbaba89930", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"clusters_df.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5a475ce35250afea", | ||
"metadata": {}, | ||
"source": [ | ||
"## Visualization\n", | ||
"Finally, we'll visualize the clusters using geopandas. Some manipulations are made to the data to improve the clarity of the visualization." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "244c8d55-0c69-4922-b769-666c056098c4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import geopandas as gpd\n", | ||
"import pyspark.sql.types as t\n", | ||
"\n", | ||
"pdf = (clusters_df\n", | ||
" .withColumn(\"isCore\", (f.col(\"isCore\").cast(t.IntegerType()) + 1) * 40)\n", | ||
" .withColumn(\"cluster\", f.hash(\"cluster\").cast(t.StringType()))\n", | ||
" .toPandas()\n", | ||
" )\n", | ||
"gdf = gpd.GeoDataFrame(pdf, geometry=\"geometry\")\n", | ||
"\n", | ||
"gdf.plot(\n", | ||
" figsize=(10, 8),\n", | ||
" column=\"cluster\",\n", | ||
" markersize=gdf['isCore'],\n", | ||
" edgecolor='lightgray',\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "057ef39f-0d20-464d-9891-590e5aff6d72", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.11" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.