[shark] 06/58: SquaredHingeLinearCSvmTrainer trainer added

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:26 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch master
in repository shark.

commit 013d589f8c3998706cfcbcb44a05f7489f1db5d9
Author: Tobias Glasmachers <tobias.glasmachers at ini.rub.de>
Date:   Mon Jan 18 12:22:21 2016 +0100

    SquaredHingeLinearCSvmTrainer trainer added
---
 include/shark/Algorithms/QP/QpBoxLinear.h          | 54 +++++++++++++---------
 include/shark/Algorithms/Trainers/CSvmTrainer.h    | 32 ++++++++++++-
 .../shark/Algorithms/Trainers/McSvmOVATrainer.h    |  4 +-
 3 files changed, 64 insertions(+), 26 deletions(-)

diff --git a/include/shark/Algorithms/QP/QpBoxLinear.h b/include/shark/Algorithms/QP/QpBoxLinear.h
index 84b4be3..0f72e4a 100644
--- a/include/shark/Algorithms/QP/QpBoxLinear.h
+++ b/include/shark/Algorithms/QP/QpBoxLinear.h
@@ -11,7 +11,7 @@
  * \date        -
  *
  *
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
  * 
  * <BR><HR>
  * This file is part of Shark.
@@ -100,19 +100,22 @@ public:
 	///
 	/// \brief Solve the SVM training problem.
 	///
-	/// \param  C        regularization constant of the SVM
+	/// \param  bound    upper bound for alpha-components, complexity parameter of the hinge loss SVM
+	/// \param  reg      coefficient of the penalty term \f$-\frac{reg}{2} \cdot \|\alpha\|^2\f$, reg=1/C where C is the complexity parameter of the squared hinge loss SVM
 	/// \param  stop     stopping condition(s)
 	/// \param  prop     solution properties
 	/// \param  verbose  if true, the solver prints status information and solution statistics
 	///
 	RealVector solve(
-			double C,
+			double bound,
+			double reg,
 			QpStoppingCondition& stop,
 			QpSolutionProperties* prop = NULL,
 			bool verbose = false)
 	{
 		// sanity checks
-		SHARK_ASSERT(C > 0.0);
+		SHARK_ASSERT(bound > 0.0);
+		SHARK_ASSERT(reg >= 0.0);
 
 		// measure training time
 		Timer timer;
@@ -173,8 +176,8 @@ public:
 				// compute gradient and projected gradient
 				double a = alpha(i);
 				double wyx = y_i * inner_prod(w, e_i.input);
-				double g = 1.0 - wyx;
-				double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == C && g > 0.0 ? 0.0 : g);
+				double g = 1.0 - wyx - reg * a;
+				double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == bound && g > 0.0 ? 0.0 : g);
 
 				// update maximal KKT violation over the epoch
 				max_violation = std::max(max_violation, std::abs(pg));
@@ -184,7 +187,7 @@ public:
 				if (pg != 0.0)
 				{
 					// SMO-style coordinate descent step
-					double q = m_xSquared(i);
+					double q = m_xSquared(i) + reg;
 					double mu = g / q;
 					double new_a = a + mu;
 
@@ -194,10 +197,10 @@ public:
 						mu = -a;
 						new_a = 0.0;
 					}
-					else if (new_a >= C)
+					else if (new_a >= bound)
 					{
-						mu = C - a;
-						new_a = C;
+						mu = bound - a;
+						new_a = bound;
 					}
 
 					// update both representations of the weight vector: alpha and w
@@ -269,10 +272,11 @@ public:
 		for (std::size_t i=0; i<ell; i++)
 		{
 			double a = alpha(i);
-			objective += a;
 			if (a > 0.0)
 			{
-				if (a == C) bounded_SV++;
+				objective += a;
+				objective -= reg/2.0 * a * a;
+				if (a == bound) bounded_SV++;
 				else free_SV++;
 			}
 		}
@@ -375,19 +379,22 @@ public:
 	///
 	/// \brief Solve the SVM training problem.
 	///
-	/// \param  C        regularization constant of the SVM
+	/// \param  bound    upper bound for alpha-components, complexity parameter of the hinge loss SVM
+	/// \param  reg      coefficient of the penalty term \f$-\frac{reg}{2} \cdot \|\alpha\|^2\f$, reg=1/C where C is the complexity parameter of the squared hinge loss SVM
 	/// \param  stop     stopping condition(s)
 	/// \param  prop     solution properties
 	/// \param  verbose  if true, the solver prints status information and solution statistics
 	///
 	RealVector solve(
-			double C,
+			double bound,
+			double reg,
 			QpStoppingCondition& stop,
 			QpSolutionProperties* prop = NULL,
 			bool verbose = false)
 	{
 		// sanity checks
-		SHARK_ASSERT(C > 0.0);
+		SHARK_ASSERT(bound > 0.0);
+		SHARK_ASSERT(reg >= 0.0);
 
 		// measure training time
 		Timer timer;
@@ -447,8 +454,8 @@ public:
 				// compute gradient and projected gradient
 				double a = alpha(i);
 				double wyx = y(i) * inner_prod(w, x_i);
-				double g = 1.0 - wyx;
-				double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == C && g > 0.0 ? 0.0 : g);
+				double g = 1.0 - wyx - reg * a;
+				double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == bound && g > 0.0 ? 0.0 : g);
 
 				// update maximal KKT violation over the epoch
 				max_violation = std::max(max_violation, std::abs(pg));
