Download LinearClassifier / Read the documentation / LinearClassifier source on Github

This library implements linear classification for numerical data. It accepts two labeled data sets, each element of which can contain two or more floats as data. Once trained, LinearClassifier can predict which of the two original sets any new piece of test data will fall into. LinearClassifier also includes tools for displaying data in the 2D case. See the 2D example for details.

LinearClassifier is based on examples in Chapter 9 of Programming Collective by Toby Segaran, which is highly recommended.

Installation and Usage

To install Isolines, download the library and unzip it into your Processing libraries folder. Restart Processing and LinearClassifier should show up under the “File > Examples” menu under “Contributed Libraries”.

2D Classification

Linear Classifier in Processing

The process of linear classification is best illustrated by the 2D case. LinearClassification works by finding the average point for each of the two labeled data sets. It then draws a straight line between these two points. At the point halfway along that line, it constructs a second line perpendicular to the first. This second line divides the two data sets. To predict which set contains a new, un-labeled, point, LinearClassifier simply checks to see which side of the line the point falls on.

This example uses a sample of height and weight data labeled by gender. It displays the average points for the men and women as well as the connecting line and the dividing line. It then uses LinearClassifier to classify the point represented by the mouse position and displays the results of that classification. The data for this example is available here or as an example accompanying the LinearClassifier library download.

import linearclassifier.*;
import processing.data.*;

Table data;
LinearClassifier classifier;

void setup() {
  size(500, 500);
  
  // load the data and automatically parse the csv
  data = new Table(this, "height_and_weight.csv");
  
  classifier = new LinearClassifier(this);
  
  ArrayList<ArrayList<Float>> men = new ArrayList<ArrayList<Float>>();
  ArrayList<ArrayList<Float>> women = new ArrayList<ArrayList<Float>>();
  
  data.removeTitleRow();
  
  // iterate through the data in the csv
  for(TableRow row : data){
    // put the height and weight data into an ArrayList
    ArrayList<Float> entry = new ArrayList<Float>(); 
    
    entry.add(row.getFloat(1));
    entry.add(row.getFloat(2));    
    
    // based on the f/m column
    // put the Pvector in the right ArrayList for men or women
    if(row.getString(3).equals("f")){
      women.add(entry);
    } else {
      men.add(entry);
    }
  }
  
  // pass the data to the classifier
  classifier.loadSet1(men);
  classifier.loadSet2(women);
  
  ArrayList<PVector> scales = new ArrayList<PVector>();
  scales.add(new PVector(0, 500));
  scales.add(new PVector(500, 0));
  classifier.setOutputScale(scales);  
}

void draw(){
  background(255);
  noStroke();
  
  // display the data
  fill(0,0,255);
  classifier.drawSet1();
  fill(255,0,0);
  classifier.drawSet2();

  // get the average of each set
  ArrayList<Float> set1Ave = classifier.getSet1Average();
  ArrayList<Float> set2Ave = classifier.getSet2Average();
  
  // draw them as black boxes
  fill(0);
  rectMode(CENTER);
  rect(set1Ave.get(0), set1Ave.get(1), 10, 10);
  text("Men\nAverage", set1Ave.get(0)+10, set1Ave.get(1)+10);
  rect(set2Ave.get(0), set2Ave.get(1), 10, 10);
  text("Women\nAverage", set2Ave.get(0)+10, set2Ave.get(1)+10);

  // draw a line connecting the two averages
  // and a point at the center of that line
  fill(0,255,0);
  stroke(0,255,0);
  line(set1Ave.get(0), set1Ave.get(1), set2Ave.get(0), set2Ave.get(1));
  ellipse(classifier.getCenterPoint().x, classifier.getCenterPoint().y, 10, 10);
  
  // draw a line perpendicualr to the line
  // between the cetner points
  // this line should divide the two sets after classification
  stroke(0);
  drawPerpindicularLine(set1Ave.get(0), set1Ave.get(1), set2Ave.get(0), set2Ave.get(1));
  
  noStroke();
  if(classifier.isInSet1( new PVector(mouseX, mouseY))){
    fill(0,0,255);
  } else {
    fill(255,0,0);
  }
  
  ellipse(mouseX, mouseY, 10, 10);  
 
  PVector p = classifier.getUnscaledPoint(new PVector(mouseX, mouseY));
  text("H: " + round(p.x) + "\"\nW: " + round(p.y) + "lbs", mouseX+7, mouseY+7);
}

void drawPerpindicularLine(float x1, float y1, float x2, float y2){
  
  PVector axis = PVector.sub(new PVector(x1, y1), new PVector(x2, y2));
  PVector perp = axis.cross(new PVector(0,0,1));

  perp.setMag(500);

  PVector lineStart = PVector.sub(classifier.getCenterPoint(), perp);
  PVector lineEnd = PVector.add(classifier.getCenterPoint(), perp);
  
  line(lineStart.x, lineStart.y, lineEnd.x, lineEnd.y);
}

