Skip to content

Commit

Permalink
Add Stats Example Noteboks (#5)
Browse files Browse the repository at this point in the history
Co-authored-by: jameswillis <[email protected]>
  • Loading branch information
james-willis and jameswillis authored Sep 17, 2024
1 parent 8a9a4f9 commit e713def
Show file tree
Hide file tree
Showing 3 changed files with 576 additions and 0 deletions.
197 changes: 197 additions & 0 deletions python/wherobots-ai/dbscan_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fa4dae809be2867",
"metadata": {},
"source": [
"![](https://wherobots.com/wp-content/uploads/2023/12/[email protected])\n",
"\n",
"# DBSCAN\n",
"DBSCAN is a popular algorithm for finding clusters of spatial data. It identifies core points that have enough (defined by the user) neighbors within some distance (also user defined). Points that are not core points but are within the distance of a core point are considered border points of the cluster. Points that are not core points and are not within the distance of a core point are considered outliers and not part of any cluster.\n",
"\n",
"The algorithm requires two parameters:\n",
"* `epsilon` - The farthest apart two points can be while still being considered connected or related. `epsilon` must be a positive double float.\n",
"* `minPoints` - The minimum number of neighbor points (as determined by epsilon). A point needs `minPoints` neighbors to be considered a core point. `minPoints` must be a positive integer."
]
},
{
"cell_type": "markdown",
"id": "d25f585f",
"metadata": {},
"source": [
"## Example overview\n",
"In this example, we will generate some random data and use DBSCAN to cluster that data. Then, we'll visualize the clusters using a scatter plot.\n",
"\n",
"This demo is derived from the [scikit-learn DBSCAN demo](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1a7540c75a535ce",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"%pip install scikit-learn"
]
},
{
"cell_type": "markdown",
"id": "32ec30f6e1885b75",
"metadata": {},
"source": [
"# Define Sedona Context"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a51adb8-f89f-4cb3-9a41-24a36d8f1fcf",
"metadata": {},
"outputs": [],
"source": [
"from sedona.spark import SedonaContext\n",
"import os\n",
"\n",
"config = SedonaContext.builder().getOrCreate()\n",
"sedona = SedonaContext.create(config)\n",
"bucket_path = os.getenv(\"USER_S3_PATH\")\n",
"\n",
"sedona.sparkContext.setCheckpointDir(bucket_path + \"checkpoints\")"
]
},
{
"cell_type": "markdown",
"id": "cff198142e2ebced",
"metadata": {},
"source": [
"## Data Generation\n",
"In the following code section, we'll generate some data using sklearn's `make_blobs` function. We've set the data to consist of 750 points with 3 clusters. After clustering the data, we'll visualize it in `pyplot`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99f8c27a-9c0b-4f8f-a388-54a781d892e3",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import make_blobs\n",
"from sklearn.preprocessing import StandardScaler\n",
"import matplotlib.pyplot as plt\n",
"\n",
"center_clusters = [[1, 1], [-1, -1], [1, -1]]\n",
"feature_matrix, labels_true = make_blobs(\n",
" n_samples=750, centers=center_clusters, cluster_std=0.4, random_state=0\n",
")\n",
"\n",
"feature_matrix = StandardScaler().fit_transform(feature_matrix)\n",
"\n",
"plt.scatter(feature_matrix[:, 0], feature_matrix[:, 1])\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "372ae8aee6714c8c",
"metadata": {},
"source": [
"## Clustering\n",
"In the following section, we'll use the DBSCAN implementation in Wherobots to cluster the data in a dataframe, setting `epsilon` to `0.3` and `minPoints` to `10`.\n",
"\n",
"Wherobots' DBSCAN returns outliers by default."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe5af921-957c-48cc-942a-e6c744c72bc5",
"metadata": {},
"outputs": [],
"source": [
"import pyspark.sql.functions as f\n",
"from sedona.sql.st_constructors import ST_MakePoint\n",
"from wherobots.clustering.dbscan import dbscan\n",
"\n",
"df = sedona.createDataFrame(X).select(ST_MakePoint(\"_1\", \"_2\").alias(\"geometry\"))\n",
"clusters_df = dbscan(df, 0.3, 10, include_outliers=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d060243f-c00b-436d-99d4-e0fbaba89930",
"metadata": {},
"outputs": [],
"source": [
"clusters_df.show()"
]
},
{
"cell_type": "markdown",
"id": "5a475ce35250afea",
"metadata": {},
"source": [
"## Visualization\n",
"Finally, we'll visualize the clusters using geopandas. Some manipulations are made to the data to improve the clarity of the visualization."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "244c8d55-0c69-4922-b769-666c056098c4",
"metadata": {},
"outputs": [],
"source": [
"import geopandas as gpd\n",
"import pyspark.sql.types as t\n",
"\n",
"pdf = (clusters_df\n",
" .withColumn(\"isCore\", (f.col(\"isCore\").cast(t.IntegerType()) + 1) * 40)\n",
" .withColumn(\"cluster\", f.hash(\"cluster\").cast(t.StringType()))\n",
" .toPandas()\n",
" )\n",
"gdf = gpd.GeoDataFrame(pdf, geometry=\"geometry\")\n",
"\n",
"gdf.plot(\n",
" figsize=(10, 8),\n",
" column=\"cluster\",\n",
" markersize=gdf['isCore'],\n",
" edgecolor='lightgray',\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "057ef39f-0d20-464d-9891-590e5aff6d72",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit e713def

Please sign in to comment.