[opengm] 141/386: improved lunary

Ghislain Vaillant ghisvail-guest at moszumanska.debian.org
Wed Aug 31 08:36:35 UTC 2016


This is an automated email from the git hooks/post-receive script.

ghisvail-guest pushed a commit to branch debian/master
in repository opengm.

commit 0492f92a277765f6b057a2246a70fa8072cd5993
Author: DerThorsten <thorsten.beier at iwr.uni-heidelberg.de>
Date:   Thu Dec 18 19:15:55 2014 +0100

    improved lunary
---
 src/interfaces/python/opengm/learning/__init__.py  | 23 +++++++++++++++++++++-
 .../python/opengm/opengmcore/pyFunctionTypes.cxx   |  6 ++++--
 2 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/src/interfaces/python/opengm/learning/__init__.py b/src/interfaces/python/opengm/learning/__init__.py
index b1293b0..89f639f 100644
--- a/src/interfaces/python/opengm/learning/__init__.py
+++ b/src/interfaces/python/opengm/learning/__init__.py
@@ -2,7 +2,7 @@ from _learning import *
 import numpy
 import struct
 from opengm import index_type,value_type, label_type
-from opengm import configuration as opengmConfig
+from opengm import configuration as opengmConfig, LUnaryFunction
 
 DatasetWithHammingLoss.lossType = 'hamming'
 DatasetWithGeneralizedHammingLoss.lossType = 'generalized-hamming'
@@ -110,6 +110,27 @@ def lPottsFunctions(nFunctions, numberOfLabels, features, weightIds):
 
     raise RuntimeError("not yet implemented")
 
+
+def lunaryFunction(weights, numberOfLabels, features, weightIds):
+
+    features = numpy.require(features, dtype=value_type)
+    weightIds = numpy.require(weightIds, dtype=index_type)
+    
+    assert features.ndim == weightIds.ndim
+    if features.ndim == 1 or weightIds.ndim == 1:
+        assert numberOfLabels ==2
+        features  = features.reshape(1,-1)
+        weightIds = weightIds.reshape(1,-1)
+
+    assert features.shape[0] in [numberOfLabels, numberOfLabels-1]
+    assert weightIds.shape[0] in [numberOfLabels, numberOfLabels-1]
+
+
+
+    return LUnaryFunction(weights=weights, numberOfLabels=numberOfLabels, 
+                          features=features, weightIds=weightIds)
+
+
 def lUnaryFunctions(nFunctions, numberOfLabels, features, weightIds):
     raise RuntimeError("not yet implemented")
 
diff --git a/src/interfaces/python/opengm/opengmcore/pyFunctionTypes.cxx b/src/interfaces/python/opengm/opengmcore/pyFunctionTypes.cxx
index a727362..3615564 100644
--- a/src/interfaces/python/opengm/opengmcore/pyFunctionTypes.cxx
+++ b/src/interfaces/python/opengm/opengmcore/pyFunctionTypes.cxx
@@ -207,13 +207,15 @@ namespace pyfunction{
 
         size_t fPerL = weightIds.shape(1);
 
-        OPENGM_CHECK_OP(weightIds.shape(0), ==, numberOfLabels,   "wrong shapes");
+        OPENGM_CHECK_OP(weightIds.shape(0), <=, numberOfLabels,   "wrong shapes");
+        OPENGM_CHECK_OP(weightIds.shape(0), >=, numberOfLabels-1,   "wrong shapes");
         OPENGM_CHECK_OP(weightIds.shape(0), ==, features.shape(0),"wrong shapes");
         OPENGM_CHECK_OP(weightIds.shape(1), ==, features.shape(1),"wrong shapes");
 
         FI_VEC fiVec(numberOfLabels);
 
-        for(size_t l=0; l<numberOfLabels; ++l){
+        const size_t weightShape0 =  weightIds.shape(0);
+        for(size_t l=0; l<weightShape0; ++l){
             fiVec[l].indices.resize(fPerL);
             fiVec[l].features.resize(fPerL);
             for(size_t i=0; i<fPerL; ++i){

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/opengm.git



More information about the debian-science-commits mailing list