Robust Model FittingΒΆ
Robust model fitting attempts to find the best fit model to observations under the assumption that a few of the observations are generated by noise. If those noisy observations are included in standard model fitting approaches the final solution will be extremely inaccurate. Thus a robust model fitting algorithm finds the best fit parameters and the set of observations which are not generated by noise. Please checkout all the example code since this example reilies on additional classes directory.
1public static void main( String[] args ) {
2 Random rand = new Random(234);
3
4 //------------------------ Create Observations
5 // define a line in 2D space as the tangent from the origin
6 double lineX = -2.1;
7 double lineY = 1.3;
8 List<Point2D> points = generateObservations(rand, lineX, lineY);
9
10 //------------------------ Compute the solution
11 // Let it know how to compute the model and fit errors
12 ModelManager<Line2D> manager = new LineManager();
13
14 // RANSAC or LMedS work well here
15 ModelMatcherPost<Line2D,Point2D> alg = new Ransac<>(234234, 500, 0.01, manager, Point2D.class);
16 alg.setModel(LineGenerator::new, DistanceFromLine::new);
17
18 // ModelMatcher<Line2D,Point2D> alg =
19 // new LeastMedianOfSquares<Line2D, Point2D>(234234,100,0.1,0.5,generator,distance);
20
21 if( !alg.process(points) )
22 throw new RuntimeException("Robust fit failed!");
23
24 // let's look at the results
25 Line2D found = alg.getModelParameters();
26
27 // notice how all the noisy points were removed and an accurate line was estimated?
28 System.out.println("Found line "+found);
29 System.out.println("Actual line x = "+lineX+" y = "+lineY);
30 System.out.println("Match set size = "+alg.getMatchSet().size());
31}
32
33private static List<Point2D> generateObservations(Random rand, double lineX, double lineY) {
34 // randomly generate points along the line
35 List<Point2D> points = new ArrayList<Point2D>();
36 for( int i = 0; i < 20; i++ ) {
37 double t = (rand.nextDouble()-0.5)*10;
38 points.add( new Point2D(lineX + t*lineY, lineY - t*lineX) );
39 }
40
41 // Add in some random points
42 for( int i = 0; i < 5; i++ ) {
43 points.add( new Point2D(rand.nextGaussian()*10,rand.nextGaussian()*10));
44 }
45
46 // Shuffle the list to remove any structure
47 Collections.shuffle(points);
48 return points;
49}
Below are all the support classes which define the model being fitted and perform data management.
1public class Line2D {
2 /**
3 * Coordinate of the closest point on the line to the origin.
4 */
5 double x, y;
6
7 @Override
8 public String toString() {
9 return "Line2D( x=" + x + " y=" + y + " )";
10 }
11}
1public class DistanceFromLine implements DistanceFromModel<Line2D, Point2D> {
2
3 // parametric line equation
4 double x0, y0;
5 double slopeX, slopeY;
6
7 @Override
8 public void setModel( Line2D param ) {
9 x0 = param.x;
10 y0 = param.y;
11
12 slopeX = -y0;
13 slopeY = x0;
14 }
15
16 @Override
17 public double distance( Point2D p ) {
18
19 // find the closest point on the line to the point
20 double t = slopeX*(p.x - x0) + slopeY*(p.y - y0);
21 t /= slopeX*slopeX + slopeY*slopeY;
22
23 double closestX = x0 + t*slopeX;
24 double closestY = y0 + t*slopeY;
25
26 // compute the Euclidean distance
27 double dx = p.x - closestX;
28 double dy = p.y - closestY;
29
30 return Math.sqrt(dx*dx + dy*dy);
31 }
32
33 /**
34 * There are some situations where processing everything as a list can speed things up a lot.
35 * This is not one of them.
36 */
37 @Override
38 public void distances( List<Point2D> obs, double[] distance ) {
39 for (int i = 0; i < obs.size(); i++) {
40 distance[i] = distance(obs.get(i));
41 }
42 }
43
44 @Override
45 public Class<Point2D> getPointType() {
46 return Point2D.class;
47 }
48
49 @Override
50 public Class<Line2D> getModelType() {
51 return Line2D.class;
52 }
53}
1 // a point at the origin (0,0)
2 Point2D origin = new Point2D();
3
4 @Override
5 public boolean generate( List<Point2D> dataSet, Line2D output ) {
6 Point2D p1 = dataSet.get(0);
7 Point2D p2 = dataSet.get(1);
8
9 // First find the slope of the line
10 double slopeX = p2.x - p1.x;
11 double slopeY = p2.y - p1.y;
12
13 // Now that we have the slope, all we need is a line on the point (we pick p1) to find
14 // the closest point on the line to the origin. This closest point is the parametrization.
15 double t = slopeX*(origin.x - p1.x) + slopeY*(origin.y - p1.y);
16 t /= slopeX*slopeX + slopeY*slopeY;
17
18 output.x = p1.x + t*slopeX;
19 output.y = p1.y + t*slopeY;
20
21 return true;
22 }
23
24 @Override
25 public int getMinimumPoints() {
26 return 2;
27 }
28}