Matchmaker: Multi-Variable Classification

This example demonstrates the use of LinearClassifier on more complex mult-dimensional data, including the transformation of text and yes/no data into numerical data as well as the scaling of the data set. It is a matchmaking example: each entry contains data about a pair of people: their interests, ages, physical proximity, etc. Each entry is marked with whether or not a person judged it a good match. The classifier is then tested against random examples from the data set to see if it matches the predictions of the people. (Download the data here or get it with the example that accompanies the LinearClassifier download.)

LinearClassifier matchmaking on 8 dimensions in Processing

import linearclassifier.*;
import processing.data.*;

Table data;
LinearClassifier classifier;

void setup() {
  size(500, 500);

  // load the data and automatically parse the csv
  data = new Table(this, "matchmaker.csv");

  classifier = new LinearClassifier(this);

  ArrayList<ArrayList<Float>> matches = new ArrayList<ArrayList<Float>>();
  ArrayList<ArrayList<Float>> noMatches = new ArrayList<ArrayList<Float>>();

  // iterate through the data in the csv
  for (TableRow row : data) {
    ArrayList<Float> entry = new ArrayList<Float>();

    entry.add(row.getFloat(0)); // age1
    entry.add(parseYesNo(row.getString(1)));
    entry.add(parseYesNo(row.getString(2)));

    entry.add(row.getFloat(5)); // age2
    entry.add(parseYesNo(row.getString(6)));
    entry.add(parseYesNo(row.getString(7)));

    entry.add(matchCount( row.getString(3), row.getString(8) ));
    entry.add(row.getFloat(11)); // distance

    if (row.getInt(10) == 1) {
      matches.add(entry);
    } 
    else {
      noMatches.add(entry);
    }
  }

  // pass the data to the classifier
  classifier.loadSet1(matches);
  classifier.loadSet2(noMatches);

  // scale all the data to be between 0 and 1
  // to give each component equal weight
  classifier.scaleData(0,1);
  displayNewResults();
}

// picks 5 entries randomly from both set 1 and set 2
// runs them through the classifier to see if it
// categorizes them correctly
// all examples from set 2 should be "no match", i.e. red
// and all from set 1 should be "match", i.e. green
// the classifier won't always be right but should be most of the time
// this gets called in keyPressed() to display new data
void displayNewResults() {
  background(255);
  fill(0);
  text("Set 1: These should all be matches", 10, 20);

  for (int i = 0; i < 5; i ++) {
    int j = (int)random(classifier.set1.size());
    ArrayList<Float> entry = classifier.set1.get(j);
    if (classifier.isInSet1(entry)) {
      fill(0, 255, 0);
    } 
    else {
      fill(255, 0, 0);
    }
    text("["+ toString(entry) + "]\nPredicted match? " + classifier.isInSet1(entry), 10, 40 + i*35);
  }

  fill(0);
  text("Set 2: None of these should be matches", 10, 40 + 5*35 + 20);

  for (int i = 0; i < 5; i ++) {
    int j = (int)random(classifier.set2.size());
    ArrayList<Float> entry = classifier.set2.get(j);
    if (classifier.isInSet2(entry)) {
      fill(0, 255, 0);
    } 
    else {
      fill(255, 0, 0);
    }
    text("["+ toString(entry) + "]\nPredicted match? " + classifier.isInSet2(entry), 10, 40 + 5*35 + 40 + i*35);
  }
}

void draw() {
  fill(0);
  text("(PRESS ANY KEY FOR MORE DATA)", 10, height-20);
}

void keyPressed(){
  displayNewResults();
}

// helper function to turn a "yes" or "no"
// string into a -1 or 1 for numerical comparison
float parseYesNo(String s) {
  if (s.equals("yes")) {
    return 1;
  } 
  else if (s.equals("no")) {
    return -1;
  } 
  else {
    return 0;
  }
}

// helper function to turn the list of interests
// into a float based on how many are shared
float matchCount(String interests1, String interests2) {
  String[] i1 = split(interests1, ":");
  String[] i2 = split(interests2, ":");

  float result = 0;
  for (int i = 0; i < i1.length; i++) {
    if (i < i2.length && i1[i].equals(i2[i])) {
      result = result + 1;
    }
  }

  return result;
}

// helper function to display an ArrayList as a string
String toString(ArrayList<Float> a) {
  String s = "";
  for (int i = 0; i < a.size(); i++) {
    float e = a.get(i);
    if (i == a.size() - 1) {
      s = s + e;
    } 
    else {
      s = s + e + ", ";
    }
  }
  return s;
}