# Robust Model Fitting¶

Robust model fitting attempts to find the best fit model to observations under the assumption that a few of the observations are generated by noise. If those noisy observations are included in standard model fitting approaches the final solution will be extremely inaccurate. Thus a robust model fitting algorithm finds the best fit parameters and the set of observations which are not generated by noise. Please checkout all the example code since this example reilies on additional classes directory.

ExampleRobustModelFit.java

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 public static void main( String args[] ) { Random rand = new Random(234); //------------------------ Create Observations // define a line in 2D space as the tangent from the origin double lineX = -2.1; double lineY = 1.3; List points = generateObservations(rand, lineX, lineY); //------------------------ Compute the solution // Let it know how to compute the model and fit errors ModelManager manager = new LineManager(); ModelGenerator generator = new LineGenerator(); DistanceFromModel distance = new DistanceFromLine(); // RANSAC or LMedS work well here ModelMatcher alg = new Ransac<>(234234, manager, generator, distance, 500, 0.01); // ModelMatcher alg = // new LeastMedianOfSquares(234234,100,0.1,0.5,generator,distance); if( !alg.process(points) ) throw new RuntimeException("Robust fit failed!"); // let's look at the results Line2D found = alg.getModelParameters(); // notice how all the noisy points were removed and an accurate line was estimated? System.out.println("Found line "+found); System.out.println("Actual line x = "+lineX+" y = "+lineY); System.out.println("Match set size = "+alg.getMatchSet().size()); } private static List generateObservations(Random rand, double lineX, double lineY) { // randomly generate points along the line List points = new ArrayList(); for( int i = 0; i < 20; i++ ) { double t = (rand.nextDouble()-0.5)*10; points.add( new Point2D(lineX + t*lineY, lineY - t*lineX) ); } // Add in some random points for( int i = 0; i < 5; i++ ) { points.add( new Point2D(rand.nextGaussian()*10,rand.nextGaussian()*10)); } // Shuffle the list to remove any structure Collections.shuffle(points); return points; } 

Below are all the support classes which define the model being fitted and perform data management.

  1 2 3 4 5 6 7 8 9 10 public class Line2D { /** * Coordinate of the closest point on the line to the origin. */ double x,y; public String toString() { return "Line2D( x="+x+" y="+y+" )"; } } 
  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 public class DistanceFromLine implements DistanceFromModel { // parametric line equation double x0, y0; double slopeX,slopeY; @Override public void setModel(Line2D param) { x0 = param.x; y0 = param.y; slopeX = -y0; slopeY = x0; } @Override public double computeDistance(Point2D p) { // find the closest point on the line to the point double t = slopeX * ( p.x - x0) + slopeY * ( p.y - y0); t /= slopeX * slopeX + slopeY * slopeY; double closestX = x0 + t*slopeX; double closestY = y0 + t*slopeY; // compute the Euclidean distance double dx = p.x - closestX; double dy = p.y - closestY; return Math.sqrt(dx*dx + dy*dy); } /** * There are some situations where processing everything as a list can speed things up a lot. * This is not one of them. */ @Override public void computeDistance(List obs, double[] distance) { for( int i = 0; i < obs.size(); i++ ) { distance[i] = computeDistance(obs.get(i)); } } @Override public Class getPointType() { return Point2D.class; } @Override public Class getModelType() { return Line2D.class; } } 
  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 public class LineGenerator implements ModelGenerator { // a point at the origin (0,0) Point2D origin = new Point2D(); @Override public boolean generate(List dataSet, Line2D output) { Point2D p1 = dataSet.get(0); Point2D p2 = dataSet.get(1); // First find the slope of the line double slopeX = p2.x - p1.x; double slopeY = p2.y - p1.y; // Now that we have the slope, all we need is a line on the point (we pick p1) to find // the closest point on the line to the origin. This closest point is the parametrization. double t = slopeX * ( origin.x - p1.x) + slopeY * ( origin.y - p1.y); t /= slopeX * slopeX + slopeY * slopeY; output.x = p1.x + t*slopeX; output.y = p1.y + t*slopeY; return true; } @Override public int getMinimumPoints() { return 2; } }