Features
Edux library supports a wide array of machine learning methods and algorithms. Here's what you can expect:
- Multilayer Perceptron (Neural Network)
- K Nearest Neighbors
- Decision Tree
- Support Vector Machine
- RandomForest
Integration
Gradle
implementation 'io.github.samyssmile:edux:1.0.5'
Maven
<dependency>
<groupId>io.github.samyssmile</groupId>
<artifactId>edux</artifactId>
<version>1.0.5</version>
</dependency>
Examples
Multilayer Network
NetworkConfiguration config = new NetworkConfiguration(
ActivationFunction.LEAKY_RELU,
ActivationFunction.SOFTMAX,
LossFunction.CATEGORICAL_CROSS_ENTROPY,
Initialization.XAVIER, Initialization.XAVIER);
Classifier mlp = new MultilayerPerceptron(
features,
labels,
testFeatures,
testLabels,
config);
mlp.train();
mlp.predict(...);
Decision Tree
var datasetProvider = new IrisProvider(NORMALIZE, SHUFFLE, 0.6);
datasetProvider.printStatistics();
double[][] features = datasetProvider.getTrainFeatures();
double[][] labels = datasetProvider.getTrainLabels();
double[][] testFeatures = datasetProvider.getTestFeatures();
double[][] testLabels = datasetProvider.getTestLabels();
Classifier decisionTree = new DecisionTree(8, 2, 1, 4);
decisionTree.train(features, labels);
decisionTree.evaluate(testFeatures, testLabels);
Support Vector Machine - Example on IRIS Dataset
/*
+-------------+------------+-------------+------------+---------+
| sepal.length| sepal.width| petal.length| petal.width| variety |
+-------------+------------+-------------+------------+---------+
| 5.1 | 3.5 | 1.4 | .2 | Setosa |
+-------------+------------+-------------+------------+---------+
*/
// First 4 columns are features
var featureColumnIndices = new int[]{0, 1, 2, 3};
// Last column is the target
var targetColumnIndex = 4;
var irisDataProcessor = new DataProcessor(new CSVIDataReader())
.loadDataSetFromCSV(CSV_FILE, ',', SKIP_HEAD,
featureColumnIndices,
targetColumnIndex)
.normalize().shuffle()
.split(TRAIN_TEST_SPLIT_RATIO);
Classifier svm = new SupportVectorMachine(SVMKernel.LINEAR, 2);
var trainFeatures = irisDataProcessor.getTrainFeatures(featureColumnIndices);
var trainTestFeatures = irisDataProcessor.getTestFeatures(featureColumnIndices);
var trainLabels = irisDataProcessor.getTrainLabels(targetColumnIndex);
var trainTestLabels = irisDataProcessor.getTestLabels(targetColumnIndex);
svm.train(trainFeatures, trainLabels);
svm.evaluate(trainTestFeatures, trainTestLabels);