weka.classifiers.meta.ensembleSelection
Class ModelBag

java.lang.Object
  extended by weka.classifiers.meta.ensembleSelection.ModelBag
All Implemented Interfaces:
RevisionHandler

public class ModelBag
extends java.lang.Object
implements RevisionHandler

This class is responsible for the duties of a bag of models. It is designed for use with the EnsembleSelection meta classifier. It handles shuffling the models, doing sort initialization, performing forward selection/ backwards elimination, etc.

We utilize a simple "virtual indexing" scheme inside. If we shuffle and/or sort the models, we change the "virtual" order around. The elements of the bag are always those elements with virtual index 0..(m_bagSize-1). Each "virtual" index maps to some real index in m_models. Not every model in m_models gets a virtual index... the virtual indexing is what defines the subset of models of which our Bag is composed. This makes it easy to refer to models in the bag, by their virtual index, while maintaining the original indexing for our clients.

Version:
$Revision: 1.2 $
Author:
David Michael

Constructor Summary
ModelBag(double[][][] models, double bag_percent, boolean debug)
          Constructor for ModelBag.
 
Method Summary
 void backwardEliminate(Instances instances, int metric)
          Find the model whose removal will help the ensemble's performance the most, and remove it.
 void forwardSelect(boolean withReplacement, Instances instances, int metric)
          Forward select one model.
 void forwardSelectOrBackwardEliminate(boolean with_replacement, Instances instances, int metric)
          Find the best action to perform, be it adding a model or removing a model, and perform it.
 double[] getIndividualPerformance(Instances instances, int metric)
          Gets the individual performances of all the models in the bag.
 int[] getModelWeights()
          returns the model weights
 java.lang.String getRevision()
          Returns the revision string.
 void shuffle(java.util.Random rand)
          Shuffle the models.
 int[] sortInitialize(int num, boolean greedy, Instances instances, int metric)
          Sort initialize the bag.
 void weightAll(int weight)
          Add "weight" to the number of times each model in the bag was chosen.
 
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

ModelBag

public ModelBag(double[][][] models,
                double bag_percent,
                boolean debug)
Constructor for ModelBag.

Parameters:
models - The complete set of models from which to draw our bag. First index is for the model, second is for the instance. The last is a prediction distribution for that instance. Models are represented by this array of predictions for validation data, since that's all ensemble selection needs to know.
bag_percent - The percentage of the set of given models that should be used in the Model Bag.
debug - Whether the ModelBag should print debug information.
Method Detail

shuffle

public void shuffle(java.util.Random rand)
Shuffle the models. The order in m_models is preserved, but we change our virtual indexes around.

Parameters:
rand - the random number generator to use

sortInitialize

public int[] sortInitialize(int num,
                            boolean greedy,
                            Instances instances,
                            int metric)
                     throws java.lang.Exception
Sort initialize the bag.

Parameters:
num - the Maximum number of models to initialize with
greedy - True if we do greedy addition, up to num. Greedy sort initialization adds models (up to num) in order of best to worst performance until performance no longer improves.
instances - the data set (needed for performance evaluation)
metric - metric for which to optimize. See EnsembleMetricHelper
Returns:
returns an array of indexes which were selected, in order starting from the model with best performance.
Throws:
java.lang.Exception - if something goes wrong

weightAll

public void weightAll(int weight)
Add "weight" to the number of times each model in the bag was chosen. Typically for use with backward elimination.

Parameters:
weight - the weight to add

forwardSelect

public void forwardSelect(boolean withReplacement,
                          Instances instances,
                          int metric)
                   throws java.lang.Exception
Forward select one model. Will add the model which has the best effect on performance. If replacement is false, and all models are chosen, no action is taken. If a model can be added, one always is (even if it hurts performance).

Parameters:
withReplacement - whether a model can be added more than once.
instances - The dataset, for calculating performance.
metric - The metric to which we will optimize. See EnsembleMetricHelper
Throws:
java.lang.Exception - if something goes wrong

backwardEliminate

public void backwardEliminate(Instances instances,
                              int metric)
                       throws java.lang.Exception
Find the model whose removal will help the ensemble's performance the most, and remove it. If there is only one model left, we leave it in. If we can remove a model, we always do, even if it hurts performance.

Parameters:
instances - The data set, for calculating performance
metric - Metric to optimize for. See EnsembleMetricHelper.
Throws:
java.lang.Exception - if something goes wrong

forwardSelectOrBackwardEliminate

public void forwardSelectOrBackwardEliminate(boolean with_replacement,
                                             Instances instances,
                                             int metric)
                                      throws java.lang.Exception
Find the best action to perform, be it adding a model or removing a model, and perform it. Some action is always performed, even if it hurts performance.

Parameters:
with_replacement - whether we can add a model more than once
instances - The dataset, for determining performance.
metric - The metric for which to optimize. See EnsembleMetricHelper.
Throws:
java.lang.Exception - if something goes wrong

getModelWeights

public int[] getModelWeights()
returns the model weights

Returns:
the model weights

getIndividualPerformance

public double[] getIndividualPerformance(Instances instances,
                                         int metric)
                                  throws java.lang.Exception
Gets the individual performances of all the models in the bag.

Parameters:
instances - The validation data, for which we want performance.
metric - The desired metric (see EnsembleMetricHelper).
Returns:
the performance
Throws:
java.lang.Exception - if something goes wrong

getRevision

public java.lang.String getRevision()
Returns the revision string.

Specified by:
getRevision in interface RevisionHandler
Returns:
the revision