[shark] 57/58: Trees: renamed splitmatrix → tree, and splitinfo → nodeinfo

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:34 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository shark.

commit f32349daf0e780881ed1480f52c3309e4bb8fe93
Author: Jakob Wrigley <wrigleyster at gmail.com>
Date:   Sat Nov 28 21:00:43 2015 +0100

    Trees: renamed splitmatrix → tree, and splitinfo → nodeinfo
---
 include/shark/Algorithms/Trainers/CARTTrainer.h |  18 +--
 include/shark/Algorithms/Trainers/RFTrainer.h   |   4 +-
 include/shark/Models/Trees/CARTClassifier.h     |  76 ++++-----
 include/shark/Models/Trees/RFClassifier.h       |   6 +-
 src/Algorithms/CARTTrainer.cpp                  | 202 ++++++++++++------------
 src/Algorithms/RFTrainer.cpp                    | 126 +++++++--------
 6 files changed, 216 insertions(+), 216 deletions(-)

diff --git a/include/shark/Algorithms/Trainers/CARTTrainer.h b/include/shark/Algorithms/Trainers/CARTTrainer.h
index 4fc61a4..b810566 100644
--- a/include/shark/Algorithms/Trainers/CARTTrainer.h
+++ b/include/shark/Algorithms/Trainers/CARTTrainer.h
@@ -108,7 +108,7 @@ protected:
 	typedef std::vector < TableEntry > AttributeTable;
 	typedef std::vector < AttributeTable > AttributeTables;
 
-	typedef ModelType::SplitMatrixType SplitMatrixType;
+	typedef ModelType::TreeType TreeType;
 
 
 	///Number of attributes in the dataset
@@ -129,7 +129,7 @@ protected:
 	//Classification functions
 	///Builds a single decision tree from a classification dataset
 	///The method requires the attribute tables,
-	SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
+	SHARK_EXPORT_SYMBOL TreeType buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId );
 
 	///Calculates the Gini impurity of a node. The impurity is defined as
 	///1-sum_j p(j|t)^2
@@ -139,22 +139,22 @@ protected:
 	SHARK_EXPORT_SYMBOL RealVector hist(boost::unordered_map<std::size_t, std::size_t> countMatrix);
 
 	///Regression functions
-	SHARK_EXPORT_SYMBOL SplitMatrixType buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize);
+	SHARK_EXPORT_SYMBOL TreeType buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize);
 	///Calculates the total sum of squares
 	SHARK_EXPORT_SYMBOL double totalSumOfSquares(std::vector<RealVector> const& labels, std::size_t start, std::size_t length, RealVector const& sumLabel);
 	///Calculates the mean of a vector of labels
 	SHARK_EXPORT_SYMBOL RealVector mean(std::vector<RealVector> const& labels);
 
 	///Pruning
-	///Prunes decision tree, represented by a split matrix
-	SHARK_EXPORT_SYMBOL void pruneMatrix(SplitMatrixType& splitMatrix);
+	///Prunes decision tree
+	SHARK_EXPORT_SYMBOL void pruneTree(TreeType & tree);
 	///Prunes a single node, including the child nodes of the decision tree
-	SHARK_EXPORT_SYMBOL void pruneNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
+	SHARK_EXPORT_SYMBOL void pruneNode(TreeType & tree, std::size_t nodeId);
 	///Updates the node variables used in the cost complexity pruning stage
-	SHARK_EXPORT_SYMBOL void measureStrenght(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNodeId);
+	SHARK_EXPORT_SYMBOL void measureStrength(TreeType & tree, std::size_t nodeId, std::size_t parentNodeId);
 
-	///Returns the index of the node with node id in splitMatrix.
-	SHARK_EXPORT_SYMBOL std::size_t findNode(SplitMatrixType& splitMatrix, std::size_t nodeId);
+	///Returns the index of the node with node id in tree.
+	SHARK_EXPORT_SYMBOL std::size_t findNode(TreeType & tree, std::size_t nodeId);
 
 	///Attribute table functions
 	///Create the attribute tables used by the SPRINT algorithm
diff --git a/include/shark/Algorithms/Trainers/RFTrainer.h b/include/shark/Algorithms/Trainers/RFTrainer.h
index e51aa46..bf2ca35 100644
--- a/include/shark/Algorithms/Trainers/RFTrainer.h
+++ b/include/shark/Algorithms/Trainers/RFTrainer.h
@@ -146,10 +146,10 @@ protected:
 	SHARK_EXPORT_SYMBOL void splitAttributeTables(AttributeTables const& tables, std::size_t index, std::size_t valIndex, AttributeTables& LAttributeTables, AttributeTables& RAttributeTables);
 
 	/// Build a decision tree for classification
-	SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
+	SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::TreeType buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId);
 
 	/// Builds a decision tree for regression
-	SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::SplitMatrixType buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId);
+	SHARK_EXPORT_SYMBOL CARTClassifier<RealVector>::TreeType buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId);
 
 	/// comparison function for sorting an attributeTable
 	SHARK_EXPORT_SYMBOL static bool tableSort(RFAttribute const& v1, RFAttribute const& v2);
