Skip to content

Design and implementation of a small DeepLearning library from scratch in Java, inspired by Keras and the book "Deep Learning From Scratch" by O'Reilly. The goal is to apply the main OOP Design Patterns and UML diagramming in a software project.

License

Notifications You must be signed in to change notification settings

samuellimabraz/DeepLearning4Java

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DeepLearning4Java

Design and implementation of a small DeepLearning library from scratch in Java, inspired by Keras and the book "Deep Learning From Scratch" by O'Reilly. The goal is to apply the main OOP Design Patterns and UML diagramming in a software project.

Introduction

As a project for the Software Design and Databases courses at university, I created and executed a project for a game inspired by QuickDraw, where a neural network attempts to classify your hand-drawn sketches.

Because of this, I decided to create a DeepLearning "library" from scratch in Java and use it in the game. By employing the main design patterns and best practices, I was inspired by the Keras library and its Sequential API to create different architectures with the classic layers and operations of a neural network.

In the scope of databases, I used MongoDB and Morphia for Object Document Mapping (ODM) and to persist the trained models.


Rede Neural

The "library" uses the ND4J library for tensor and matrix manipulation, utilizing NDArray as the main data structure. Similar to Numpy, the library allows for vectorized and efficient operations, enabling its use in somewhat more complex applications. There are certain redundancies in using this library, as it already possesses many of the operations and serves as the foundation for the algebraic operations of DL4J. The idea was to implement the relevant mathematical parts while taking advantage of the data structure it provides.

Examples

The examples package contains some use cases and tests for the library. The main ones are:

  • Classification:

    • Mnist: A simple example of a neural network for classifying handwritten digits from the MNIST dataset.

    • QuickDraw: A neural network for classifying hand-drawn sketches from the QuickDraw dataset.

    mnistplotter.mp4
    qdrawplotter.mp4
  • Regression:

    • Linear: A simple example of a neural network for linear regression.

    • NonLinearFunctions: A neural network for approximating non-linear functions, such as sine, saddle, rosenbrock, and others.

      SaddleFunction Saddle Function RosenbrockFunction Rosenbrock Function
      SineFunction Sine Function LinearRegression Linear Regression
  • Code example for training a neural network with the library:

    DataLoader dataLoader = new DataLoader(root + "/npy/train/x_train250.npy", root + "/npy/train/y_train250.npy", root + "/npy/test/x_test250.npy", root + "/npy/test/y_test250.npy");
    INDArray xTrain = dataLoader.getAllTrainImages().get(NDArrayIndex.interval(0, trainSize));
    INDArray yTrain = dataLoader.getAllTrainLabels().reshape(-1, 1).get(NDArrayIndex.interval(0, trainSize));
    INDArray xTest = dataLoader.getAllTestImages().get(NDArrayIndex.interval(0, testSize));
    INDArray yTest = dataLoader.getAllTestLabels().reshape(-1, 1).get(NDArrayIndex.interval(0, testSize));
    
    // Normalization
    xTrain = xTrain.divi(255);
    xTest = xTest.divi(255);
    // Reshape
    xTrain = xTrain.reshape(xTrain.rows(), 28, 28, 1);
    xTest = xTest.reshape(xTest.rows(), 28, 28, 1);
    
    NeuralNetwork model = new ModelBuilder()
        .add(new Conv2D(32, 2, Arrays.asList(2, 2), "valid", Activation.create("relu"), "he"))
        .add(new Conv2D(16, 1, Arrays.asList(1, 1), "valid", Activation.create("relu"), "he"))
        .add(new Flatten())
        .add(new Dense(178, Activation.create("relu"), "he"))
        .add(new Dropout(0.4))
        .add(new Dense(49, Activation.create("relu"), "he"))
        .add(new Dropout(0.3))
        .add(new Dense(numClasses,  Activation.create("linear"), "he"))
        .build();
    
    int epochs = 20;
    int batchSize = 64;
    
    LearningRateDecayStrategy lr = new ExponentialDecayStrategy(0.01, 0.0001, epochs);
    Optimizer optimizer = new RMSProp(lr);
    Trainer trainer = new TrainerBuilder(model, xTrain, yTrain, xTest, yTest, new SoftmaxCrossEntropy())
            .setOptimizer(optimizer)
            .setBatchSize(batchSize)
            .setEpochs(epochs)
            .setEvalEvery(2)
            .setEarlyStopping(true)
            .setPatience(4)
            .setMetric(new Accuracy())
            .build();
    trainer.fit();

    Complete example: QuickDrawNN.java

Features

The library has the following features and is structured into the following packages:

src/main/java/br/deeplearning4java/neuralnetwork
├── core
│   ├── activation
│   ├── layers
│   ├── losses
│   ├── metrics
│   ├── optimizers
│   ├── models
│   └── train
└── data
│   └── preprocessing
└── database
└── examples
│   ├── activations
│   ├── classification
│   ├── regression
    └── persist

Full diagram: UML Diagram, UML Diagram (Dia)


Game

The game essentially uses the model trained with the "library" in QuickDrawNN to classify the drawings made by the player. Ten classes from the original dataset were selected, and each session consists of 4 rounds. In each round, if the prediction confidence is greater than 50%, the drawing and the round are saved in the database. There is also a visualization screen that allows users to view all the drawings saved in the database and delete them.

The game implementation uses JavaFX with the MVC pattern.

game.mp4

About

Design and implementation of a small DeepLearning library from scratch in Java, inspired by Keras and the book "Deep Learning From Scratch" by O'Reilly. The goal is to apply the main OOP Design Patterns and UML diagramming in a software project.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published