@@ -458,7 +465,7 @@ public:
 				if (pg != 0.0)
 				{
 					// SMO-style coordinate descent step
-					double q = diagonal(i);
+					double q = diagonal(i) + reg;
 					double mu = g / q;
 					double new_a = a + mu;
 
@@ -468,10 +475,10 @@ public:
 						mu = -a;
 						new_a = 0.0;
 					}
-					else if (new_a >= C)
+					else if (new_a >= bound)
 					{
-						mu = C - a;
-						new_a = C;
+						mu = bound - a;
+						new_a = bound;
 					}
 
 					// update both representations of the weight vector: alpha and w
@@ -544,10 +551,11 @@ public:
 		for (std::size_t i=0; i<ell; i++)
 		{
 			double a = alpha(i);
-			objective += a;
 			if (a > 0.0)
 			{
-				if (a == C) bounded_SV++;
+				objective += a;
+				objective -= reg/2.0 * a * a;
+				if (a == bound) bounded_SV++;
 				else free_SV++;
 			}
 		}
diff --git a/include/shark/Algorithms/Trainers/CSvmTrainer.h b/include/shark/Algorithms/Trainers/CSvmTrainer.h
index 4a21f32..0618e82 100644
--- a/include/shark/Algorithms/Trainers/CSvmTrainer.h
+++ b/include/shark/Algorithms/Trainers/CSvmTrainer.h
@@ -17,7 +17,7 @@
  * \date        -
  *
  *
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
  * 
  * <BR><HR>
  * This file is part of Shark.
@@ -395,6 +395,7 @@ public:
 		RealMatrix w(1, dim, 0.0);
 		row(w, 0) = solver.solve(
 				base_type::C(),
+				0.0,
 				QpConfig::stoppingCondition(),
 				&QpConfig::solutionProperties(),
 				QpConfig::verbosity() > 0);
@@ -507,5 +508,34 @@ private:
 };
 
 
+template <class InputType>
+class SquaredHingeLinearCSvmTrainer : public AbstractLinearSvmTrainer<InputType>
+{
+public:
+	typedef AbstractLinearSvmTrainer<InputType> base_type;
+
+	SquaredHingeLinearCSvmTrainer(double C, bool unconstrained = false) 
+	: AbstractLinearSvmTrainer<InputType>(C, unconstrained){}
+
+	/// \brief From INameable: return the class name.
+	std::string name() const
+	{ return "SquaredHingeLinearCSvmTrainer"; }
+
+	void train(LinearClassifier<InputType>& model, LabeledData<InputType, unsigned int> const& dataset)
+	{
+		std::size_t dim = inputDimension(dataset);
+		QpBoxLinear<InputType> solver(dataset, dim);
+		RealMatrix w(1, dim, 0.0);
+		row(w, 0) = solver.solve(
+				1e100,
+				1.0 / base_type::C(),
+				QpConfig::stoppingCondition(),
+				&QpConfig::solutionProperties(),
+				QpConfig::verbosity() > 0);
+		model.decisionFunction().setStructure(w);
+	}
+};
+
+
 }
 #endif
diff --git a/include/shark/Algorithms/Trainers/McSvmOVATrainer.h b/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
index a1bed52..2474b62 100644
--- a/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
+++ b/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
@@ -11,7 +11,7 @@
  * \date        -
  *
  *
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
  * 
  * <BR><HR>
  * This file is part of Shark.
@@ -168,7 +168,7 @@ public:
 			LabeledData<InputType, unsigned int> bindata = oneVersusRestProblem(dataset, c);
 			QpBoxLinear<InputType> solver(bindata, dim);
 			QpSolutionProperties prop;
-			row(w, c) = solver.solve(this->C(), base_type::m_stoppingcondition, &prop, base_type::m_verbosity > 0);
+			row(w, c) = solver.solve(this->C(), 0.0, base_type::m_stoppingcondition, &prop, base_type::m_verbosity > 0);
 			base_type::m_solutionproperties.iterations += prop.iterations;
 			base_type::m_solutionproperties.seconds += prop.seconds;
 			base_type::m_solutionproperties.accuracy = std::max(base_type::solutionProperties().accuracy, prop.accuracy);

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git



More information about the debian-science-commits mailing list