EDUX - Java Machine Learning Library

Java library for solving problems with a machine learning approach.

Features

Edux library supports a wide array of machine learning methods and algorithms. Here's what you can expect:

  • Multilayer Perceptron (Neural Network)
  • K Nearest Neighbors
  • Decision Tree
  • Support Vector Machine
  • RandomForest

Integration

Gradle

                    
implementation 'io.github.samyssmile:edux:1.0.5'
                    
                

Maven

                    
<dependency>
    <groupId>io.github.samyssmile</groupId>
    <artifactId>edux</artifactId>
    <version>1.0.5</version>
</dependency>
                    

Examples

Multilayer Network

                
    NetworkConfiguration config = new NetworkConfiguration(
                    ActivationFunction.LEAKY_RELU,
                    ActivationFunction.SOFTMAX,
                    LossFunction.CATEGORICAL_CROSS_ENTROPY,
                    Initialization.XAVIER, Initialization.XAVIER);

    Classifier mlp = new MultilayerPerceptron(
                    features,
                    labels,
                    testFeatures,
                    testLabels,
                    config);
    mlp.train();
    mlp.predict(...);
                
            

Decision Tree

                
    var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6);
    datasetProvider.printStatistics();

    double[][] features = datasetProvider.getTrainFeatures();
    double[][] labels = datasetProvider.getTrainLabels();

    double[][] testFeatures = datasetProvider.getTestFeatures();
    double[][] testLabels = datasetProvider.getTestLabels();

    Classifier decisionTree = new DecisionTree(8, 2, 1, 4);
    decisionTree.train(features, labels);
    decisionTree.evaluate(testFeatures, testLabels);
                
            

Support Vector Machine - Example on IRIS Dataset

                
    /*
        +-------------+------------+-------------+------------+---------+
        | sepal.length| sepal.width| petal.length| petal.width| variety |
        +-------------+------------+-------------+------------+---------+
        |     5.1     |     3.5    |     1.4     |     .2     | Setosa  |
        +-------------+------------+-------------+------------+---------+
    */

    // First 4 columns are features
    var featureColumnIndices = new int[]{0, 1, 2, 3};
    // Last column is the target
    var targetColumnIndex = 4;

    var irisDataProcessor = new DataProcessor(new CSVIDataReader())
        .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD,
                    featureColumnIndices,
                    targetColumnIndex)
        .normalize().shuffle()
        .split(TRAIN_TEST_SPLIT_RATIO);


    Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 2);

    var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices);
    var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices);
    var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex);
    var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex);

    svm.train(trainFeatures, trainLabels);
    svm.evaluate(trainTestFeatures, trainTestLabels);
                
            

Log

========================= Data Statistic ================== [main] INFO - Total dataset size: 150 [main] INFO - Training dataset size: 90 [main] INFO - Test data set size: 60 [main] INFO - Classes: 3 =========================================================== [main] INFO - Decision Tree - accuracy: 93,33%