diff --git a/include/shark/Models/Trees/CARTClassifier.h b/include/shark/Models/Trees/CARTClassifier.h
index 6755781..e4da36d 100644
--- a/include/shark/Models/Trees/CARTClassifier.h
+++ b/include/shark/Models/Trees/CARTClassifier.h
@@ -63,7 +63,7 @@ public:
 	typedef typename base_type::BatchInputType BatchInputType;
 	typedef typename base_type::BatchOutputType BatchOutputType;
 //	Information about a single split. misclassProp, r and g are variables used in the cost complexity step
-	struct SplitInfo{
+	struct NodeInfo {
 		std::size_t nodeId;
 		std::size_t attributeIndex;
 		double attributeValue;
@@ -91,31 +91,31 @@ public:
 	/// Vector of structs that contains the splitting information and the labels.
 	/// The class label is a normalized histogram in the classification case.
 	/// In the regression case, the label is the regression value.
-	typedef std::vector<SplitInfo> SplitMatrixType;
+	typedef std::vector<NodeInfo> TreeType;
 
 	/// Constructor
 	CARTClassifier()
 	{}
 
-	/// Constructor taking the splitMatrix as argument
-	CARTClassifier(SplitMatrixType const& splitMatrix)
+	/// Constructor taking the tree as argument
+	CARTClassifier(TreeType const& tree)
 	{
-		m_splitMatrix=splitMatrix;
+		m_tree=tree;
 	}
 
-	/// Constructor taking the splitMatrix as argument and optimize it if requested
-	CARTClassifier(SplitMatrixType const& splitMatrix, bool optimize)
+	/// Constructor taking the tree as argument and optimize it if requested
+	CARTClassifier(TreeType const& tree, bool optimize)
 	{
 		if (optimize)
-			setSplitMatrix(splitMatrix);
+			setTree(tree);
 		else
-			m_splitMatrix=splitMatrix;
+			m_tree=tree;
 	}
 
-	/// Constructor taking the splitMatrix as argument as well as maximum number of attributes
-	CARTClassifier(SplitMatrixType const& splitMatrix, std::size_t d)
+	/// Constructor taking the tree as argument as well as maximum number of attributes
+	CARTClassifier(TreeType const& tree, std::size_t d)
 	{
-		setSplitMatrix(splitMatrix);
+		setTree(tree);
 		m_inputDimension = d;
 	}
 
@@ -150,15 +150,15 @@ public:
 		output = evalPattern(pattern);		
 	}
 
-	/// Set the model split matrix.
-	void setSplitMatrix(SplitMatrixType const& splitMatrix){
-		m_splitMatrix = splitMatrix;
-		optimizeSplitMatrix(m_splitMatrix);
+	/// Set the model tree.
+	void setTree(TreeType const& tree){
+		m_tree = tree;
+		optimizeTree(m_tree);
 	}
 	
 	/// Get the model split matrix.
-	SplitMatrixType getSplitMatrix() const {
-		return m_splitMatrix;
+	TreeType getTree() const {
+		return m_tree;
 	}
 
 	/// \brief The model does not have any parameters.
@@ -178,12 +178,12 @@ public:
 
 	/// from ISerializable, reads a model from an archive
 	void read(InArchive& archive){
-		archive >> m_splitMatrix;
+		archive >> m_tree;
 	}
 
 	/// from ISerializable, writes a model to an archive
 	void write(OutArchive& archive) const {
-		archive << m_splitMatrix;
+		archive << m_tree;
 	}
 
 
@@ -191,8 +191,8 @@ public:
 	UIntVector countAttributes() const {
 		SHARK_ASSERT(m_inputDimension > 0);
 		UIntVector r(m_inputDimension, 0);
-		typename SplitMatrixType::const_iterator it;
-		for(it = m_splitMatrix.begin(); it != m_splitMatrix.end(); ++it) {
+		typename TreeType::const_iterator it;
+		for(it = m_tree.begin(); it != m_tree.end(); ++it) {
 			//std::cout << "NodeId: " <<it->leftNodeId << std::endl;
 			if(it->leftNodeId != 0) { // not a label 
 				r(it->attributeIndex)++;
@@ -316,27 +316,27 @@ public:
 	}
 
 protected:
-	/// split matrix of the model
-	SplitMatrixType m_splitMatrix;
+	/// tree of the model
+	TreeType m_tree;
 	
-	/// \brief Finds the index of the node with a certain nodeID in an unoptimized split matrix.
+	/// \brief Finds the index of the node with a certain nodeID in an unoptimized tree.
 	std::size_t findNode(std::size_t nodeId) const{
 		std::size_t index = 0;
-		for(; nodeId != m_splitMatrix[index].nodeId; ++index);
+		for(; nodeId != m_tree[index].nodeId; ++index);
 		return index;
 	}
 
-	/// Optimize a split matrix, so constant lookup can be used.
+	/// Optimize a tree, so constant lookup can be used.
 	/// The optimization is done by changing the index of the children
 	/// to use indices instead of node ID.
 	/// Furthermore, the node IDs are converted to index numbers.
-	void optimizeSplitMatrix(SplitMatrixType& splitMatrix) const{
-		for(std::size_t i = 0; i < splitMatrix.size(); i++){
-			splitMatrix[i].leftNodeId = findNode(splitMatrix[i].leftNodeId);
-			splitMatrix[i].rightNodeId = findNode(splitMatrix[i].rightNodeId);
+	void optimizeTree(TreeType & tree) const{
+		for(std::size_t i = 0; i < tree.size(); i++){
+			tree[i].leftNodeId = findNode(tree[i].leftNodeId);
+			tree[i].rightNodeId = findNode(tree[i].rightNodeId);
 		}
-		for(std::size_t i = 0; i < splitMatrix.size(); i++){
-			splitMatrix[i].nodeId = i;
+		for(std::size_t i = 0; i < tree.size(); i++){
+			tree[i].nodeId = i;
 		}
 	}
 	
@@ -344,16 +344,16 @@ protected:
 	template<class Vector>
 	LabelType const& evalPattern(Vector const& pattern) const{
 		std::size_t nodeId = 0;
-		while(m_splitMatrix[nodeId].leftNodeId != 0){
-			if(pattern[m_splitMatrix[nodeId].attributeIndex]<=m_splitMatrix[nodeId].attributeValue){
+		while(m_tree[nodeId].leftNodeId != 0){
+			if(pattern[m_tree[nodeId].attributeIndex]<= m_tree[nodeId].attributeValue){
 				//Branch on left node
-				nodeId = m_splitMatrix[nodeId].leftNodeId;
+				nodeId = m_tree[nodeId].leftNodeId;
 			}else{
 				//Branch on right node
-				nodeId = m_splitMatrix[nodeId].rightNodeId;
+				nodeId = m_tree[nodeId].rightNodeId;
 			}
 		}
-		return m_splitMatrix[nodeId].label;
+		return m_tree[nodeId].label;
 	}
 
 
diff --git a/include/shark/Models/Trees/RFClassifier.h b/include/shark/Models/Trees/RFClassifier.h
index b1987af..544d473 100644
--- a/include/shark/Models/Trees/RFClassifier.h
+++ b/include/shark/Models/Trees/RFClassifier.h
@@ -40,8 +40,8 @@
 
 namespace shark {
 
-typedef CARTClassifier<RealVector>::SplitMatrixType SplitMatrixType;
-typedef std::vector<SplitMatrixType> ForestInfo;
+typedef CARTClassifier<RealVector>::TreeType TreeType;
+typedef std::vector<TreeType> ForestInfo;
 
 ///
 /// \brief Random Forest Classifier.
@@ -119,7 +119,7 @@ public:
 	ForestInfo getForestInfo() const {
 		ForestInfo finfo(m_models.size());
 		for (std::size_t i=0; i<m_models.size(); ++i)
-			finfo[i]=m_models[i].getSplitMatrix();
+			finfo[i]=m_models[i].getTree();
 		return finfo;
 	}
 
diff --git a/src/Algorithms/CARTTrainer.cpp b/src/Algorithms/CARTTrainer.cpp
index 3ae5379..c123aee 100644
--- a/src/Algorithms/CARTTrainer.cpp
+++ b/src/Algorithms/CARTTrainer.cpp
@@ -29,7 +29,7 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
 	RegressionDataset set=dataset;
 	CVFolds<RegressionDataset > folds = createCVSameSize(set, m_numberOfFolds);
 	double bestErrorRate = std::numeric_limits<double>::max();
-	CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+	CARTClassifier<RealVector>::TreeType bestTree;
 	
 	for (unsigned fold = 0; fold < m_numberOfFolds; ++fold){
 		//Run through all the cross validation sets
@@ -42,10 +42,10 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
 		std::vector < RealVector > labels(numTrainElements);
 		boost::copy(dataTrain.labels().elements(),labels.begin());
 		//Build tree form this fold
-		CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements());
+		CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements());
 		//Add the tree to the model and prune
-		model.setSplitMatrix(splitMatrix);
-		while(splitMatrix.size()!=1){
+		model.setTree(tree);
+		while(tree.size()!=1){
 			//evaluate the error of current tree
 			SquaredLoss<> loss;
 			double error = loss.eval(dataTest.labels(), model(dataTest.inputs()));
@@ -53,13 +53,13 @@ void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset)
 			if(error < bestErrorRate){
 				//We have found a subtree that has a smaller error rate when tested!
 				bestErrorRate = error;
-				bestSplitMatrix = splitMatrix;
+				bestTree = tree;
 			}
-			pruneMatrix(splitMatrix);
-			model.setSplitMatrix(splitMatrix);
+			pruneTree(tree);
+			model.setTree(tree);
 		}
 	}
-	model.setSplitMatrix(bestSplitMatrix);
+	model.setTree(bestTree);
 }
 
 
@@ -79,7 +79,7 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
 	CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(set, m_numberOfFolds);
 	//find the best tree for the cv folds
 	double bestErrorRate = std::numeric_limits<double>::max();
-	CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+	CARTClassifier<RealVector>::TreeType bestTree;
 	
 	//Run through all the cross validation sets
 	for (unsigned fold = 0; fold < m_numberOfFolds; ++fold) {
@@ -92,24 +92,24 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
 		AttributeTables tables = createAttributeTables(dataTrain.inputs());
 		
 
-		//create initial split matrix for the fold
-		CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
-		model.setSplitMatrix(splitMatrix);
+		//create initial tree for the fold
+		CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+		model.setTree(tree);
 		
-		while(splitMatrix.size()!=1){
+		while(tree.size()!=1){
 			ZeroOneLoss<unsigned int, RealVector> loss;
 			double errorRate = loss.eval(dataTest.labels(), model(dataTest.inputs()));
 			if(errorRate < bestErrorRate){
 				//We have found a subtree that has a smaller error rate when tested!
 				bestErrorRate = errorRate;
-				bestSplitMatrix = splitMatrix;
+				bestTree = tree;
 			}
-			pruneMatrix(splitMatrix);
-			model.setSplitMatrix(splitMatrix);
+			pruneTree(tree);
+			model.setTree(tree);
 		}
 	}
 
-	model.setSplitMatrix(bestSplitMatrix);
+	model.setTree(bestTree);
 
 }
 
@@ -127,7 +127,7 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
 	//~ CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(set, m_numberOfFolds);
 	//~ //find the best tree for the cv folds
 	//~ double bestErrorRate = std::numeric_limits<double>::max();
-	//~ CARTClassifier<RealVector>::SplitMatrixType bestSplitMatrix;
+	//~ CARTClassifier<RealVector>::TreeType bestTree;
 	
 	//~ //Run through all the cross validation sets
 	//~ for (unsigned fold = 0; fold < m_numberOfFolds; ++fold) {
@@ -137,57 +137,57 @@ void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){
 		//~ boost::unordered_map<size_t, size_t> cAbove = createCountMatrix(dataTrain);
 		//~ AttributeTables tables = createAttributeTables(dataTrain.inputs());
 
-		//~ //create initial split matrix for the fold
-		//~ CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
-		//~ model.setSplitMatrix(splitMatrix);
+		//~ //create initial tree for the fold
+		//~ CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+		//~ model.setTree(tree);
 		
-		//~ while(splitMatrix.size()!=1){
+		//~ while(tree.size()!=1){
 			//~ double errorRate = evalWeightedError(model, dataTest, weights);
 			//~ if(errorRate < bestErrorRate){
 				//~ //We have found a subtree that has a smaller error rate when tested!
 				//~ bestErrorRate = errorRate;
-				//~ bestSplitMatrix = splitMatrix;
+				//~ bestTree = tree;
 			//~ }
-			//~ pruneMatrix(splitMatrix);
-			//~ model.setSplitMatrix(splitMatrix);
+			//~ pruneTree(tree);
+			//~ model.setTree(tree);
 		//~ }
 	//~ }
 
 	//~ error = bestErrorRate;
-	//~ model.setSplitMatrix(bestSplitMatrix);
+	//~ model.setTree(bestTree);
 //~ }
 
 
-void CARTTrainer::pruneMatrix(SplitMatrixType& splitMatrix){
+void CARTTrainer::pruneTree(TreeType & tree){
 
 	//Calculate g of all the nodes
-	measureStrenght(splitMatrix, 0, 0);
+	measureStrength(tree, 0, 0);
 
 	//Find the lowest g of the internal nodes
 	double g = std::numeric_limits<double>::max();
-	for(std::size_t i = 0; i != splitMatrix.size(); i++){
-		if(splitMatrix[i].leftNodeId > 0 && splitMatrix[i].g < g){
+	for(std::size_t i = 0; i != tree.size(); i++){
+		if(tree[i].leftNodeId > 0 && tree[i].g < g){
 			//Update g
-			g = splitMatrix[i].g;
+			g = tree[i].g;
 		}
 	}
 	//Prune the nodes with lowest g and make them terminal
-	for(std::size_t i=0; i != splitMatrix.size(); i++){
+	for(std::size_t i=0; i != tree.size(); i++){
 		//Make the internal nodes with the smallest g terminal nodes and prune their children!
-		if( splitMatrix[i].leftNodeId > 0 && splitMatrix[i].g == g){
-			pruneNode(splitMatrix, splitMatrix[i].leftNodeId);
-			pruneNode(splitMatrix, splitMatrix[i].rightNodeId);
+		if( tree[i].leftNodeId > 0 && tree[i].g == g){
+			pruneNode(tree, tree[i].leftNodeId);
+			pruneNode(tree, tree[i].rightNodeId);
 			// //Make the node terminal
-			splitMatrix[i].leftNodeId = 0;
-			splitMatrix[i].rightNodeId = 0;
+			tree[i].leftNodeId = 0;
+			tree[i].rightNodeId = 0;
 		}
 	}
 }
 
-std::size_t CARTTrainer::findNode(SplitMatrixType& splitMatrix, std::size_t nodeId){
+std::size_t CARTTrainer::findNode(TreeType & tree, std::size_t nodeId){
 	std::size_t i = 0;
-	//while(i<splitMatrix.size() && splitMatrix[i].nodeId!=nodeId){
-	while(splitMatrix[i].nodeId != nodeId){
+	//while(i<tree.size() && tree[i].nodeId!=nodeId){
+	while(tree[i].nodeId != nodeId){
 		i++;
 	}
 	return i;
@@ -196,63 +196,63 @@ std::size_t CARTTrainer::findNode(SplitMatrixType& splitMatrix, std::size_t node
 /*
 	Removes branch with root node id nodeId, incl. the node itself
 */
-void CARTTrainer::pruneNode(SplitMatrixType& splitMatrix, std::size_t nodeId){
-	std::size_t i = findNode(splitMatrix,nodeId);
+void CARTTrainer::pruneNode(TreeType & tree, std::size_t nodeId){
+	std::size_t i = findNode(tree,nodeId);
 
-	if(splitMatrix[i].leftNodeId>0){
+	if(tree[i].leftNodeId>0){
 		//Prune left branch
-		pruneNode(splitMatrix, splitMatrix[i].leftNodeId);
+		pruneNode(tree, tree[i].leftNodeId);
 		//Prune right branch
-		pruneNode(splitMatrix, splitMatrix[i].rightNodeId);
+		pruneNode(tree, tree[i].rightNodeId);
 	}
 	//Remove node
-	splitMatrix.erase(splitMatrix.begin()+i);
+	tree.erase(tree.begin()+i);
 }
 
 
-void CARTTrainer::measureStrenght(SplitMatrixType& splitMatrix, std::size_t nodeId, std::size_t parentNode){
-	std::size_t i = findNode(splitMatrix,nodeId);
+void CARTTrainer::measureStrength(TreeType & tree, std::size_t nodeId, std::size_t parentNode){
+	std::size_t i = findNode(tree,nodeId);
 
 	//Reset the entries
-	splitMatrix[i].r = 0;
-	splitMatrix[i].g = 0;
+	tree[i].r = 0;
+	tree[i].g = 0;
 
-	if(splitMatrix[i].leftNodeId==0){
+	if(tree[i].leftNodeId==0){
 		//Leaf node
 		//Update number of leafs
-		splitMatrix[parentNode].r+=1;
+		tree[parentNode].r+=1;
 		//update R(T) from r(t) of node. R(T) is the sum of all the leaf's r(t)
-		splitMatrix[parentNode].g+=splitMatrix[i].misclassProp;
+		tree[parentNode].g+= tree[i].misclassProp;
 	}else{
 
 		//Left recursion
-		measureStrenght(splitMatrix, splitMatrix[i].leftNodeId, i);
+		measureStrength(tree, tree[i].leftNodeId, i);
 		//Right recursion
-		measureStrenght(splitMatrix, splitMatrix[i].rightNodeId, i);
+		measureStrength(tree, tree[i].rightNodeId, i);
 
 		if(parentNode != i){
-			splitMatrix[parentNode].r+=splitMatrix[i].r;
-			splitMatrix[parentNode].g+=splitMatrix[i].g;
+			tree[parentNode].r+= tree[i].r;
+			tree[parentNode].g+= tree[i].g;
 		}
 
 		//Final calculation of g
-		splitMatrix[i].g = (splitMatrix[i].misclassProp-splitMatrix[i].g)/(splitMatrix[i].r-1);
+		tree[i].g = (tree[i].misclassProp- tree[i].g)/(tree[i].r-1);
 	}
 }
 
 //Classification case
-CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
-	//Construct split matrix
-	ModelType::SplitInfo splitInfo;
-	splitInfo.nodeId = nodeId;
-	splitInfo.leftNodeId = 0;
-	splitInfo.rightNodeId = 0;
+CARTTrainer::TreeType CARTTrainer::buildTree(AttributeTables const& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
+	//Construct tree
+	ModelType::NodeInfo nodeInfo;
+	nodeInfo.nodeId = nodeId;
+	nodeInfo.leftNodeId = 0;
+	nodeInfo.rightNodeId = 0;
 	// calculate the label of the node, which is the propability of class c 
 	// given all points in this split for every class
-	splitInfo.label = hist(cAbove);
+	nodeInfo.label = hist(cAbove);
 	// calculate the misclassification propability,
 	// 1-p(j*|t) where j* is the class the node t is most likely to belong to;
-	splitInfo.misclassProp = 1- *std::max_element(splitInfo.label.begin(),splitInfo.label.end());
+	nodeInfo.misclassProp = 1- *std::max_element(nodeInfo.label.begin(), nodeInfo.label.end());
 	
 	//calculate leaves from the data
 	
@@ -304,27 +304,27 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
 			AttributeTables rTables, lTables;
 			splitAttributeTables(tables, bestAttributeIndex, bestAttributeValIndex, lTables, rTables);
 			//Continue recursively
-			splitInfo.attributeIndex = bestAttributeIndex;
-			splitInfo.attributeValue = bestAttributeVal;
+			nodeInfo.attributeIndex = bestAttributeIndex;
+			nodeInfo.attributeValue = bestAttributeVal;
 			
 
-			//Store entry in the splitMatrix table
-			splitInfo.leftNodeId = nodeId+1;
-			SplitMatrixType lSplitMatrix = buildTree(lTables, dataset, cBestBelow, splitInfo.leftNodeId);
-			splitInfo.rightNodeId = splitInfo.leftNodeId+lSplitMatrix.size();
-			SplitMatrixType rSplitMatrix = buildTree(rTables, dataset, cBestAbove, splitInfo.rightNodeId);
+			//Store entry in the tree
+			nodeInfo.leftNodeId = nodeId+1;
+			TreeType lTree = buildTree(lTables, dataset, cBestBelow, nodeInfo.leftNodeId);
+			nodeInfo.rightNodeId = nodeInfo.leftNodeId+ lTree.size();
+			TreeType rTree = buildTree(rTables, dataset, cBestAbove, nodeInfo.rightNodeId);
 			
-			SplitMatrixType splitMatrix;
-			splitMatrix.push_back(splitInfo);
-			splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
-			splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
-			return splitMatrix;
+			TreeType tree;
+			tree.push_back(nodeInfo);
+			tree.insert(tree.end(), lTree.begin(), lTree.end());
+			tree.insert(tree.end(), rTree.begin(), rTree.end());
+			return tree;
 		}
 	}
 	
-	SplitMatrixType splitMatrix;
-	splitMatrix.push_back(splitInfo);
-	return splitMatrix;
+	TreeType tree;
+	tree.push_back(nodeInfo);
+	return tree;
 }
 
 RealVector CARTTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countMatrix){
@@ -345,15 +345,15 @@ RealVector CARTTrainer::hist(boost::unordered_map<std::size_t, std::size_t> coun
 
 
 //Build CART tree in the regression case
-CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize){
+CARTTrainer::TreeType CARTTrainer::buildTree(AttributeTables const& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId, std::size_t trainSize){
 
-	//Construct split matrix
-	CARTClassifier<RealVector>::SplitInfo splitInfo;
+	//Construct tree
+	CARTClassifier<RealVector>::NodeInfo nodeInfo;
 
-	splitInfo.nodeId = nodeId;
-	splitInfo.label = mean(labels);
-	splitInfo.leftNodeId = 0;
-	splitInfo.rightNodeId = 0;
+	nodeInfo.nodeId = nodeId;
+	nodeInfo.label = mean(labels);
+	nodeInfo.leftNodeId = 0;
+	nodeInfo.rightNodeId = 0;
 
 	//Store the Total Sum of Squares (TSS)
 	RealVector labelSum = labels[0];
@@ -361,9 +361,9 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
 		labelSum += labels[0];
 	}
 
-	splitInfo.misclassProp = totalSumOfSquares(labels, 0, labels.size(), labelSum)*((double)dataset.numberOfElements()/trainSize);
+	nodeInfo.misclassProp = totalSumOfSquares(labels, 0, labels.size(), labelSum)*((double)dataset.numberOfElements()/trainSize);
 
-	SplitMatrixType splitMatrix, lSplitMatrix, rSplitMatrix;
+	TreeType tree, lTree, rTree;
 
 	//n = Total number of cases in the dataset
 	//n1 = Number of cases to the left child node
@@ -440,23 +440,23 @@ CARTTrainer::SplitMatrixType CARTTrainer::buildTree(AttributeTables const& table
 			}
 
 			//Continue recursively
-			splitInfo.attributeIndex = bestAttributeIndex;
-			splitInfo.attributeValue = bestAttributeVal;
-			splitInfo.leftNodeId = 2*nodeId+1;
-			splitInfo.rightNodeId = 2*nodeId+2;
+			nodeInfo.attributeIndex = bestAttributeIndex;
+			nodeInfo.attributeValue = bestAttributeVal;
+			nodeInfo.leftNodeId = 2*nodeId+1;
+			nodeInfo.rightNodeId = 2*nodeId+2;
 
-			lSplitMatrix = buildTree(lTables, dataset, lLabels, splitInfo.leftNodeId, trainSize);
-			rSplitMatrix = buildTree(rTables, dataset, rLabels, splitInfo.rightNodeId, trainSize);
+			lTree = buildTree(lTables, dataset, lLabels, nodeInfo.leftNodeId, trainSize);
+			rTree = buildTree(rTables, dataset, rLabels, nodeInfo.rightNodeId, trainSize);
 		}
 	}
 
 
-	splitMatrix.push_back(splitInfo);
-	splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
-	splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+	tree.push_back(nodeInfo);
+	tree.insert(tree.end(), lTree.begin(), lTree.end());
+	tree.insert(tree.end(), rTree.begin(), rTree.end());
 
-	//Store entry in the splitMatrix table
-	return splitMatrix;
+	//Store entry in the tree
+	return tree;
 
 }
 
diff --git a/src/Algorithms/RFTrainer.cpp b/src/Algorithms/RFTrainer.cpp
index c3ecdce..514843f 100644
--- a/src/Algorithms/RFTrainer.cpp
+++ b/src/Algorithms/RFTrainer.cpp
@@ -131,8 +131,8 @@ void RFTrainer::train(RFClassifier& model, RegressionDataset const& dataset)
 			labels.push_back(dataTrain.element(i).label);
 		}
 
-		CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, labels, 0);
-		CARTClassifier<RealVector> tree(splitMatrix, m_inputDimension);
+		CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0);
+		CARTClassifier<RealVector> cart(tree, m_inputDimension);
 
 		// if oob error or importances have to be computed, create an oob sample
 		if(m_computeOOBerror || m_computeFeatureImportances){
@@ -141,15 +141,15 @@ void RFTrainer::train(RFClassifier& model, RegressionDataset const& dataset)
 
 			// if importances should be computed, oob errors are computed implicitly
 			if(m_computeFeatureImportances){
-				tree.computeFeatureImportances(dataOOB);
+				cart.computeFeatureImportances(dataOOB);
 			} // if importances should not be computed, only compute the oob errors
 			else{
-				tree.computeOOBerror(dataOOB);
+				cart.computeOOBerror(dataOOB);
 			}
 		}
 
 		SHARK_CRITICAL_REGION{
-			model.addModel(tree);
+			model.addModel(cart);
 		}
 	}
 
@@ -205,8 +205,8 @@ void RFTrainer::train(RFClassifier& model, ClassificationDataset const& dataset)
 		createAttributeTables(dataTrain.inputs(), tables);
 		createCountMatrix(dataTrain, cAbove);
 
-		CARTClassifier<RealVector>::SplitMatrixType splitMatrix = buildTree(tables, dataTrain, cAbove, 0);
-		CARTClassifier<RealVector> tree(splitMatrix, m_inputDimension);
+		CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0);
+		CARTClassifier<RealVector> cart(tree, m_inputDimension);
 
 		// if oob error or importances have to be computed, create an oob sample
 		if(m_computeOOBerror || m_computeFeatureImportances){
@@ -215,15 +215,15 @@ void RFTrainer::train(RFClassifier& model, ClassificationDataset const& dataset)
 
 			// if importances should be computed, oob errors are computed implicitly
 			if(m_computeFeatureImportances){
-				tree.computeFeatureImportances(dataOOB);
+				cart.computeFeatureImportances(dataOOB);
 			} // if importances should not be computed, only compute the oob errors
 			else{
-				tree.computeOOBerror(dataOOB);
+				cart.computeOOBerror(dataOOB);
 			}
 		}
 
 		SHARK_CRITICAL_REGION{
-			model.addModel(tree);
+			model.addModel(cart);
 		}
 	}
 
@@ -256,20 +256,20 @@ void RFTrainer::setOOBratio(double ratio){
 
 
 
-CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
-	CARTClassifier<RealVector>::SplitMatrixType lSplitMatrix, rSplitMatrix;
+CARTClassifier<RealVector>::TreeType RFTrainer::buildTree(AttributeTables& tables, ClassificationDataset const& dataset, boost::unordered_map<std::size_t, std::size_t>& cAbove, std::size_t nodeId ){
+	CARTClassifier<RealVector>::TreeType lTree, rTree;
 
-	//Construct split matrix
-	CARTClassifier<RealVector>::SplitInfo splitInfo;
+	//Construct tree
+	CARTClassifier<RealVector>::NodeInfo nodeInfo;
 
-	splitInfo.nodeId = nodeId;
-	splitInfo.attributeIndex = 0;
-	splitInfo.attributeValue = 0.0;
-	splitInfo.leftNodeId = 0;
-	splitInfo.rightNodeId = 0;
-	splitInfo.misclassProp = 0.0;
-	splitInfo.r = 0;
-	splitInfo.g = 0.0;
+	nodeInfo.nodeId = nodeId;
+	nodeInfo.attributeIndex = 0;
+	nodeInfo.attributeValue = 0.0;
+	nodeInfo.leftNodeId = 0;
+	nodeInfo.rightNodeId = 0;
+	nodeInfo.misclassProp = 0.0;
+	nodeInfo.r = 0;
+	nodeInfo.g = 0.0;
 
 	//n = Total number of cases in the dataset
 	std::size_t n = tables[0].size();
@@ -330,13 +330,13 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
 			tables.clear();
 			//Continue recursively
 
-			splitInfo.attributeIndex = bestAttributeIndex;
-			splitInfo.attributeValue = bestAttributeVal;
-			splitInfo.leftNodeId = 2*nodeId+1;
-			splitInfo.rightNodeId = 2*nodeId+2;
+			nodeInfo.attributeIndex = bestAttributeIndex;
+			nodeInfo.attributeValue = bestAttributeVal;
+			nodeInfo.leftNodeId = 2*nodeId+1;
+			nodeInfo.rightNodeId = 2*nodeId+2;
 
-			lSplitMatrix = buildTree(lTables, dataset, cBestBelow, splitInfo.leftNodeId);
-			rSplitMatrix = buildTree(rTables, dataset, cBestAbove, splitInfo.rightNodeId);
+			lTree = buildTree(lTables, dataset, cBestBelow, nodeInfo.leftNodeId);
+			rTree = buildTree(rTables, dataset, cBestAbove, nodeInfo.rightNodeId);
 		}else{
 			//Leaf node
 			isLeaf = true;
@@ -344,20 +344,20 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
 
 	}
 
-	//Store entry in the splitMatrix table
-	CARTClassifier<RealVector>::SplitMatrixType splitMatrix;
+	//Store entry in the tree table
+	CARTClassifier<RealVector>::TreeType tree;
 
 	if(isLeaf){
-		splitInfo.label = hist(cAbove);
-		splitMatrix.push_back(splitInfo);
-		return splitMatrix;
+		nodeInfo.label = hist(cAbove);
+		tree.push_back(nodeInfo);
+		return tree;
 	}
 
-	splitMatrix.push_back(splitInfo);
-	splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
-	splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+	tree.push_back(nodeInfo);
+	tree.insert(tree.end(), lTree.begin(), lTree.end());
+	tree.insert(tree.end(), rTree.begin(), rTree.end());
 
-	return splitMatrix;
+	return tree;
 }
 
 RealVector RFTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countMatrix){
@@ -376,22 +376,22 @@ RealVector RFTrainer::hist(boost::unordered_map<std::size_t, std::size_t> countM
 	return histogram;
 }
 
-CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId ){
+CARTClassifier<RealVector>::TreeType RFTrainer::buildTree(AttributeTables& tables, RegressionDataset const& dataset, std::vector<RealVector> const& labels, std::size_t nodeId ){
 
-	//Construct split matrix
-	CARTClassifier<RealVector>::SplitInfo splitInfo;
+	//Construct tree
+	CARTClassifier<RealVector>::NodeInfo nodeInfo;
 
-	splitInfo.nodeId = nodeId;
-	splitInfo.attributeIndex = 0;
-	splitInfo.attributeValue = 0.0;
-	splitInfo.leftNodeId = 0;
-	splitInfo.rightNodeId = 0;
-	splitInfo.label = average(labels);
-	splitInfo.misclassProp = 0.0;
-	splitInfo.r = 0;
-	splitInfo.g = 0.0;
+	nodeInfo.nodeId = nodeId;
+	nodeInfo.attributeIndex = 0;
+	nodeInfo.attributeValue = 0.0;
+	nodeInfo.leftNodeId = 0;
+	nodeInfo.rightNodeId = 0;
+	nodeInfo.label = average(labels);
+	nodeInfo.misclassProp = 0.0;
+	nodeInfo.r = 0;
+	nodeInfo.g = 0.0;
 
-	CARTClassifier<RealVector>::SplitMatrixType splitMatrix, lSplitMatrix, rSplitMatrix;
+	CARTClassifier<RealVector>::TreeType tree, lTree, rTree;
 
 	//n = Total number of cases in the dataset
 	std::size_t n = tables[0].size();
@@ -473,13 +473,13 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
 			}
 
 			//Continue recursively
-			splitInfo.attributeIndex = bestAttributeIndex;
-			splitInfo.attributeValue = bestAttributeVal;
-			splitInfo.leftNodeId = 2*nodeId+1;
-			splitInfo.rightNodeId = 2*nodeId+2;
+			nodeInfo.attributeIndex = bestAttributeIndex;
+			nodeInfo.attributeValue = bestAttributeVal;
+			nodeInfo.leftNodeId = 2*nodeId+1;
+			nodeInfo.rightNodeId = 2*nodeId+2;
 
-			lSplitMatrix = buildTree(lTables, dataset, lLabels, splitInfo.leftNodeId);
-			rSplitMatrix = buildTree(rTables, dataset, rLabels, splitInfo.rightNodeId);
+			lTree = buildTree(lTables, dataset, lLabels, nodeInfo.leftNodeId);
+			rTree = buildTree(rTables, dataset, rLabels, nodeInfo.rightNodeId);
 		}else{
 			//Leaf node
 			isLeaf = true;
@@ -488,16 +488,16 @@ CARTClassifier<RealVector>::SplitMatrixType RFTrainer::buildTree(AttributeTables
 	}
 
 	if(isLeaf){
-		splitMatrix.push_back(splitInfo);
-		return splitMatrix;
+		tree.push_back(nodeInfo);
+		return tree;
 	}
 
-	splitMatrix.push_back(splitInfo);
-	splitMatrix.insert(splitMatrix.end(), lSplitMatrix.begin(), lSplitMatrix.end());
-	splitMatrix.insert(splitMatrix.end(), rSplitMatrix.begin(), rSplitMatrix.end());
+	tree.push_back(nodeInfo);
+	tree.insert(tree.end(), lTree.begin(), lTree.end());
+	tree.insert(tree.end(), rTree.begin(), rTree.end());
 
-	//Store entry in the splitMatrix table
-	return splitMatrix;
+	//Store entry in the tree
+	return tree;
 
 }
 

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git



More information about the debian-science-commits mailing list