[shark] 06/58: SquaredHingeLinearCSvmTrainer trainer added
Ghislain Vaillant
ghisvail-guest at moszumanska.debian.org
Wed Mar 16 10:05:26 UTC 2016
This is an automated email from the git hooks/post-receive script.
ghisvail-guest pushed a commit to branch master
in repository shark.
commit 013d589f8c3998706cfcbcb44a05f7489f1db5d9
Author: Tobias Glasmachers <tobias.glasmachers at ini.rub.de>
Date: Mon Jan 18 12:22:21 2016 +0100
SquaredHingeLinearCSvmTrainer trainer added
---
include/shark/Algorithms/QP/QpBoxLinear.h | 54 +++++++++++++---------
include/shark/Algorithms/Trainers/CSvmTrainer.h | 32 ++++++++++++-
.../shark/Algorithms/Trainers/McSvmOVATrainer.h | 4 +-
3 files changed, 64 insertions(+), 26 deletions(-)
diff --git a/include/shark/Algorithms/QP/QpBoxLinear.h b/include/shark/Algorithms/QP/QpBoxLinear.h
index 84b4be3..0f72e4a 100644
--- a/include/shark/Algorithms/QP/QpBoxLinear.h
+++ b/include/shark/Algorithms/QP/QpBoxLinear.h
@@ -11,7 +11,7 @@
* \date -
*
*
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
*
* <BR><HR>
* This file is part of Shark.
@@ -100,19 +100,22 @@ public:
///
/// \brief Solve the SVM training problem.
///
- /// \param C regularization constant of the SVM
+ /// \param bound upper bound for alpha-components, complexity parameter of the hinge loss SVM
+ /// \param reg coefficient of the penalty term \f$-\frac{reg}{2} \cdot \|\alpha\|^2\f$, reg=1/C where C is the complexity parameter of the squared hinge loss SVM
/// \param stop stopping condition(s)
/// \param prop solution properties
/// \param verbose if true, the solver prints status information and solution statistics
///
RealVector solve(
- double C,
+ double bound,
+ double reg,
QpStoppingCondition& stop,
QpSolutionProperties* prop = NULL,
bool verbose = false)
{
// sanity checks
- SHARK_ASSERT(C > 0.0);
+ SHARK_ASSERT(bound > 0.0);
+ SHARK_ASSERT(reg >= 0.0);
// measure training time
Timer timer;
@@ -173,8 +176,8 @@ public:
// compute gradient and projected gradient
double a = alpha(i);
double wyx = y_i * inner_prod(w, e_i.input);
- double g = 1.0 - wyx;
- double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == C && g > 0.0 ? 0.0 : g);
+ double g = 1.0 - wyx - reg * a;
+ double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == bound && g > 0.0 ? 0.0 : g);
// update maximal KKT violation over the epoch
max_violation = std::max(max_violation, std::abs(pg));
@@ -184,7 +187,7 @@ public:
if (pg != 0.0)
{
// SMO-style coordinate descent step
- double q = m_xSquared(i);
+ double q = m_xSquared(i) + reg;
double mu = g / q;
double new_a = a + mu;
@@ -194,10 +197,10 @@ public:
mu = -a;
new_a = 0.0;
}
- else if (new_a >= C)
+ else if (new_a >= bound)
{
- mu = C - a;
- new_a = C;
+ mu = bound - a;
+ new_a = bound;
}
// update both representations of the weight vector: alpha and w
@@ -269,10 +272,11 @@ public:
for (std::size_t i=0; i<ell; i++)
{
double a = alpha(i);
- objective += a;
if (a > 0.0)
{
- if (a == C) bounded_SV++;
+ objective += a;
+ objective -= reg/2.0 * a * a;
+ if (a == bound) bounded_SV++;
else free_SV++;
}
}
@@ -375,19 +379,22 @@ public:
///
/// \brief Solve the SVM training problem.
///
- /// \param C regularization constant of the SVM
+ /// \param bound upper bound for alpha-components, complexity parameter of the hinge loss SVM
+ /// \param reg coefficient of the penalty term \f$-\frac{reg}{2} \cdot \|\alpha\|^2\f$, reg=1/C where C is the complexity parameter of the squared hinge loss SVM
/// \param stop stopping condition(s)
/// \param prop solution properties
/// \param verbose if true, the solver prints status information and solution statistics
///
RealVector solve(
- double C,
+ double bound,
+ double reg,
QpStoppingCondition& stop,
QpSolutionProperties* prop = NULL,
bool verbose = false)
{
// sanity checks
- SHARK_ASSERT(C > 0.0);
+ SHARK_ASSERT(bound > 0.0);
+ SHARK_ASSERT(reg >= 0.0);
// measure training time
Timer timer;
@@ -447,8 +454,8 @@ public:
// compute gradient and projected gradient
double a = alpha(i);
double wyx = y(i) * inner_prod(w, x_i);
- double g = 1.0 - wyx;
- double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == C && g > 0.0 ? 0.0 : g);
+ double g = 1.0 - wyx - reg * a;
+ double pg = (a == 0.0 && g < 0.0) ? 0.0 : (a == bound && g > 0.0 ? 0.0 : g);
// update maximal KKT violation over the epoch
max_violation = std::max(max_violation, std::abs(pg));
@@ -458,7 +465,7 @@ public:
if (pg != 0.0)
{
// SMO-style coordinate descent step
- double q = diagonal(i);
+ double q = diagonal(i) + reg;
double mu = g / q;
double new_a = a + mu;
@@ -468,10 +475,10 @@ public:
mu = -a;
new_a = 0.0;
}
- else if (new_a >= C)
+ else if (new_a >= bound)
{
- mu = C - a;
- new_a = C;
+ mu = bound - a;
+ new_a = bound;
}
// update both representations of the weight vector: alpha and w
@@ -544,10 +551,11 @@ public:
for (std::size_t i=0; i<ell; i++)
{
double a = alpha(i);
- objective += a;
if (a > 0.0)
{
- if (a == C) bounded_SV++;
+ objective += a;
+ objective -= reg/2.0 * a * a;
+ if (a == bound) bounded_SV++;
else free_SV++;
}
}
diff --git a/include/shark/Algorithms/Trainers/CSvmTrainer.h b/include/shark/Algorithms/Trainers/CSvmTrainer.h
index 4a21f32..0618e82 100644
--- a/include/shark/Algorithms/Trainers/CSvmTrainer.h
+++ b/include/shark/Algorithms/Trainers/CSvmTrainer.h
@@ -17,7 +17,7 @@
* \date -
*
*
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
*
* <BR><HR>
* This file is part of Shark.
@@ -395,6 +395,7 @@ public:
RealMatrix w(1, dim, 0.0);
row(w, 0) = solver.solve(
base_type::C(),
+ 0.0,
QpConfig::stoppingCondition(),
&QpConfig::solutionProperties(),
QpConfig::verbosity() > 0);
@@ -507,5 +508,34 @@ private:
};
+template <class InputType>
+class SquaredHingeLinearCSvmTrainer : public AbstractLinearSvmTrainer<InputType>
+{
+public:
+ typedef AbstractLinearSvmTrainer<InputType> base_type;
+
+ SquaredHingeLinearCSvmTrainer(double C, bool unconstrained = false)
+ : AbstractLinearSvmTrainer<InputType>(C, unconstrained){}
+
+ /// \brief From INameable: return the class name.
+ std::string name() const
+ { return "SquaredHingeLinearCSvmTrainer"; }
+
+ void train(LinearClassifier<InputType>& model, LabeledData<InputType, unsigned int> const& dataset)
+ {
+ std::size_t dim = inputDimension(dataset);
+ QpBoxLinear<InputType> solver(dataset, dim);
+ RealMatrix w(1, dim, 0.0);
+ row(w, 0) = solver.solve(
+ 1e100,
+ 1.0 / base_type::C(),
+ QpConfig::stoppingCondition(),
+ &QpConfig::solutionProperties(),
+ QpConfig::verbosity() > 0);
+ model.decisionFunction().setStructure(w);
+ }
+};
+
+
}
#endif
diff --git a/include/shark/Algorithms/Trainers/McSvmOVATrainer.h b/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
index a1bed52..2474b62 100644
--- a/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
+++ b/include/shark/Algorithms/Trainers/McSvmOVATrainer.h
@@ -11,7 +11,7 @@
* \date -
*
*
- * \par Copyright 1995-2015 Shark Development Team
+ * \par Copyright 1995-2016 Shark Development Team
*
* <BR><HR>
* This file is part of Shark.
@@ -168,7 +168,7 @@ public:
LabeledData<InputType, unsigned int> bindata = oneVersusRestProblem(dataset, c);
QpBoxLinear<InputType> solver(bindata, dim);
QpSolutionProperties prop;
- row(w, c) = solver.solve(this->C(), base_type::m_stoppingcondition, &prop, base_type::m_verbosity > 0);
+ row(w, c) = solver.solve(this->C(), 0.0, base_type::m_stoppingcondition, &prop, base_type::m_verbosity > 0);
base_type::m_solutionproperties.iterations += prop.iterations;
base_type::m_solutionproperties.seconds += prop.seconds;
base_type::m_solutionproperties.accuracy = std::max(base_type::solutionProperties().accuracy, prop.accuracy);
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/shark.git
More information about the debian-science-commits
mailing list