EDUX - Java Machine Learning Library

Java library for solving problems with a machine learning approach.

Features

Edux library supports a wide array of machine learning methods and algorithms. Here's what you can expect:

  • Multilayer Perceptron (Neural Network)
  • K Nearest Neighbors
  • Decision Tree
  • Support Vector Machine
  • RandomForest

Integration

Gradle

                        
                    
implementation 'io.github.samyssmile:edux:1.0.7'
                    
                

Maven

                    
<dependency>
    <groupId>io.github.samyssmile</groupId>
    <artifactId>edux</artifactId>
    <version>1.0.7</version>
</dependency>

Examples

Multilayer Network

                

Step 1: Data Processing

Firstly, we will load and prepare the IRIS dataset:
sepal.length sepal.width petal.length petal.width variety
5.1 3.5 1.4 0.2 Setosa
var featureColumnIndices=new int[]{0,1,2,3}; // Specify your feature columns var targetColumnIndex=4; // Specify your target column var dataProcessor=new DataProcessor(new CSVIDataReader()); var dataset=dataProcessor.loadDataSetFromCSV( new File("path/to/your/data.csv"), // Replace with your CSV file path ',', // CSV delimiter true, // Whether to skip the header featureColumnIndices, targetColumnIndex ); dataset.shuffle(); dataset.normalize(); dataProcessor.split(0.8); // Replace with your train-test split ratio

Step 2: Configure the MultilayerPerceptron

Extract the features and labels for both training and test sets:
var trainFeatures=dataProcessor.getTrainFeatures(featureColumnIndices); var trainLabels=dataProcessor.getTrainLabels(targetColumnIndex); var testFeatures=dataProcessor.getTestFeatures(featureColumnIndices); var testLabels=dataProcessor.getTestLabels(targetColumnIndex);

Step 3: Network Configuration

var networkConfiguration=new NetworkConfiguration( trainFeatures[0].length, // Number of input neurons List.of(128,256,512), // Number of neurons in each hidden layer 3, // Number of output neurons 0.01, // Learning rate 300, // Number of epochs ActivationFunction.LEAKY_RELU, // Activation function for hidden layers ActivationFunction.SOFTMAX, // Activation function for output layer LossFunction.CATEGORICAL_CROSS_ENTROPY, // Loss function Initialization.XAVIER, // Weight initialization for hidden layers Initialization.XAVIER // Weight initialization for output layer );

Step 4: Training and Evaluation

MultilayerPerceptron multilayerPerceptron=new MultilayerPerceptron( networkConfiguration, testFeatures, testLabels ); multilayerPerceptron.train(trainFeatures,trainLabels); multilayerPerceptron.evaluate(testFeatures,testLabels);

Results

... MultilayerPerceptron - Best accuracy after restoring best MLP model: 93.33%

Decision Tree

                
    var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6);
    datasetProvider.printStatistics();

    double[][] features = datasetProvider.getTrainFeatures();
    double[][] labels = datasetProvider.getTrainLabels();

    double[][] testFeatures = datasetProvider.getTestFeatures();
    double[][] testLabels = datasetProvider.getTestLabels();

    Classifier decisionTree = new DecisionTree(8, 2, 1, 4);
    decisionTree.train(features, labels);
    decisionTree.evaluate(testFeatures, testLabels);
                
            

Support Vector Machine - Example on IRIS Dataset

                    

Step 1 : Initialize the variables

private static final double TRAIN_TEST_SPLIT_RATIO = 0.70; private static final File CSV_FILE = new File("path/to/your/iris.csv"); private static final boolean SKIP_HEAD = true;

Step 2 : Create,train and evaluate classifier

/* +-------------+------------+-------------+------------+---------+ | sepal.length| sepal.width| petal.length| petal.width| variety | +-------------+------------+-------------+------------+---------+ | 5.1 | 3.5 | 1.4 | .2 | Setosa | +-------------+------------+-------------+------------+---------+ */ // First 4 columns are features var featureColumnIndices = new int[]{0, 1, 2, 3}; // Last column is the target var targetColumnIndex = 4; var irisDataProcessor = new DataProcessor(new CSVIDataReader()) .loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD,featureColumnIndices,targetColumnIndex) .normalize().shuffle() .split(TRAIN_TEST_SPLIT_RATIO); Classifier classifier = new DecisionTree(2, 2, 3, 12); var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices); var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices); var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex); var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex); classifier.train(trainFeatures, trainLabels); classifier.evaluate(trainTestFeatures, trainTestLabels);

Log

========================= Data Statistic ================== [main] INFO - Total dataset size: 150 [main] INFO - Training dataset size: 90 [main] INFO - Test data set size: 60 [main] INFO - Classes: 3 =========================================================== [main] INFO - Decision Tree - accuracy: 93,33%