Robust Model FittingΒΆ

Robust model fitting attempts to find the best fit model to observations under the assumption that a few of the observations are generated by noise. If those noisy observations are included in standard model fitting approaches the final solution will be extremely inaccurate. Thus a robust model fitting algorithm finds the best fit parameters and the set of observations which are not generated by noise. Please checkout all the example code since this example reilies on additional classes directory.

ExampleRobustModelFit.java

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
public static void main( String args[] ) {
    Random rand = new Random(234);

    //------------------------ Create Observations
    // define a line in 2D space as the tangent from the origin
    double lineX = -2.1;
    double lineY = 1.3;
    List<Point2D> points = generateObservations(rand, lineX, lineY);

    //------------------------ Compute the solution
    // Let it know how to compute the model and fit errors
    ModelManager<Line2D> manager = new LineManager();
    ModelGenerator<Line2D,Point2D> generator = new LineGenerator();
    DistanceFromModel<Line2D,Point2D> distance = new DistanceFromLine();

    // RANSAC or LMedS work well here
    ModelMatcher<Line2D,Point2D> alg =
            new Ransac<>(234234, manager, generator, distance, 500, 0.01);
    // ModelMatcher<Line2D,Point2D> alg =
    //      new LeastMedianOfSquares<Line2D, Point2D>(234234,100,0.1,0.5,generator,distance);

    if( !alg.process(points) )
        throw new RuntimeException("Robust fit failed!");

    // let's look at the results
    Line2D found = alg.getModelParameters();

    // notice how all the noisy points were removed and an accurate line was estimated?
    System.out.println("Found line   "+found);
    System.out.println("Actual line   x = "+lineX+" y = "+lineY);
    System.out.println("Match set size = "+alg.getMatchSet().size());
}

private static List<Point2D> generateObservations(Random rand, double lineX, double lineY) {
    // randomly generate points along the line
    List<Point2D> points = new ArrayList<Point2D>();
    for( int i = 0; i < 20; i++ ) {
        double t = (rand.nextDouble()-0.5)*10;
        points.add( new Point2D(lineX + t*lineY, lineY - t*lineX) );
    }

    // Add in some random points
    for( int i = 0; i < 5; i++ ) {
        points.add( new Point2D(rand.nextGaussian()*10,rand.nextGaussian()*10));
    }

    // Shuffle the list to remove any structure
    Collections.shuffle(points);
    return points;
}

Below are all the support classes which define the model being fitted and perform data management.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
public class Line2D {
    /**
     * Coordinate of the closest point on the line to the origin.
     */
    double x,y;

    public String toString() {
        return "Line2D( x="+x+" y="+y+" )";
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
public class DistanceFromLine implements DistanceFromModel<Line2D,Point2D> {

    // parametric line equation
    double x0, y0;
    double slopeX,slopeY;

    @Override
    public void setModel(Line2D param) {
        x0 = param.x;
        y0 = param.y;

        slopeX = -y0;
        slopeY = x0;
    }

    @Override
    public double computeDistance(Point2D p) {

        // find the closest point on the line to the point
        double t = slopeX * ( p.x - x0) + slopeY * ( p.y - y0);
        t /= slopeX * slopeX + slopeY * slopeY;

        double closestX = x0 + t*slopeX;
        double closestY = y0 + t*slopeY;

        // compute the Euclidean distance
        double dx = p.x - closestX;
        double dy = p.y - closestY;

        return Math.sqrt(dx*dx + dy*dy);
    }

    /**
     * There are some situations where processing everything as a list can speed things up a lot.
     * This is not one of them.
     */
    @Override
    public void computeDistance(List<Point2D> obs, double[] distance) {
        for( int i = 0; i < obs.size(); i++ ) {
            distance[i] = computeDistance(obs.get(i));
        }
    }

    @Override
    public Class<Point2D> getPointType() {
        return Point2D.class;
    }

    @Override
    public Class<Line2D> getModelType() {
        return Line2D.class;
    }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
public class LineGenerator implements ModelGenerator<Line2D,Point2D> {

    // a point at the origin (0,0)
    Point2D origin = new Point2D();

    @Override
    public boolean generate(List<Point2D> dataSet, Line2D output) {
        Point2D p1 = dataSet.get(0);
        Point2D p2 = dataSet.get(1);

        // First find the slope of the line
        double slopeX = p2.x - p1.x;
        double slopeY = p2.y - p1.y;

        // Now that we have the slope, all we need is a line on the point (we pick p1) to find
        // the closest point on the line to the origin. This closest point is the parametrization.
        double t = slopeX * ( origin.x - p1.x) + slopeY * ( origin.y - p1.y);
        t /= slopeX * slopeX + slopeY * slopeY;

        output.x = p1.x + t*slopeX;
        output.y = p1.y + t*slopeY;

        return true;
    }

    @Override
    public int getMinimumPoints() {
        return 2;
    }
}