[shark] 57/58: Trees: renamed splitmatrix → tree, and splitinfo → nodeinfo
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:34 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository shark.
commit f32349daf0e780881ed1480f52c3309e4bb8fe93
Author: Jakob Wrigley <wrigleyster at gmail.com>
Date: Sat Nov 28 21:00:43 2015 +0100
Trees: renamed splitmatrix → tree, and splitinfo → nodeinfo
---
include/shark/Algorithms/Trainers/CARTTrainer.h | 18 +--
include/shark/Algorithms/Trainers/RFTrainer.h | 4 +-
include/shark/Models/Trees/CARTClassifier.h | 76 ++++-----
include/shark/Models/Trees/RFClassifier.h | 6 +-
src/Algorithms/CARTTrainer.cpp | 202 ++++++++++++------------
src/Algorithms/RFTrainer.cpp | 126 +++++++--------
6 files changed, 216 insertions(+), 216 deletions(-)
diff --git a/include/shark/Algorithms/Trainers/CARTTrainer.h b/include/shark/Algorithms/Trainers/CARTTrainer.h
index 4fc61a4..b810566 100644
--- a/include/shark/Algorithms/Trainers/CARTTrainer.h
+++ b/include/shark/Algorithms/Trainers/CARTTrainer.h
@@ -108,7 +108,7 @@ protected:
typedef std::vector < TableEntry > AttributeTable;
typedef std::vector < AttributeTable > AttributeTables;
- typedef ModelType::SplitMatrixType SplitMatrixType;
+ typedef ModelType::TreeType TreeType;
///Number of attributes in the dataset
@@ -129,7 +129,7 @@ protected:
//Classification functions
///Builds a single decision tree from a classification dataset
///The method requires the attribute tables,
- SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
+ SHARK_EXPORT_SYMBOL TreeType buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
///Calculates the Gini impurity of a node. The impurity is defined as
///1-sum_j p(j|t)^2
@@ -139,22 +139,22 @@ protected:
SHARK_EXPORT_SYMBOL RealVector hist(boost::unordered_map<std::size_t, std::size_t> countMatrix);
///Regression functions
- SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize);
+ SHARK_EXPORT_SYMBOL TreeType buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize);
///Calculates the total sum of squares
SHARK_EXPORT_SYMBOL double totalSumOfSquares(std::vector<RealVector> const& labels, std::size_t start, std::size_t length, RealVector const& sumLabel);
///Calculates the mean of a vector of labels
SHARK_EXPORT_SYMBOL RealVector mean(std::vector<RealVector> const& labels);
///Pruning
- ///Prunes decision tree, represented by a split matrix
- SHARK_EXPORT_SYMBOL void pruneMatrix(SplitMatrixType& splitMatrix);
+ ///Prunes decision tree
+ SHARK_EXPORT_SYMBOL void pruneTree(TreeType & tree);
///Prunes a single node, including the child nodes of the decision tree
- SHARK_EXPORT_SYMBOL void pruneNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
+ SHARK_EXPORT_SYMBOL void pruneNode(TreeType & tree, std::size_t nodeId);
///Updates the node variables used in the cost complexity pruning stage
- SHARK_EXPORT_SYMBOL void measureStrenght(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNodeId);
+ SHARK_EXPORT_SYMBOL void measureStrength(TreeType & tree, std::size_t nodeId, std::size_t parentNodeId);
- ///Returns the index of the node with node id in splitMatrix.
- SHARK_EXPORT_SYMBOL std::size_t findNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
+ ///Returns the index of the node with node id in tree.
+ SHARK_EXPORT_SYMBOL std::size_t findNode(TreeType & tree, std::size_t nodeId);
///Attribute table functions
///Create the attribute tables used by the SPRINT algorithm
diff --git a/include/shark/Algorithms/Trainers/RFTrainer.h b/include/shark/Algorithms/Trainers/RFTrainer.h
index e51aa46..bf2ca35 100644
--- a/include/shark/Algorithms/Trainers/RFTrainer.h
+++ b/include/shark/Algorithms/Trainers/RFTrainer.h
@@ -146,10 +146,10 @@ protected:
SHARK_EXPORT_SYMBOL void splitAttributeTables(AttributeTables const& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
/// Build a decision tree for classification
- SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
+ SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::TreeType buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
/// Builds a decision tree for regression
- SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId);
+ SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::TreeType buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId);
/// comparison function for sorting an attributeTable
SHARK_EXPORT_SYMBOL static bool tableSort(RFAttribute const& v1, RFAttribute const& v2);
diff --git a/include/shark/Models/Trees/CARTClassifier.h b/include/shark/Models/Trees/CARTClassifier.h
index 6755781..e4da36d 100644
--- a/include/shark/Models/Trees/CARTClassifier.h
+++ b/include/shark/Models/Trees/CARTClassifier.h
@@ -63,7 +63,7 @@ public:
typedef typename base_type::BatchInputType BatchInputType;
typedef typename base_type::BatchOutputType BatchOutputType;
// Information about a single split. misclassProp, r and g are variables used in the cost complexity step
- struct SplitInfo{
+ struct NodeInfo {
std::size_t nodeId;
std::size_t attributeIndex;
double attributeValue;
@@ -91,31 +91,31 @@ public:
/// Vector of structs that contains the splitting information and the labels.
/// The class label is a normalized histogram in the classification case.
/// In the regression case, the label is the regression value.
- typedef std::vector<SplitInfo> SplitMatrixType;
+ typedef std::vector<NodeInfo> TreeType;
/// Constructor
CARTClassifier()
{}
- /// Constructor taking the splitMatrix as argument
- CARTClassifier(SplitMatrixType const& splitMatrix)
+ /// Constructor taking the tree as argument
+ CARTClassifier(TreeType const& tree)
{
- m_splitMatrix=splitMatrix;
+ m_tree=tree;
}
- /// Constructor taking the splitMatrix as argument and optimize it if requested
- CARTClassifier(SplitMatrixType const& splitMatrix, bool optimize)
+ /// Constructor taking the tree as argument and optimize it if requested
+ CARTClassifier(TreeType const& tree, bool optimize)
{
if (optimize)
- setSplitMatrix(splitMatrix);
+ setTree(tree);
else
- m_splitMatrix=splitMatrix;
+ m_tree=tree;
}
- /// Constructor taking the splitMatrix as argument as well as maximum number of attributes
- CARTClassifier(SplitMatrixType const& splitMatrix, std::size_t d)
+ /// Constructor taking the tree as argument as well as maximum number of attributes
+ CARTClassifier(TreeType const& tree, std::size_t d)
{
- setSplitMatrix(splitMatrix);
+ setTree(tree);
m_inputDimension = d;
}
@@ -150,15 +150,15 @@ public:
output = evalPattern(pattern);
}
- /// Set the model split matrix.
- void setSplitMatrix(SplitMatrixType const& splitMatrix){
- m_splitMatrix = splitMatrix;
- optimizeSplitMatrix(m_splitMatrix);
+ /// Set the model tree.
+ void setTree(TreeType const& tree){
+ m_tree = tree;
+ optimizeTree(m_tree);
}
/// Get the model split matrix.
- SplitMatrixType getSplitMatrix() const {
- return m_splitMatrix;
+ TreeType getTree() const {
+ return m_tree;
}
/// \brief The model does not have any parameters.
@@ -178,12 +178,12 @@ public:
/// from ISerializable, reads a model from an archive
void read(InArchive& archive){
- archive >> m_splitMatrix;
+ archive >> m_tree;
}
/// from ISerializable, writes a model to an archive
void write(OutArchive& archive) const {
- archive << m_splitMatrix;
+ archive << m_tree;
}
@@ -191,8 +191,8 @@ public:
UIntVector countAttributes() const {
SHARK_ASSERT(m_inputDimension > 0);
UIntVector r(m_inputDimension, 0);
- typename SplitMatrixType::const_iterator it;
- for(it = m_splitMatrix.begin(); it != m_splitMatrix.end(); ++it) {
+ typename TreeType::const_iterator it;
+ for(it = m_tree.begin(); it != m_tree.end(); ++it) {
//std::cout << "NodeId: " <<it->leftNodeId << std::endl;
if(it->leftNodeId != 0) { // not a label
r(it->attributeIndex)++;
@@ -316,27 +316,27 @@ public:
}
protected:
- /// split matrix of the model
- SplitMatrixType m_splitMatrix;
+ /// tree of the model
+ TreeType m_tree;
- /// \brief Finds the index of the node with a certain nodeID in an unoptimized split matrix.
+ /// \brief Finds the index of the node with a certain nodeID in an unoptimized tree.
std::size_t findNode(std::size_t nodeId) const{
std::size_t index = 0;
- for(; nodeId != m_splitMatrix[index].nodeId; ++index);
+ for(; nodeId != m_tree[index].nodeId; ++index);
return index;
}
- /// Optimize a split matrix, so constant lookup can be used.
+ /// Optimize a tree, so constant lookup can be used.
/// The optimization is done by changing the index of the children
/// to use indices instead of node ID.
/// Furthermore, the node IDs are converted to index numbers.
- void optimizeSplitMatrix(SplitMatrixType& splitMatrix) const{
- for(std::size_t i = 0; i < splitMatrix.size(); i++){
- splitMatrix[i].leftNodeId = findNode(splitMatrix[i].leftNodeId);
- splitMatrix[i].rightNodeId = findNode(splitMatrix[i].rightNodeId);
+ void optimizeTree(TreeType & tree) const{
+ for(std::size_t i = 0; i < tree.size(); i++){
+ tree[i].leftNodeId = findNode(tree[i].leftNodeId);
+ tree[i].rightNodeId = findNode(tree[i].rightNodeId);
}
- for(std::size_t i = 0; i < splitMatrix.size(); i++){
- splitMatrix[i].nodeId = i;
+ for(std::size_t i = 0; i < tree.size(); i++){
+ tree[i].nodeId = i;
}
}
@@ -344,16 +344,16 @@ protected:
template<class Vector>
LabelType const& evalPattern(Vector const& pattern) const{
std::size_t nodeId = 0;
- while(m_splitMatrix[nodeId].leftNodeId != 0){
- if(pattern[m_splitMatrix[nodeId].attributeIndex]<=m_splitMatrix[nodeId].attributeValue){
+ while(m_tree[nodeId].leftNodeId != 0){
+ if(pattern[m_tree[nodeId].attributeIndex]<= m_tree[nodeId].attributeValue){
//Branch on left node
- nodeId = m_splitMatrix[nodeId].leftNodeId;
+ nodeId = m_tree[nodeId].leftNodeId;
}else{
//Branch on right node
- nodeId = m_splitMatrix[nodeId].rightNodeId;
+ nodeId = m_tree[nodeId].rightNodeId;
}
}
- return m_splitMatrix[nodeId].label;
+ return m_tree[nodeId].label;
}
diff --git a/include/shark/Models/Trees/RFClassifier.h b/include/shark/Models/Trees/RFClassifier.h
index b1987af..544d473 100644
--- a/include/shark/Models/Trees/RFClassifier.h
+++ b/include/shark/Models/Trees/RFClassifier.h
@@ -40,8 +40,8 @@
namespace shark {
-typedef CARTClassifier<RealVector>::SplitMatrixType SplitMatrixType;
-typedef std::vector<SplitMatrixType> ForestInfo;
+typedef CARTClassifier<RealVector>::TreeType TreeType;
+typedef std::vector<TreeType> ForestInfo;
///
/// \brief Random Forest Classifier.
@@ -119,7 +119,7 @@ public:
ForestInfo getForestInfo() const {
ForestInfo finfo(m_models.size());
for (std::size_t i=0; i<m_models.size(); ++i)
- finfo[i]=m_models[i].getSplitMatrix();
+ finfo[i]=m_models[i].getTree();
return finfo;
}
diff --git a/src/Algorithms/CARTTrainer.cpp b/src/Algorithms/CARTTrainer.cpp
index 3ae5379..c123aee 100644
--- a/src/Algorithms/CARTTrainer.cpp
+++ b/src/Algorithms/CARTTrainer.cpp
@@ -29,7 +29,7 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
RegressionDataset set=dataset;
CVFolds<RegressionDataset > folds = createCVSameSize(set, m_numberOfFolds);
double bestErrorRate = std::numeric_limits<double>::max();
- CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+ CARTClassifier<RealVector>::TreeType bestTree;
for (unsigned fold = 0; fold < m_numberOfFolds; ++fold){
//Run through all the cross validation sets
@@ -42,10 +42,10 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
std::vector < RealVector > labels(numTrainElements);
boost::copy(dataTrain.labels().elements(),labels.begin());
//Build tree form this fold
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements());
+ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements());
//Add the tree to the model and prune
- model.setSplitMatrix(splitMatrix);
- while(splitMatrix.size()!=1){
+ model.setTree(tree);
+ while(tree.size()!=1){
//evaluate the error of current tree
SquaredLoss<> loss;
double error = loss.eval(dataTest.labels(), model(dataTest.inputs()));
@@ -53,13 +53,13 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
if(error < bestErrorRate){
//We have found a subtree that has a smaller error rate when tested!
bestErrorRate = error;
- bestSplitMatrix = splitMatrix;
+ bestTree = tree;
}
- pruneMatrix(splitMatrix);
- model.setSplitMatrix(splitMatrix);
+ pruneTree(tree);
+ model.setTree(tree);
}
}
- model.setSplitMatrix(bestSplitMatrix);
+ model.setTree(bestTree);
}
@@ -79,7 +79,7 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(set, m_numberOfFolds);
//find the best tree for the cv folds
double bestErrorRate = std::numeric_limits<double>::max();
- CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+ CARTClassifier<RealVector>::TreeType bestTree;
//Run through all the cross validation sets
for (unsigned fold = 0; fold < m_numberOfFolds; ++fold) {
@@ -92,24 +92,24 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
AttributeTables tables = createAttributeTables(dataTrain.inputs());
- //create initial split matrix for the fold
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
- model.setSplitMatrix(splitMatrix);
+ //create initial tree for the fold
+ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+ model.setTree(tree);
- while(splitMatrix.size()!=1){
+ while(tree.size()!=1){
ZeroOneLoss<unsigned int, RealVector> loss;
double errorRate = loss.eval(dataTest.labels(), model(dataTest.inputs()));
if(errorRate < bestErrorRate){
//We have found a subtree that has a smaller error rate when tested!
bestErrorRate = errorRate;
- bestSplitMatrix = splitMatrix;
+ bestTree = tree;
}
- pruneMatrix(splitMatrix);
- model.setSplitMatrix(splitMatrix);
+ pruneTree(tree);
+ model.setTree(tree);
}
}
- model.setSplitMatrix(bestSplitMatrix);
+ model.setTree(bestTree);
}
@@ -127,7 +127,7 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
//~ CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(set, m_numberOfFolds);
//~ //find the best tree for the cv folds
//~ double bestErrorRate = std::numeric_limits<double>::max();
- //~ CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+ //~ CARTClassifier<RealVector>::TreeType bestTree;
//~ //Run through all the cross validation sets
//~ for (unsigned fold = 0; fold < m_numberOfFolds; ++fold) {
@@ -137,57 +137,57 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
//~ boost::unordered_map<size_t, size_t> cAbove = createCountMatrix(dataTrain);
//~ AttributeTables tables = createAttributeTables(dataTrain.inputs());
- //~ //create initial split matrix for the fold
- //~ CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
- //~ model.setSplitMatrix(splitMatrix);
+ //~ //create initial tree for the fold
+ //~ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+ //~ model.setTree(tree);
- //~ while(splitMatrix.size()!=1){
+ //~ while(tree.size()!=1){
//~ double errorRate = evalWeightedError(model, dataTest, weights);
//~ if(errorRate < bestErrorRate){
//~ //We have found a subtree that has a smaller error rate when tested!
//~ bestErrorRate = errorRate;
- //~ bestSplitMatrix = splitMatrix;
+ //~ bestTree = tree;
//~ }
- //~ pruneMatrix(splitMatrix);
- //~ model.setSplitMatrix(splitMatrix);
+ //~ pruneTree(tree);
+ //~ model.setTree(tree);
//~ }
//~ }
//~ error = bestErrorRate;
- //~ model.setSplitMatrix(bestSplitMatrix);
+ //~ model.setTree(bestTree);
//~ }
-void CARTTrainer::pruneMatrix(SplitMatrixType& splitMatrix){
+void CARTTrainer::pruneTree(TreeType & tree){
//Calculate g of all the nodes
- measureStrenght(splitMatrix, 0, 0);
+ measureStrength(tree, 0, 0);
//Find the lowest g of the internal nodes
double g = std::numeric_limits<double>::max();
- for(std::size_t i = 0; i != splitMatrix.size(); i++){
- if(splitMatrix[i].leftNodeId > 0 && splitMatrix[i].g < g){
+ for(std::size_t i = 0; i != tree.size(); i++){
+ if(tree[i].leftNodeId > 0 && tree[i].g < g){
//Update g
- g = splitMatrix[i].g;
+ g = tree[i].g;
}
}
//Prune the nodes with lowest g and make them terminal
- for(std::size_t i=0; i != splitMatrix.size(); i++){
+ for(std::size_t i=0; i != tree.size(); i++){
//Make the internal nodes with the smallest g terminal nodes and prune their children!
- if( splitMatrix[i].leftNodeId > 0 && splitMatrix[i].g == g){
- pruneNode(splitMatrix, splitMatrix[i].leftNodeId);
- pruneNode(splitMatrix, splitMatrix[i].rightNodeId);
+ if( tree[i].leftNodeId > 0 && tree[i].g == g){
+ pruneNode(tree, tree[i].leftNodeId);
+ pruneNode(tree, tree[i].rightNodeId);
// //Make the node terminal
- splitMatrix[i].leftNodeId = 0;
- splitMatrix[i].rightNodeId = 0;
+ tree[i].leftNodeId = 0;
+ tree[i].rightNodeId = 0;
}
}
}
-std::size_t CARTTrainer::findNode(SplitMatrixType& splitMatrix, std::size_t nodeId){
+std::size_t CARTTrainer::findNode(TreeType & tree, std::size_t nodeId){
std::size_t i = 0;
- //while(i<splitMatrix.size() && splitMatrix[i].nodeId!=nodeId){
- while(splitMatrix[i].nodeId != nodeId){
+ //while(i<tree.size() && tree[i].nodeId!=nodeId){
+ while(tree[i].nodeId != nodeId){
i++;
}
return i;
@@ -196,63 +196,63 @@ std::size_t CARTTrainer::findNode(SplitMatrixType& splitMatrix, std::size_t node
/*
Removes branch with root node id nodeId, incl. the node itself
*/
-void CARTTrainer::pruneNode(SplitMatrixType& splitMatrix, std::size_t nodeId){
- std::size_t i = findNode(splitMatrix,nodeId);
+void CARTTrainer::pruneNode(TreeType & tree, std::size_t nodeId){
+ std::size_t i = findNode(tree,nodeId);
- if(splitMatrix[i].leftNodeId>0){
+ if(tree[i].leftNodeId>0){
//Prune left branch
- pruneNode(splitMatrix, splitMatrix[i].leftNodeId);
+ pruneNode(tree, tree[i].leftNodeId);
//Prune right branch
- pruneNode(splitMatrix, splitMatrix[i].rightNodeId);
+ pruneNode(tree, tree[i].rightNodeId);
}
//Remove node
- splitMatrix.erase(splitMatrix.begin()+i);
+ tree.erase(tree.begin()+i);
}
-void CARTTrainer::measureStrenght(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNode){
- std::size_t i = findNode(splitMatrix,nodeId);
+void CARTTrainer::measureStrength(TreeType & tree, std::size_t nodeId, std::size_t parentNode){
+ std::size_t i = findNode(tree,nodeId);
//Reset the entries
- splitMatrix[i].r = 0;
- splitMatrix[i].g = 0;
+ tree[i].r = 0;
+ tree[i].g = 0;
- if(splitMatrix[i].leftNodeId==0){
+ if(tree[i].leftNodeId==0){
//Leaf node
//Update number of leafs
- splitMatrix[parentNode].r+=1;
+ tree[parentNode].r+=1;
//update R(T) from r(t) of node. R(T) is the sum of all the leaf's r(t)
- splitMatrix[parentNode].g+=splitMatrix[i].misclassProp;
+ tree[parentNode].g+= tree[i].misclassProp;
}else{
//Left recursion
- measureStrenght(splitMatrix, splitMatrix[i].leftNodeId, i);
+ measureStrength(tree, tree[i].leftNodeId, i);
//Right recursion
- measureStrenght(splitMatrix, splitMatrix[i].rightNodeId, i);
+ measureStrength(tree, tree[i].rightNodeId, i);
if(parentNode != i){
- splitMatrix[parentNode].r+=splitMatrix[i].r;
- splitMatrix[parentNode].g+=splitMatrix[i].g;
+ tree[parentNode].r+= tree[i].r;
+ tree[parentNode].g+= tree[i].g;
}
//Final calculation of g
- splitMatrix[i].g = (splitMatrix[i].misclassProp-splitMatrix[i].g)/(splitMatrix[i].r-1);
+ tree[i].g = (tree[i].misclassProp- tree[i].g)/(tree[i].r-1);
}
}
//Classification case
-CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
- //Construct split matrix
- ModelType::SplitInfo splitInfo;
- splitInfo.nodeId = nodeId;
- splitInfo.leftNodeId = 0;
- splitInfo.rightNodeId = 0;
+CARTTrainer::TreeType CARTTrainer::buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
+ //Construct tree
+ ModelType::NodeInfo nodeInfo;
+ nodeInfo.nodeId = nodeId;
+ nodeInfo.leftNodeId = 0;
+ nodeInfo.rightNodeId = 0;
// calculate the label of the node, which is the propability of class c
// given all points in this split for every class
- splitInfo.label = hist(cAbove);
+ nodeInfo.label = hist(cAbove);
// calculate the misclassification propability,
// 1-p(j*|t) where j* is the class the node t is most likely to belong to;
- splitInfo.misclassProp = 1- *std::max_element(splitInfo.label.begin(),splitInfo.label.end());
+ nodeInfo.misclassProp = 1- *std::max_element(nodeInfo.label.begin(), nodeInfo.label.end());
//calculate leaves from the data
@@ -304,27 +304,27 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
AttributeTables rTables, lTables;
splitAttributeTables(tables, bestAttributeIndex, bestAttributeValIndex, lTables, rTables);
//Continue recursively
- splitInfo.attributeIndex = bestAttributeIndex;
- splitInfo.attributeValue = bestAttributeVal;
+ nodeInfo.attributeIndex = bestAttributeIndex;
+ nodeInfo.attributeValue = bestAttributeVal;
- //Store entry in the splitMatrix table
- splitInfo.leftNodeId = nodeId+1;
- SplitMatrixType lSplitMatrix = buildTree(lTables, dataset, cBestBelow, splitInfo.leftNodeId);
- splitInfo.rightNodeId = splitInfo.leftNodeId+lSplitMatrix.size();
- SplitMatrixType rSplitMatrix = buildTree(rTables, dataset, cBestAbove, splitInfo.rightNodeId);
+ //Store entry in the tree
+ nodeInfo.leftNodeId = nodeId+1;
+ TreeType lTree = buildTree(lTables, dataset, cBestBelow, nodeInfo.leftNodeId);
+ nodeInfo.rightNodeId = nodeInfo.leftNodeId+ lTree.size();
+ TreeType rTree = buildTree(rTables, dataset, cBestAbove, nodeInfo.rightNodeId);
- SplitMatrixType splitMatrix;
- splitMatrix.push_back(splitInfo);
- splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
- splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
- return splitMatrix;
+ TreeType tree;
+ tree.push_back(nodeInfo);
+ tree.insert(tree.end(), lTree.begin(), lTree.end());
+ tree.insert(tree.end(), rTree.begin(), rTree.end());
+ return tree;
}
}
- SplitMatrixType splitMatrix;
- splitMatrix.push_back(splitInfo);
- return splitMatrix;
+ TreeType tree;
+ tree.push_back(nodeInfo);
+ return tree;
}
RealVector CARTTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countMatrix){
@@ -345,15 +345,15 @@ RealVector CARTTrainer::hist(boost::unordered_map<std::size_t, std::size_t> coun
//Build CART tree in the regression case
-CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize){
+CARTTrainer::TreeType CARTTrainer::buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize){
- //Construct split matrix
- CARTClassifier<RealVector>::SplitInfo splitInfo;
+ //Construct tree
+ CARTClassifier<RealVector>::NodeInfo nodeInfo;
- splitInfo.nodeId = nodeId;
- splitInfo.label = mean(labels);
- splitInfo.leftNodeId = 0;
- splitInfo.rightNodeId = 0;
+ nodeInfo.nodeId = nodeId;
+ nodeInfo.label = mean(labels);
+ nodeInfo.leftNodeId = 0;
+ nodeInfo.rightNodeId = 0;
//Store the Total Sum of Squares (TSS)
RealVector labelSum = labels[0];
@@ -361,9 +361,9 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
labelSum += labels[0];
}
- splitInfo.misclassProp = totalSumOfSquares(labels, 0, labels.size(), labelSum)*((double)dataset.numberOfElements()/trainSize);
+ nodeInfo.misclassProp = totalSumOfSquares(labels, 0, labels.size(), labelSum)*((double)dataset.numberOfElements()/trainSize);
- SplitMatrixType splitMatrix, lSplitMatrix, rSplitMatrix;
+ TreeType tree, lTree, rTree;
//n = Total number of cases in the dataset
//n1 = Number of cases to the left child node
@@ -440,23 +440,23 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
}
//Continue recursively
- splitInfo.attributeIndex = bestAttributeIndex;
- splitInfo.attributeValue = bestAttributeVal;
- splitInfo.leftNodeId = 2*nodeId+1;
- splitInfo.rightNodeId = 2*nodeId+2;
+ nodeInfo.attributeIndex = bestAttributeIndex;
+ nodeInfo.attributeValue = bestAttributeVal;
+ nodeInfo.leftNodeId = 2*nodeId+1;
+ nodeInfo.rightNodeId = 2*nodeId+2;
- lSplitMatrix = buildTree(lTables, dataset, lLabels, splitInfo.leftNodeId, trainSize);
- rSplitMatrix = buildTree(rTables, dataset, rLabels, splitInfo.rightNodeId, trainSize);
+ lTree = buildTree(lTables, dataset, lLabels, nodeInfo.leftNodeId, trainSize);
+ rTree = buildTree(rTables, dataset, rLabels, nodeInfo.rightNodeId, trainSize);
}
}
- splitMatrix.push_back(splitInfo);
- splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
- splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+ tree.push_back(nodeInfo);
+ tree.insert(tree.end(), lTree.begin(), lTree.end());
+ tree.insert(tree.end(), rTree.begin(), rTree.end());
- //Store entry in the splitMatrix table
- return splitMatrix;
+ //Store entry in the tree
+ return tree;
}
diff --git a/src/Algorithms/RFTrainer.cpp b/src/Algorithms/RFTrainer.cpp
index c3ecdce..514843f 100644
--- a/src/Algorithms/RFTrainer.cpp
+++ b/src/Algorithms/RFTrainer.cpp
@@ -131,8 +131,8 @@ void RFTrainer::train(RFClassifier& model, RegressionDataset const& dataset)
labels.push_back(dataTrain.element(i).label);
}
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, labels, 0);
- CARTClassifier<RealVector> tree(splitMatrix, m_inputDimension);
+ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0);
+ CARTClassifier<RealVector> cart(tree, m_inputDimension);
// if oob error or importances have to be computed, create an oob sample
if(m_computeOOBerror || m_computeFeatureImportances){
@@ -141,15 +141,15 @@ void RFTrainer::train(RFClassifier& model, RegressionDataset const& dataset)
// if importances should be computed, oob errors are computed implicitly
if(m_computeFeatureImportances){
- tree.computeFeatureImportances(dataOOB);
+ cart.computeFeatureImportances(dataOOB);
} // if importances should not be computed, only compute the oob errors
else{
- tree.computeOOBerror(dataOOB);
+ cart.computeOOBerror(dataOOB);
}
}
SHARK_CRITICAL_REGION{
- model.addModel(tree);
+ model.addModel(cart);
}
}
@@ -205,8 +205,8 @@ void RFTrainer::train(RFClassifier& model, ClassificationDataset const& dataset)
createAttributeTables(dataTrain.inputs(), tables);
createCountMatrix(dataTrain, cAbove);
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
- CARTClassifier<RealVector> tree(splitMatrix, m_inputDimension);
+ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+ CARTClassifier<RealVector> cart(tree, m_inputDimension);
// if oob error or importances have to be computed, create an oob sample
if(m_computeOOBerror || m_computeFeatureImportances){
@@ -215,15 +215,15 @@ void RFTrainer::train(RFClassifier& model, ClassificationDataset const& dataset)
// if importances should be computed, oob errors are computed implicitly
if(m_computeFeatureImportances){
- tree.computeFeatureImportances(dataOOB);
+ cart.computeFeatureImportances(dataOOB);
} // if importances should not be computed, only compute the oob errors
else{
- tree.computeOOBerror(dataOOB);
+ cart.computeOOBerror(dataOOB);
}
}
SHARK_CRITICAL_REGION{
- model.addModel(tree);
+ model.addModel(cart);
}
}
@@ -256,20 +256,20 @@ void RFTrainer::setOOBratio(double ratio){
-CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
- CARTClassifier<RealVector>::SplitMatrixType lSplitMatrix, rSplitMatrix;
+CARTClassifier<RealVector>::TreeType RFTrainer::buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
+ CARTClassifier<RealVector>::TreeType lTree, rTree;
- //Construct split matrix
- CARTClassifier<RealVector>::SplitInfo splitInfo;
+ //Construct tree
+ CARTClassifier<RealVector>::NodeInfo nodeInfo;
- splitInfo.nodeId = nodeId;
- splitInfo.attributeIndex = 0;
- splitInfo.attributeValue = 0.0;
- splitInfo.leftNodeId = 0;
- splitInfo.rightNodeId = 0;
- splitInfo.misclassProp = 0.0;
- splitInfo.r = 0;
- splitInfo.g = 0.0;
+ nodeInfo.nodeId = nodeId;
+ nodeInfo.attributeIndex = 0;
+ nodeInfo.attributeValue = 0.0;
+ nodeInfo.leftNodeId = 0;
+ nodeInfo.rightNodeId = 0;
+ nodeInfo.misclassProp = 0.0;
+ nodeInfo.r = 0;
+ nodeInfo.g = 0.0;
//n = Total number of cases in the dataset
std::size_t n = tables[0].size();
@@ -330,13 +330,13 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
tables.clear();
//Continue recursively
- splitInfo.attributeIndex = bestAttributeIndex;
- splitInfo.attributeValue = bestAttributeVal;
- splitInfo.leftNodeId = 2*nodeId+1;
- splitInfo.rightNodeId = 2*nodeId+2;
+ nodeInfo.attributeIndex = bestAttributeIndex;
+ nodeInfo.attributeValue = bestAttributeVal;
+ nodeInfo.leftNodeId = 2*nodeId+1;
+ nodeInfo.rightNodeId = 2*nodeId+2;
- lSplitMatrix = buildTree(lTables, dataset, cBestBelow, splitInfo.leftNodeId);
- rSplitMatrix = buildTree(rTables, dataset, cBestAbove, splitInfo.rightNodeId);
+ lTree = buildTree(lTables, dataset, cBestBelow, nodeInfo.leftNodeId);
+ rTree = buildTree(rTables, dataset, cBestAbove, nodeInfo.rightNodeId);
}else{
//Leaf node
isLeaf = true;
@@ -344,20 +344,20 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
}
- //Store entry in the splitMatrix table
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix;
+ //Store entry in the tree table
+ CARTClassifier<RealVector>::TreeType tree;
if(isLeaf){
- splitInfo.label = hist(cAbove);
- splitMatrix.push_back(splitInfo);
- return splitMatrix;
+ nodeInfo.label = hist(cAbove);
+ tree.push_back(nodeInfo);
+ return tree;
}
- splitMatrix.push_back(splitInfo);
- splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
- splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+ tree.push_back(nodeInfo);
+ tree.insert(tree.end(), lTree.begin(), lTree.end());
+ tree.insert(tree.end(), rTree.begin(), rTree.end());
- return splitMatrix;
+ return tree;
}
RealVector RFTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countMatrix){
@@ -376,22 +376,22 @@ RealVector RFTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countM
return histogram;
}
-CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId ){
+CARTClassifier<RealVector>::TreeType RFTrainer::buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId ){
- //Construct split matrix
- CARTClassifier<RealVector>::SplitInfo splitInfo;
+ //Construct tree
+ CARTClassifier<RealVector>::NodeInfo nodeInfo;
- splitInfo.nodeId = nodeId;
- splitInfo.attributeIndex = 0;
- splitInfo.attributeValue = 0.0;
- splitInfo.leftNodeId = 0;
- splitInfo.rightNodeId = 0;
- splitInfo.label = average(labels);
- splitInfo.misclassProp = 0.0;
- splitInfo.r = 0;
- splitInfo.g = 0.0;
+ nodeInfo.nodeId = nodeId;
+ nodeInfo.attributeIndex = 0;
+ nodeInfo.attributeValue = 0.0;
+ nodeInfo.leftNodeId = 0;
+ nodeInfo.rightNodeId = 0;
+ nodeInfo.label = average(labels);
+ nodeInfo.misclassProp = 0.0;
+ nodeInfo.r = 0;
+ nodeInfo.g = 0.0;
- CARTClassifier<RealVector>::SplitMatrixType splitMatrix, lSplitMatrix, rSplitMatrix;
+ CARTClassifier<RealVector>::TreeType tree, lTree, rTree;
//n = Total number of cases in the dataset
std::size_t n = tables[0].size();
@@ -473,13 +473,13 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
}
//Continue recursively
- splitInfo.attributeIndex = bestAttributeIndex;
- splitInfo.attributeValue = bestAttributeVal;
- splitInfo.leftNodeId = 2*nodeId+1;
- splitInfo.rightNodeId = 2*nodeId+2;
+ nodeInfo.attributeIndex = bestAttributeIndex;
+ nodeInfo.attributeValue = bestAttributeVal;
+ nodeInfo.leftNodeId = 2*nodeId+1;
+ nodeInfo.rightNodeId = 2*nodeId+2;
- lSplitMatrix = buildTree(lTables, dataset, lLabels, splitInfo.leftNodeId);
- rSplitMatrix = buildTree(rTables, dataset, rLabels, splitInfo.rightNodeId);
+ lTree = buildTree(lTables, dataset, lLabels, nodeInfo.leftNodeId);
+ rTree = buildTree(rTables, dataset, rLabels, nodeInfo.rightNodeId);
}else{
//Leaf node
isLeaf = true;
@@ -488,16 +488,16 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
}
if(isLeaf){
- splitMatrix.push_back(splitInfo);
- return splitMatrix;
+ tree.push_back(nodeInfo);
+ return tree;
}
- splitMatrix.push_back(splitInfo);
- splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
- splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+ tree.push_back(nodeInfo);
+ tree.insert(tree.end(), lTree.begin(), lTree.end());
+ tree.insert(tree.end(), rTree.begin(), rTree.end());
- //Store entry in the splitMatrix table
- return splitMatrix;
+ //Store entry in the tree
+ return tree;
}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git
More information about the debian-science-commits
mailing list