Hello, this is Sumiyama water.
Continuing from the previous session. [Introduction to Computer Science Part 2: Let's try machine learning] Let's implement k-means clustering in Java-distance between data-
If you are nice to meet you, please take a look from the beginning of the series. [Introduction to Computer Science No. 0: Try Machine Learning] Let's implement k-means clustering in Java
This time, I will talk about the idea of "the center of the data set".
I wrote about the preparation of this environment in this article. Notes when starting new development with IntelliJ + Gradle + SpringBoot + JUnit5 (jupiter)
I will talk about "Why do you need such a way of thinking?" A little earlier, but I will talk about what I want to be able to do.
It is faster to see the figure than to explain it in words, so please see here.
If there are multiple data like blue dots like this, where is the center of this group? I would like to define.
0th I also talked about that, but humans can judge that it is "about this area", but it is calculated numerically on the computer. I have to give a definition so that I can do it.
I've been entrusted with it for a long time, but here I will use the ** arithmetic mean (arithmetic mean) **, which is familiar in everyday life. [^ 1]
It's the one used for "average score in mathematics". It is the total of all the members and divided by the number of people.
Full name | Math score |
---|---|
Mr. A | 70 points |
Mr. B | 60 points |
Mr. C | 90 points |
Mr. D | 50 points |
In this case, the average score is
(70+60+90+50)/4 = 67.5
In other words, it will be 67.5 points.
We will apply this idea to the numerical values of coordinates. It looks like the figure when considering it with two-dimensional data. The "center coordinates" are the average of the numbers in the horizontal and vertical directions.
I added the getCentroid method to the classes I have created so far.
package net.tan3sugarless.clusteringsample.lib.data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.Value;
import net.tan3sugarless.clusteringsample.exception.DimensionNotUnifiedException;
import net.tan3sugarless.clusteringsample.exception.NullCoordinateException;
import net.tan3sugarless.clusteringsample.exception.UnexpectedCentroidException;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
/**
*Coordinate set on Euclidean metric space
*/
@Getter
@ToString
@EqualsAndHashCode
@Value
public class EuclideanSpace {
private final List<List<Double>> points;
private final int dimension;
//
//~~ Omitted ~~
//
/**
*Find the coordinates that are the center points of each point that belongs to the instance
* <p>
*The center is the arithmetic mean
*
* <pre>
*No. i in n-dimensional space,(x11+x21+....+xm1)/m,Let the jth element of the eye point be xij,*
* [x11, x12,...,x1n],[x21, x22,...,x2n]
*And m points are given, the coordinates of the center point are
*
* [(x11+x21+....+xm1)/m,(x12+x22+....+xm2)/m,...,(x1n+x2n+....+xmn)/m]
*Will be.
*
*Do the calculation.
* </pre>
*
* @return Center point coordinates
* @throws UnexpectedCentroidException Basically impossible
*/
public List<Double> getCentroid() {
return IntStream
.range(0, dimension)
.boxed()
.map(i -> points.stream().mapToDouble(point -> point.get(i)).average().orElseThrow(UnexpectedCentroidException::new))
.collect(Collectors.toList());
}
}
And test
/**
* points :0 dimension/2D x 3 elements
*/
static Stream<Arguments> testGetCentroidProvider() {
return Stream.of(
//@formatter:off
Arguments.of(Collections.emptyList(), Collections.emptyList()),
Arguments.of(asList(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), Collections.emptyList()),
Arguments.of(asList(asList(2.0, -4.0), asList(1.0, 0.0), asList(6.0, 1.0)), asList(3.0, -1.0))
//@formatter:on
);
}
@ParameterizedTest
@MethodSource("testGetCentroidProvider")
@DisplayName("Center calculation test")
void testGetCentroid(List<List<Double>> points, List<Double> centroid) {
EuclideanSpace space = new EuclideanSpace(points);
Assertions.assertEquals(centroid, space.getCentroid());
}
That's all for today. Now that we have all the parts, we will finally explain the main logic of the k-means method next time.
The version explained this time is left on github with tag, so please have a look if you like. https://github.com/tan3nonsugar/clusteringsample/releases/tag/v0.0.3
[^ 1]: There are other types of definition of "middle" such as median, but since the story is cluttered, the story is fixed to the arithmetic mean.
Recommended Posts