Robust Model FittingΒΆ

Robust model fitting attempts to find the best fit model to observations under the assumption that a few of the observations are generated by noise. If those noisy observations are included in standard model fitting approaches the final solution will be extremely inaccurate. Thus a robust model fitting algorithm finds the best fit parameters and the set of observations which are not generated by noise. Please checkout all the example code since this example reilies on additional classes directory.

ExampleRobustModelFit.java

 1 public static void main( String args[] ) {
 2     Random rand = new Random(234);
 3 
 4     //------------------------ Create Observations
 5     // define a line in 2D space as the tangent from the origin
 6     double lineX = -2.1;
 7     double lineY = 1.3;
 8     List<Point2D> points = generateObservations(rand, lineX, lineY);
 9 
10     //------------------------ Compute the solution
11     // Let it know how to compute the model and fit errors
12     ModelManager<Line2D> manager = new LineManager();
13 
14     // RANSAC or LMedS work well here
15     ModelMatcherPost<Line2D,Point2D> alg = new Ransac<>(234234, 500, 0.01, manager, Point2D.class);
16     alg.setModel(LineGenerator::new, DistanceFromLine::new);
17 
18     // ModelMatcher<Line2D,Point2D> alg =
19     //      new LeastMedianOfSquares<Line2D, Point2D>(234234,100,0.1,0.5,generator,distance);
20 
21     if( !alg.process(points) )
22         throw new RuntimeException("Robust fit failed!");
23 
24     // let's look at the results
25     Line2D found = alg.getModelParameters();
26 
27     // notice how all the noisy points were removed and an accurate line was estimated?
28     System.out.println("Found line   "+found);
29     System.out.println("Actual line   x = "+lineX+" y = "+lineY);
30     System.out.println("Match set size = "+alg.getMatchSet().size());
31 }
32 
33 private static List<Point2D> generateObservations(Random rand, double lineX, double lineY) {
34     // randomly generate points along the line
35     List<Point2D> points = new ArrayList<Point2D>();
36     for( int i = 0; i < 20; i++ ) {
37         double t = (rand.nextDouble()-0.5)*10;
38         points.add( new Point2D(lineX + t*lineY, lineY - t*lineX) );
39     }
40 
41     // Add in some random points
42     for( int i = 0; i < 5; i++ ) {
43         points.add( new Point2D(rand.nextGaussian()*10,rand.nextGaussian()*10));
44     }
45 
46     // Shuffle the list to remove any structure
47     Collections.shuffle(points);
48     return points;
49 }

Below are all the support classes which define the model being fitted and perform data management.

 1 public class Line2D {
 2     /**
 3      * Coordinate of the closest point on the line to the origin.
 4      */
 5     double x,y;
 6 
 7     @Override
 8     public String toString() {
 9         return "Line2D( x="+x+" y="+y+" )";
10     }
11 }
 1 public class DistanceFromLine implements DistanceFromModel<Line2D, Point2D> {
 2 
 3     // parametric line equation
 4     double x0, y0;
 5     double slopeX, slopeY;
 6 
 7     @Override
 8     public void setModel( Line2D param ) {
 9         x0 = param.x;
10         y0 = param.y;
11 
12         slopeX = -y0;
13         slopeY = x0;
14     }
15 
16     @Override
17     public double distance( Point2D p ) {
18 
19         // find the closest point on the line to the point
20         double t = slopeX*(p.x - x0) + slopeY*(p.y - y0);
21         t /= slopeX*slopeX + slopeY*slopeY;
22 
23         double closestX = x0 + t*slopeX;
24         double closestY = y0 + t*slopeY;
25 
26         // compute the Euclidean distance
27         double dx = p.x - closestX;
28         double dy = p.y - closestY;
29 
30         return Math.sqrt(dx*dx + dy*dy);
31     }
32 
33     /**
34      * There are some situations where processing everything as a list can speed things up a lot.
35      * This is not one of them.
36      */
37     @Override
38     public void distances( List<Point2D> obs, double[] distance ) {
39         for (int i = 0; i < obs.size(); i++) {
40             distance[i] = distance(obs.get(i));
41         }
42     }
43 
44     @Override
45     public Class<Point2D> getPointType() {
46         return Point2D.class;
47     }
48 
49     @Override
50     public Class<Line2D> getModelType() {
51         return Line2D.class;
52     }
53 }
 1 public class LineGenerator implements ModelGenerator<Line2D, Point2D> {
 2 
 3     // a point at the origin (0,0)
 4     Point2D origin = new Point2D();
 5 
 6     @Override
 7     public boolean generate( List<Point2D> dataSet, Line2D output ) {
 8         Point2D p1 = dataSet.get(0);
 9         Point2D p2 = dataSet.get(1);
10 
11         // First find the slope of the line
12         double slopeX = p2.x - p1.x;
13         double slopeY = p2.y - p1.y;
14 
15         // Now that we have the slope, all we need is a line on the point (we pick p1) to find
16         // the closest point on the line to the origin. This closest point is the parametrization.
17         double t = slopeX*(origin.x - p1.x) + slopeY*(origin.y - p1.y);
18         t /= slopeX*slopeX + slopeY*slopeY;
19 
20         output.x = p1.x + t*slopeX;
21         output.y = p1.y + t*slopeY;
22 
23         return true;
24     }
25 
26     @Override
27     public int getMinimumPoints() {
28         return 2;
29     }
30 }