-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCrossValidation.java
More file actions
126 lines (109 loc) · 4.89 KB
/
Copy pathCrossValidation.java
File metadata and controls
126 lines (109 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
* Created by VenkataRamesh on 12/7/2016.
*/
public class CrossValidation {
public double[][] dataSet;
public int kFold;
List<double[][]> kSplits = new ArrayList<>();
public CrossValidation() {
}
public CrossValidation(double[][] dataSet, int kFold) {
this.dataSet = dataSet;
this.kFold = kFold;
}
public List<Object[]> generateKFoldSplit(double[][] dataSet, int splits) {
List<Object[]> splitList = new ArrayList<>();
int splitSize = dataSet.length / splits;
int testSplitStartIndex = 0;
for (int k = 0; k < splits; k++) {
Object[] split = new Object[2];
double[][] testData = new double[splitSize][dataSet[0].length];
double[][] trainData = new double[dataSet.length - splitSize][dataSet[0].length];
for (int i = 0, x = 0, y = 0; i < dataSet.length; i++) {
if (i >= testSplitStartIndex && i < testSplitStartIndex + splitSize) {
testData[x] = dataSet[i];
x++;
} else {
trainData[y] = dataSet[i];
y++;
}
}
split[0] = trainData;
split[1] = testData;
splitList.add(split);
testSplitStartIndex += splitSize;
}
return splitList;
}
public List<Object[]> generateRandomSamples(double[][] dataSet, int samples) {
List<Object[]> splitList = new ArrayList<>();
int splitSize = new Double((dataSet.length) * 0.7).intValue();
int testSplitStartIndex = 0;
for (int k = 0; k < samples; k++) {
Object[] split = new Object[2];
double[][] testData = new double[splitSize][dataSet[0].length];
double[][] trainData = new double[splitSize][dataSet[0].length];
for (int i = 0; i < splitSize; i++) {
Random rand = new Random();
int index = rand.nextInt(dataSet.length);
for (int j = 0; j < dataSet[0].length; j++) {
trainData[i][j] = dataSet[index][j];
}
}
split[0] = trainData;
split[1] = testData;
splitList.add(split);
}
return splitList;
}
public static Object[] generatePartitionsForSplit(double[][] trainData, double cutValue, int attributeIndex) {
Arrays.sort(trainData, (r1, r2) -> Double.compare(r1[attributeIndex], r2[attributeIndex]));
int cutIndex = 0;
for (int i = 0; i < trainData.length; i++) {
if (trainData[i][attributeIndex] == cutValue) {
cutIndex = i;
break;
}
}
Object[] partitions = new Object[2];
RealMatrix rm = MatrixUtils.createRealMatrix(trainData);
partitions[0] = rm.getSubMatrix(0, cutIndex, 0, trainData[0].length - 1).getData();
partitions[1] = rm.getSubMatrix(cutIndex + 1, trainData.length - 1, 0, trainData[0].length - 1).getData();
return partitions;
}
public double[][] getNormalizedMatrix(double[][] dataMatrix, List<Integer> ignoreList, int flag) {
RealMatrix rm = MatrixUtils.createRealMatrix(dataMatrix);
double[][] normalizedData = new double[dataMatrix.length][dataMatrix[0].length];
RealMatrix normalizedMatrix = MatrixUtils.createRealMatrix(normalizedData);
//exclude last 2 columns
//double[][] tempMatrix = rm.getSubMatrix(0, dataMatrix.length - 1, 0, dataMatrix[0].length - 3).getData();
for (int i = 0; i < dataMatrix[0].length - 2; i++) {
if (!ignoreList.contains(i)) {
normalizedMatrix.setColumn(i, rm.getColumn(i));
double min = StatUtils.min(normalizedMatrix.getColumn(i));
double max = StatUtils.max(normalizedMatrix.getColumn(i));
for (int j = 0; j < normalizedMatrix.getColumn(i).length; j++) {
double x = normalizedMatrix.getEntry(j, i);
x = (x - min) / (max - min);
normalizedMatrix.setEntry(j, i, x);
}
}
normalizedMatrix.setColumn(dataMatrix[0].length - 2, rm.getColumn(dataMatrix[0].length - 2));
normalizedMatrix.setColumn(dataMatrix[0].length - 1, rm.getColumn(dataMatrix[0].length - 1));
if (flag == 1) {
normalizedMatrix.setColumn(dataMatrix[0].length - 3, rm.getColumn(dataMatrix[0].length - 3));
}
}
return normalizedMatrix.getData();
}
}