Design and implementation of a small DeepLearning library from scratch in Java, inspired by Keras and the book "Deep Learning From Scratch" by O'Reilly. The goal is to apply the main OOP Design Patterns and UML diagramming in a software project.
As a project for the Software Design and Databases courses at university, I created and executed a project for a game inspired by QuickDraw, where a neural network attempts to classify your hand-drawn sketches.
Because of this, I decided to create a DeepLearning "library" from scratch in Java and use it in the game. By employing the main design patterns and best practices, I was inspired by the Keras library and its Sequential API to create different architectures with the classic layers and operations of a neural network.
In the scope of databases, I used MongoDB and Morphia for Object Document Mapping (ODM) and to persist the trained models.
The "library" uses the ND4J library for tensor and matrix manipulation, utilizing NDArray as the main data structure. Similar to Numpy, the library allows for vectorized and efficient operations, enabling its use in somewhat more complex applications. There are certain redundancies in using this library, as it already possesses many of the operations and serves as the foundation for the algebraic operations of DL4J. The idea was to implement the relevant mathematical parts while taking advantage of the data structure it provides.
The examples package contains some use cases and tests for the library. The main ones are:
-
-
Mnist: A simple example of a neural network for classifying handwritten digits from the MNIST dataset.
-
QuickDraw: A neural network for classifying hand-drawn sketches from the QuickDraw dataset.
mnistplotter.mp4
qdrawplotter.mp4
-
-
-
Linear: A simple example of a neural network for linear regression.
-
NonLinearFunctions: A neural network for approximating non-linear functions, such as sine, saddle, rosenbrock, and others.
Saddle Function Rosenbrock Function Sine Function Linear Regression
-
-
Code example for training a neural network with the library:
DataLoader dataLoader = new DataLoader(root + "/npy/train/x_train250.npy", root + "/npy/train/y_train250.npy", root + "/npy/test/x_test250.npy", root + "/npy/test/y_test250.npy"); INDArray xTrain = dataLoader.getAllTrainImages().get(NDArrayIndex.interval(0, trainSize)); INDArray yTrain = dataLoader.getAllTrainLabels().reshape(-1, 1).get(NDArrayIndex.interval(0, trainSize)); INDArray xTest = dataLoader.getAllTestImages().get(NDArrayIndex.interval(0, testSize)); INDArray yTest = dataLoader.getAllTestLabels().reshape(-1, 1).get(NDArrayIndex.interval(0, testSize)); // Normalization xTrain = xTrain.divi(255); xTest = xTest.divi(255); // Reshape xTrain = xTrain.reshape(xTrain.rows(), 28, 28, 1); xTest = xTest.reshape(xTest.rows(), 28, 28, 1); NeuralNetwork model = new ModelBuilder() .add(new Conv2D(32, 2, Arrays.asList(2, 2), "valid", Activation.create("relu"), "he")) .add(new Conv2D(16, 1, Arrays.asList(1, 1), "valid", Activation.create("relu"), "he")) .add(new Flatten()) .add(new Dense(178, Activation.create("relu"), "he")) .add(new Dropout(0.4)) .add(new Dense(49, Activation.create("relu"), "he")) .add(new Dropout(0.3)) .add(new Dense(numClasses, Activation.create("linear"), "he")) .build(); int epochs = 20; int batchSize = 64; LearningRateDecayStrategy lr = new ExponentialDecayStrategy(0.01, 0.0001, epochs); Optimizer optimizer = new RMSProp(lr); Trainer trainer = new TrainerBuilder(model, xTrain, yTrain, xTest, yTest, new SoftmaxCrossEntropy()) .setOptimizer(optimizer) .setBatchSize(batchSize) .setEpochs(epochs) .setEvalEvery(2) .setEarlyStopping(true) .setPatience(4) .setMetric(new Accuracy()) .build(); trainer.fit();
Complete example: QuickDrawNN.java
The library has the following features and is structured into the following packages:
src/main/java/br/deeplearning4java/neuralnetwork
├── core
│ ├── activation
│ ├── layers
│ ├── losses
│ ├── metrics
│ ├── optimizers
│ ├── models
│ └── train
└── data
│ └── preprocessing
└── database
└── examples
│ ├── activations
│ ├── classification
│ ├── regression
└── persist
-
Layers: Dense, Dropout, Flatten, Conv2D, MaxPooling2D, ZeroPadding2D
-
Activation functions: ReLU, LeakyReLU, SiLU, Sigmoid, Tanh, Softmax
-
Loss Functions: MSE, BinaryCrossEntropy, CategoricalCrossEntropy, SoftmaxCrossEntropy
-
Optimizers: SGD, Adam, RMSProp, AdaGrad, AdaDelta, SGDMomentum, SGDNesterov, RegularizedSGD
-
Learning Rate Decay: ExponentialDecay, LinearDecay
-
-
Models: ModelBuilder, NeuralNetwork
-
Train: Trainer, TrainerBuilder
-
Metrics: Accuracy, Precision, Recall, F1Score, MSE, MAE, RMSE, R2
-
Data: DataProcessor, DataPipeline, StandardScaler, MinMaxScaler, Util, DataLoader, PlotDataPredict
-
Datastore: NeuralNetworkService, NeuralNetworkRepository, NeuralNetworkEntity
Full diagram: UML Diagram, UML Diagram (Dia)
The game essentially uses the model trained with the "library" in QuickDrawNN to classify the drawings made by the player. Ten classes from the original dataset were selected, and each session consists of 4 rounds. In each round, if the prediction confidence is greater than 50%, the drawing and the round are saved in the database. There is also a visualization screen that allows users to view all the drawings saved in the database and delete them.
The game implementation uses JavaFX with the MVC pattern.