[mlpack] 05/53: Fix some bugs with the trivial test.

Barak A. Pearlmutter barak+git at pearlmutter.net
Mon Nov 14 00:46:46 UTC 2016


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 452943f059afbcaabe0c143676f35d91f54a7733
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Apr 13 13:04:33 2016 -0700

    Fix some bugs with the trivial test.
---
 qdafn.hpp      |  4 +++-
 qdafn_impl.hpp | 31 ++++++++++++++++++++-----------
 2 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/qdafn.hpp b/qdafn.hpp
index 557b421..860acb8 100644
--- a/qdafn.hpp
+++ b/qdafn.hpp
@@ -23,7 +23,7 @@
 
 namespace qdafn {
 
-template<typename MatType>
+template<typename MatType = arma::mat>
 class QDAFN
 {
  public:
@@ -60,6 +60,8 @@ class QDAFN
   const size_t m;
   //! The random lines we are projecting onto.  Has l columns.
   arma::mat lines;
+  //! Projections of each point onto each random line.
+  arma::mat projections;
 
   //! Indices of the points for each S.
   arma::Mat<size_t> sIndices;
diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp
index 1b8cfaa..368b84c 100644
--- a/qdafn_impl.hpp
+++ b/qdafn_impl.hpp
@@ -10,6 +10,9 @@
 // In case it hasn't been included yet.
 #include "qdafn.hpp"
 
+#include <queue>
+#include <mlpack/methods/neighbor_search/sort_policies/furthest_neighbor_sort.hpp>
+
 namespace qdafn {
 
 // Constructor.
@@ -31,7 +34,7 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
 
   // Now, project each of the reference points onto each line, and collect the
   // top m elements.
-  arma::mat projections = lines.t() * referenceSet;
+  projections = referenceSet.t() * lines;
 
   // Loop over each projection and find the top m elements.
   sIndices.set_size(m, l);
@@ -43,8 +46,8 @@ QDAFN<MatType>::QDAFN(const MatType& referenceSet,
     // Grab the top m elements.
     for (size_t j = 0; j < m; ++j)
     {
-      sIndices[j] = sortedIndices[j];
-      sValues[j] = projections(sortedIndices[j], i);
+      sIndices(j, i) = sortedIndices[j];
+      sValues(j, i) = projections(sortedIndices[j], i);
     }
   }
 }
@@ -61,6 +64,7 @@ void QDAFN<MatType>::Search(const MatType& querySet,
         "value of m!");
 
   neighbors.set_size(k, querySet.n_cols);
+  neighbors.fill(size_t() - 1);
   distances.zeros(k, querySet.n_cols);
 
   // Search for each point.
@@ -103,16 +107,21 @@ void QDAFN<MatType>::Search(const MatType& querySet,
 
       // SortDistance() returns (size_t() - 1) if we shouldn't add it.
       if (insertPosition != (size_t() - 1))
-        InsertNeighbor(distances, neighbors, q, referenceIndex, dist);
+        InsertNeighbor(distances, neighbors, q, insertPosition, referenceIndex,
+            dist);
 
       // Now (line 14) get the next element and insert into the queue.  Do this
-      // by adjusting the previous value.
-      tableLocations[p.second]++;
-      const double val = p.first -
-          projections(tableLocations[p.second] - 1, p.second) +
-          projections(tableLocations[p.second], p.second);
-
-      queue.push(std::make_pair(val, p.second));
+      // by adjusting the previous value.  Don't insert anything if we are at
+      // the end of the search, though.
+      if (i < m - 1)
+      {
+        tableLocations[p.second]++;
+        const double val = p.first -
+            projections(tableLocations[p.second] - 1, p.second) +
+            projections(tableLocations[p.second], p.second);
+
+        queue.push(std::make_pair(val, p.second));
+      }
     }
   }
 }

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list