[mlpack] 03/53: Add implementation, not yet tested.

Barak A. Pearlmutter barak+git at pearlmutter.net
Mon Nov 14 00:46:46 UTC 2016


This is an automated email from the git hooks/post-receive script.

bap pushed a commit to branch master
in repository mlpack.

commit 06b46e87e22b2fc94ec8e1c1d64d742e4f35236a
Author: Ryan Curtin <ryan at ratml.org>
Date:   Wed Apr 13 12:47:14 2016 -0700

    Add implementation, not yet tested.
---
 qdafn.hpp      |  83 ++++++++++++++++++++++++++++++++
 qdafn_impl.hpp | 147 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 230 insertions(+)

diff --git a/qdafn.hpp b/qdafn.hpp
new file mode 100644
index 0000000..557b421
--- /dev/null
+++ b/qdafn.hpp
@@ -0,0 +1,83 @@
+/**
+ * @file qdafn.hpp
+ * @author Ryan Curtin
+ *
+ * An implementation of the query-dependent approximate furthest neighbor
+ * algorithm specified in the following paper:
+ *
+ * @code
+ * @incollection{pagh2015approximate,
+ *   title={Approximate furthest neighbor in high dimensions},
+ *   author={Pagh, R. and Silvestri, F. and Sivertsen, J. and Skala, M.},
+ *   booktitle={Similarity Search and Applications},
+ *   pages={3--14},
+ *   year={2015},
+ *   publisher={Springer}
+ * }
+ * @endcode
+ */
+#ifndef QDAFN_HPP
+#define QDAFN_HPP
+
+#include <mlpack/core.hpp>
+
+namespace qdafn {
+
+template<typename MatType>
+class QDAFN
+{
+ public:
+  /**
+   * Construct the QDAFN object with the given reference set (this is the set
+   * that will be searched).
+   *
+   * @param referenceSet Set of reference data.
+   * @param l Number of projections.
+   * @param m Number of elements to store for each projection.
+   */
+  QDAFN(const MatType& referenceSet,
+        const size_t l,
+        const size_t m);
+
+  /**
+   * Search for the k furthest neighbors of the given query set.  (The query set
+   * can contain just one point, that is okay.)  The results will be stored in
+   * the given neighbors and distances matrices, in the same format as the
+   * mlpack NeighborSearch and LSHSearch classes.
+   */
+  void Search(const MatType& querySet,
+              const size_t k,
+              arma::Mat<size_t>& neighbors,
+              arma::mat& distances);
+
+ private:
+  //! The reference set.
+  const MatType& referenceSet;
+
+  //! The number of projections.
+  const size_t l;
+  //! The number of elements to store for each projection.
+  const size_t m;
+  //! The random lines we are projecting onto.  Has l columns.
+  arma::mat lines;
+
+  //! Indices of the points for each S.
+  arma::Mat<size_t> sIndices;
+  //! Values of a_i * x for each point in S.
+  arma::mat sValues;
+
+  //! Insert a neighbor into a set of results for a given query point.
+  void InsertNeighbor(arma::mat& distances,
+                      arma::Mat<size_t>& neighbors,
+                      const size_t queryIndex,
+                      const size_t pos,
+                      const size_t neighbor,
+                      const double distance) const;
+};
+
+} // namespace qdafn
+
+// Include implementation.
+#include "qdafn_impl.hpp"
+
+#endif
diff --git a/qdafn_impl.hpp b/qdafn_impl.hpp
new file mode 100644
index 0000000..1b8cfaa
--- /dev/null
+++ b/qdafn_impl.hpp
@@ -0,0 +1,147 @@
+/**
+ * @file qdafn_impl.hpp
+ * @author Ryan Curtin
+ *
+ * Implementation of QDAFN class methods.
+ */
+#ifndef QDAFN_IMPL_HPP
+#define QDAFN_IMPL_HPP
+
+// In case it hasn't been included yet.
+#include "qdafn.hpp"
+
+namespace qdafn {
+
+// Constructor.
+template<typename MatType>
+QDAFN<MatType>::QDAFN(const MatType& referenceSet,
+                      const size_t l,
+                      const size_t m) :
+    referenceSet(referenceSet),
+    l(l),
+    m(m)
+{
+  // Build tables.  This is done by drawing random points from a Gaussian
+  // distribution as the vectors we project onto.  The Gaussian should have zero
+  // mean and unit variance.
+  mlpack::distribution::GaussianDistribution gd(referenceSet.n_rows);
+  lines.set_size(referenceSet.n_rows, l);
+  for (size_t i = 0; i < l; ++i)
+    lines.col(i) = gd.Random();
+
+  // Now, project each of the reference points onto each line, and collect the
+  // top m elements.
+  arma::mat projections = lines.t() * referenceSet;
+
+  // Loop over each projection and find the top m elements.
+  sIndices.set_size(m, l);
+  sValues.set_size(m, l);
+  for (size_t i = 0; i < l; ++i)
+  {
+    arma::uvec sortedIndices = arma::sort_index(projections.col(i), "descend");
+
+    // Grab the top m elements.
+    for (size_t j = 0; j < m; ++j)
+    {
+      sIndices[j] = sortedIndices[j];
+      sValues[j] = projections(sortedIndices[j], i);
+    }
+  }
+}
+
+// Search.
+template<typename MatType>
+void QDAFN<MatType>::Search(const MatType& querySet,
+                            const size_t k,
+                            arma::Mat<size_t>& neighbors,
+                            arma::mat& distances)
+{
+  if (k > m)
+    throw std::invalid_argument("QDAFN::Search(): requested k is greater than "
+        "value of m!");
+
+  neighbors.set_size(k, querySet.n_cols);
+  distances.zeros(k, querySet.n_cols);
+
+  // Search for each point.
+  for (size_t q = 0; q < querySet.n_cols; ++q)
+  {
+    // Initialize a priority queue.
+    // The size_t represents the index of the table, and the double represents
+    // the value of l_i * S_i - l_i * query (see line 6 of Algorithm 1).
+    std::priority_queue<std::pair<double, size_t>> queue;
+    for (size_t i = 0; i < l; ++i)
+    {
+      const double val = projections(0, i) - arma::dot(querySet.col(q),
+                                                       lines.col(i));
+      queue.push(std::make_pair(val, i));
+    }
+
+    // To track where we are in each S table, we keep the next index to look at
+    // in each table (they start at 0).
+    arma::Col<size_t> tableLocations = arma::zeros<arma::Col<size_t>>(l);
+
+    // Now that the queue is initialized, iterate over m elements.
+    for (size_t i = 0; i < m; ++i)
+    {
+      std::pair<size_t, double> p = queue.top();
+      queue.pop();
+
+      // Get index of reference point to look at.
+      size_t referenceIndex = sIndices(tableLocations[p.second], p.second);
+
+      // Calculate distance from query point.
+      const double dist = mlpack::metric::EuclideanDistance::Evaluate(
+          querySet.col(q), referenceSet.col(referenceIndex));
+
+      // Is this neighbor good enough to insert into the results?
+      arma::vec queryDist = distances.unsafe_col(q);
+      arma::Col<size_t> queryIndices = neighbors.unsafe_col(q);
+      const size_t insertPosition =
+          mlpack::neighbor::FurthestNeighborSort::SortDistance(queryDist,
+          queryIndices, dist);
+
+      // SortDistance() returns (size_t() - 1) if we shouldn't add it.
+      if (insertPosition != (size_t() - 1))
+        InsertNeighbor(distances, neighbors, q, referenceIndex, dist);
+
+      // Now (line 14) get the next element and insert into the queue.  Do this
+      // by adjusting the previous value.
+      tableLocations[p.second]++;
+      const double val = p.first -
+          projections(tableLocations[p.second] - 1, p.second) +
+          projections(tableLocations[p.second], p.second);
+
+      queue.push(std::make_pair(val, p.second));
+    }
+  }
+}
+
+template<typename MatType>
+void QDAFN<MatType>::InsertNeighbor(arma::mat& distances,
+                                    arma::Mat<size_t>& neighbors,
+                                    const size_t queryIndex,
+                                    const size_t pos,
+                                    const size_t neighbor,
+                                    const double distance) const
+{
+  // We only memmove() if there is actually a need to shift something.
+  if (pos < (distances.n_rows - 1))
+  {
+    const size_t len = (distances.n_rows - 1) - pos;
+    memmove(distances.colptr(queryIndex) + (pos + 1),
+        distances.colptr(queryIndex) + pos,
+        sizeof(double) * len);
+    memmove(neighbors.colptr(queryIndex) + (pos + 1),
+        neighbors.colptr(queryIndex) + pos,
+        sizeof(size_t) * len);
+  }
+
+  // Now put the new information in the right index.
+  distances(pos, queryIndex) = distance;
+  neighbors(pos, queryIndex) = neighbor;
+}
+
+} // namespace qdafn
+
+#endif

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/mlpack.git



More information about the debian-science-commits mailing list