[lua-torch-nngraph] 01/02: Imported Upstream version 0~20160804-g40e4207

Zhou Mo cdluminate-guest at moszumanska.debian.org
Sat Aug 20 15:08:01 UTC 2016


This is an automated email from the git hooks/post-receive script.

cdluminate-guest pushed a commit to branch master
in repository lua-torch-nngraph.

commit 8818a43cd8d52d3091d1cdbc0706d2c3dd1298c1
Author: Zhou Mo <cdluminate at gmail.com>
Date:   Mon Aug 15 12:21:30 2016 +0000

    Imported Upstream version 0~20160804-g40e4207
---
 .gitignore                        |   1 +
 .travis.yml                       |  66 +++++
 CMakeLists.txt                    |   8 +
 COPYRIGHT.txt                     |  35 +++
 JustElement.lua                   |  18 ++
 JustTable.lua                     |  17 ++
 ModuleFromCriterion.lua           |  42 ++++
 README.md                         | 250 ++++++++++++++++++
 doc/annotation_bg.png             | Bin 0 -> 215497 bytes
 doc/annotation_fg.png             | Bin 0 -> 215171 bytes
 doc/mlp.png                       | Bin 0 -> 133223 bytes
 doc/mlp2.png                      | Bin 0 -> 215751 bytes
 doc/mlp3_backward.png             | Bin 0 -> 201917 bytes
 doc/mlp3_forward.png              | Bin 0 -> 202244 bytes
 doc/mlp4_backward.png             | Bin 0 -> 271970 bytes
 doc/mlp4_forward.png              | Bin 0 -> 270958 bytes
 doc/my_bad_linear_net.png         | Bin 0 -> 41272 bytes
 gmodule.lua                       | 514 ++++++++++++++++++++++++++++++++++++++
 graphinspecting.lua               | 159 ++++++++++++
 init.lua                          |  80 ++++++
 nest.lua                          |  46 ++++
 nesting.lua                       |  85 +++++++
 nngraph-scm-1.rockspec            |  30 +++
 node.lua                          | 178 +++++++++++++
 simple_print.lua                  | 124 +++++++++
 test/speed.lua                    | 111 ++++++++
 test/test_JustElement.lua         |  28 +++
 test/test_JustTable.lua           |  19 ++
 test/test_ModuleFromCriterion.lua |  57 +++++
 test/test_connectivity.lua        |  26 ++
 test/test_debug.lua               |  80 ++++++
 test/test_nest.lua                |  33 +++
 test/test_nngraph.lua             | 481 +++++++++++++++++++++++++++++++++++
 test/test_old.lua                 | 227 +++++++++++++++++
 utils.lua                         |  42 ++++
 35 files changed, 2757 insertions(+)

diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..567609b
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+build/
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000..5ec70aa
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,66 @@
+language: c
+compiler:
+  - gcc
+  - clang
+cache:
+  directories:
+  - $HOME/OpenBlasInstall
+  - $HOME/GraphViz
+sudo: false
+env:
+  - TORCH_LUA_VERSION=LUAJIT21
+  - TORCH_LUA_VERSION=LUA51
+  - TORCH_LUA_VERSION=LUA52
+addons:
+  apt:
+    packages:
+    - cmake
+    - gfortran
+    - gcc-multilib
+    - gfortran-multilib
+    - liblapack-dev
+    - build-essential
+    - gcc
+    - g++
+    - curl
+    - cmake
+    - libreadline-dev
+    - git-core
+    - libqt4-core
+    - libqt4-gui
+    - libqt4-dev
+    - libjpeg-dev
+    - libpng-dev
+    - ncurses-dev
+    - imagemagick
+    - libzmq3-dev
+    - gfortran
+    - unzip
+    - gnuplot
+    - gnuplot-x11
+before_script:
+- export ROOT_TRAVIS_DIR=$(pwd)
+- export INSTALL_PREFIX=~/torch/install
+-  ls $HOME/OpenBlasInstall/lib || (cd /tmp/ && git clone https://github.com/xianyi/OpenBLAS.git -b master && cd OpenBLAS && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make PREFIX=$HOME/OpenBlasInstall install)
+-  ls $HOME/GraphViz/lib || (cd /tmp/ && wget -c http://www.graphviz.org/pub/graphviz/stable/SOURCES/graphviz-2.38.0.tar.gz && tar -xvf graphviz-2.38.0.tar.gz && cd graphviz-2.38.0 && (./configure prefix=$HOME/GraphViz/ 2>/dev/null >/dev/null) && (make NO_AFFINITY=1 -j$(getconf _NPROCESSORS_ONLN) 2>/dev/null >/dev/null) && make install)
+- export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH
+- git clone https://github.com/torch/distro.git ~/torch --recursive
+- cd ~/torch && git submodule update --init --recursive
+- mkdir build && cd build
+- export CMAKE_LIBRARY_PATH=$HOME/OpenBlasInstall/include:$HOME/OpenBlasInstall/lib:$CMAKE_LIBRARY_PATH
+- cmake .. -DCMAKE_INSTALL_PREFIX="${INSTALL_PREFIX}" -DCMAKE_BUILD_TYPE=Release -DWITH_${TORCH_LUA_VERSION}=ON
+- make && make install
+- ${INSTALL_PREFIX}/bin/luarocks install totem
+- if [[ $TORCH_LUA_VERSION != 'LUAJIT21' && $TORCH_LUA_VERSION != 'LUAJIT20' ]]; then ${INSTALL_PREFIX}/bin/luarocks install luaffi; fi
+- cd $ROOT_TRAVIS_DIR
+- export LD_LIBRARY_PATH=${INSTALL_PREFIX}/lib:$LD_LIBRARY_PATH
+script:
+- ${INSTALL_PREFIX}/bin/luarocks make
+- export PATH=${INSTALL_PREFIX}/bin:$PATH
+- export LD_LIBRARY_PATH=$HOME/GraphViz/lib:$LD_LIBRARY_PATH
+- export TESTLUA=$(which luajit lua | head -n 1)
+- ${TESTLUA} -lnngraph -e "print('nngraph loaded succesfully')"
+- cd test
+- ${TESTLUA} test_ModuleFromCriterion.lua
+- ${TESTLUA} test_nest.lua
+- ${TESTLUA} test_nngraph.lua
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..cd283bc
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,8 @@
+
+CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
+CMAKE_POLICY(VERSION 2.6)
+FIND_PACKAGE(Torch REQUIRED)
+
+FILE(GLOB luasrc *.lua)
+
+ADD_TORCH_PACKAGE(nngraph ""  "${luasrc}" "Neural Net Graph Package")
diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt
new file mode 100644
index 0000000..2e4118c
--- /dev/null
+++ b/COPYRIGHT.txt
@@ -0,0 +1,35 @@
+Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
+Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
+Copyright (c) 2011-2013 NYU (Clement Farabet)
+Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
+Copyright (c) 2006      Idiap Research Institute (Samy Bengio)
+Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
+
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+3. Neither the names of NEC Laboratories American and IDIAP Research
+   Institute nor the names of its contributors may be used to endorse or
+   promote products derived from this software without specific prior
+   written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
diff --git a/JustElement.lua b/JustElement.lua
new file mode 100644
index 0000000..0f18972
--- /dev/null
+++ b/JustElement.lua
@@ -0,0 +1,18 @@
+
+local JustElement, parent = torch.class('nngraph.JustElement', 'nn.Module')
+function JustElement:__init()
+   self.gradInput = {}
+end
+
+-- The input is a table with one element.
+-- The output the element from the table.
+function JustElement:updateOutput(input)
+   assert(#input == 1, "expecting one element")
+   self.output = input[1]
+   return self.output
+end
+
+function JustElement:updateGradInput(input, gradOutput)
+   self.gradInput[1] = gradOutput
+   return self.gradInput
+end
diff --git a/JustTable.lua b/JustTable.lua
new file mode 100644
index 0000000..1fc8434
--- /dev/null
+++ b/JustTable.lua
@@ -0,0 +1,17 @@
+
+local JustTable, parent = torch.class('nngraph.JustTable', 'nn.Module')
+function JustTable:__init()
+   self.output = {}
+end
+
+-- The input is one element.
+-- The output is a table with one element: {element}
+function JustTable:updateOutput(input)
+   self.output[1] = input
+   return self.output
+end
+
+function JustTable:updateGradInput(input, gradOutput)
+   self.gradInput = gradOutput[1]
+   return self.gradInput
+end
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua
new file mode 100644
index 0000000..8717ca5
--- /dev/null
+++ b/ModuleFromCriterion.lua
@@ -0,0 +1,42 @@
+
+--[[ A wrapper to turn a criterion into a module.
+
+The gradient with respect to the target will be zero.
+--]]
+local ModuleFromCriterion, parent = torch.class('nn.ModuleFromCriterion','nn.Module')
+function ModuleFromCriterion:__init(criterion)
+   self.criterion = criterion
+   self.output = torch.Tensor(1)
+   self.gradInput = {torch.Tensor(), torch.Tensor()}
+end
+
+local unpack = unpack or table.unpack -- lua52 compat
+
+--[[ The input is a {prediction, target} pair.
+The output is a tensor with one number: the criterion output.
+--]]
+function ModuleFromCriterion:updateOutput(input)
+   local prediction, target = unpack(input)
+   self.output[1] = self.criterion:updateOutput(prediction, target)
+   return self.output
+end
+
+function ModuleFromCriterion:updateGradInput(input, gradOutput)
+   local prediction, target = unpack(input)
+   local gradPrediction = self.criterion:updateGradInput(prediction, target)
+   if type(gradPrediction) == 'table' then
+      if type(self.gradInput[1]) ~= 'table' then
+         self.gradInput[1] = {} -- initializing to table first time if it is tensor (which it is: line 10)
+         for i=1, #gradPrediction do
+            self.gradInput[1][i] = gradPrediction[i].new() -- and putting tensors of right size inside.
+         end
+      end
+      for i=1, #gradPrediction do
+         self.gradInput[1][i]:resizeAs(gradPrediction[i]):copy(gradPrediction[i]):mul(gradOutput[1])
+      end
+   else
+      self.gradInput[1]:resizeAs(gradPrediction):copy(gradPrediction):mul(gradOutput[1])
+   end
+   self.gradInput[2]:resizeAs(target):zero()
+   return self.gradInput
+end
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..10c8ad0
--- /dev/null
+++ b/README.md
@@ -0,0 +1,250 @@
+# Neural Network Graph Package
+
+[![Build Status](https://travis-ci.org/torch/nngraph.svg)](https://travis-ci.org/torch/nngraph) 
+
+This package provides graphical computation for `nn` library in [Torch](https://github.com/torch/torch7/blob/master/README.md).
+
+## Requirements
+
+You do *not* need `graphviz` to be able to use this library, but if you have it you will be able to display the graphs that you have created. For installing the package run the appropriate command below:
+
+```bash
+# Mac users
+brew install graphviz
+# Debian/Ubuntu users
+sudo apt-get install graphviz -y
+```
+
+## Usage
+
+[Plug: A more explanatory nngraph tutorial by Nando De Freitas of  Oxford](https://www.cs.ox.ac.uk/people/nando.defreitas/machinelearning/practicals/practical5.pdf)
+
+The aim of this library is to provide users of `nn` package with tools to easily create complicated architectures.
+Any given `nn` `module` is going to be bundled into a *graph node*.
+The `__call__` operator of an instance of `nn.Module` is used to create architectures as if one is writing function calls.
+
+### Two hidden layers MLP
+
+```lua
+h1 = nn.Linear(20, 10)()
+h2 = nn.Linear(10, 1)(nn.Tanh()(nn.Linear(10, 10)(nn.Tanh()(h1))))
+mlp = nn.gModule({h1}, {h2})
+
+x = torch.rand(20)
+dx = torch.rand(1)
+mlp:updateOutput(x)
+mlp:updateGradInput(x, dx)
+mlp:accGradParameters(x, dx)
+
+-- draw graph (the forward graph, '.fg')
+graph.dot(mlp.fg, 'MLP')
+```
+
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp.png" width="300px"/>
+
+Read this diagram from top to bottom, with the first and last nodes being *dummy nodes* that regroup all inputs and outputs of the graph.
+The `module` entry describes the function of the node, as applies to `input`, and producing a result of the shape `gradOutput`; `mapindex` contains pointers to the parent nodes.
+
+To save the *graph* on file, specify the file name, and both a `dot` and `svg` files will be saved. For example, you can type:
+
+```lua
+graph.dot(mlp.fg, 'MLP', 'myMLP')
+```
+
+You can also use the `__unm__` and `__sub__` operators to replace all `__call__`:
+```lua
+h1 = - nn.Linear(20,10)
+h2 = h1
+     - nn.Tanh()
+     - nn.Linear(10,10)
+     - nn.Tanh()
+     - nn.Linear(10, 1)
+mlp = nn.gModule({h1}, {h2})
+```
+
+
+### A network with 2 inputs and 2 outputs
+
+```lua
+h1 = nn.Linear(20, 20)()
+h2 = nn.Linear(10, 10)()
+hh1 = nn.Linear(20, 1)(nn.Tanh()(h1))
+hh2 = nn.Linear(10, 1)(nn.Tanh()(h2))
+madd = nn.CAddTable()({hh1, hh2})
+oA = nn.Sigmoid()(madd)
+oB = nn.Tanh()(madd)
+gmod = nn.gModule({h1, h2}, {oA, oB})
+
+x1 = torch.rand(20)
+x2 = torch.rand(10)
+
+gmod:updateOutput({x1, x2})
+gmod:updateGradInput({x1, x2}, {torch.rand(1), torch.rand(1)})
+graph.dot(gmod.fg, 'Big MLP')
+```
+
+Alternatively, you can use `-` to make your code looks like the data flow:
+
+```lua
+h1 = - nn.Linear(20,20)
+h2 = - nn.Linear(10,10)
+hh1 = h1 - nn.Tanh() - nn.Linear(20,1)
+hh2 = h2 - nn.Tanh() - nn.Linear(10,1)
+madd = {hh1,hh2} - nn.CAddTable()
+oA = madd - nn.Sigmoid()
+oB = madd - nn.Tanh()
+gmod = nn.gModule( {h1,h2}, {oA,oB} )
+```
+
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp2.png" width="300px"/>
+
+
+### A network with containers
+
+Another net that uses container modules (like `ParallelTable`) that output a table of outputs.
+
+```lua
+m = nn.Sequential()
+m:add(nn.SplitTable(1))
+m:add(nn.ParallelTable():add(nn.Linear(10, 20)):add(nn.Linear(10, 30)))
+input = nn.Identity()()
+input1, input2 = m(input):split(2)
+m3 = nn.JoinTable(1)({input1, input2})
+
+g = nn.gModule({input}, {m3})
+
+indata = torch.rand(2, 10)
+gdata = torch.rand(50)
+g:forward(indata)
+g:backward(indata, gdata)
+
+graph.dot(g.fg, 'Forward Graph')
+graph.dot(g.bg, 'Backward Graph')
+```
+
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp3_forward.png" width="300px"/>
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp3_backward.png" width="300px"/>
+
+
+### More fun with graphs
+
+A multi-layer network where each layer takes output of previous two layers as input.
+
+```lua
+input = nn.Identity()()
+L1 = nn.Tanh()(nn.Linear(10, 20)(input))
+L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1})))
+L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2})))
+
+g = nn.gModule({input}, {L3})
+
+indata = torch.rand(10)
+gdata = torch.rand(160)
+g:forward(indata)
+g:backward(indata, gdata)
+
+graph.dot(g.fg, 'Forward Graph')
+graph.dot(g.bg, 'Backward Graph')
+```
+
+As your graph getting bigger and more complicated, the nested parentheses may become confusing. In this case, using `-` to chain the modules is a clearer and easier way:
+```lua
+input = - nn.Identity()
+L1 =  input 
+     - nn.Linear(10, 20) 
+     - nn.Tanh()
+L2 =  { input, L1 }
+     -  nn.JoinTable(1)
+     -  nn.Linear(30,60) 
+     -  nn.Tanh()
+L3 = { L1,L2 }
+     - nn.JoinTable(1)
+     - nn.Linear(80,160)
+     - nn.Tanh()
+g = nn.gModule({input},{L3})
+```
+
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_forward.png" width="300px"/>
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/mlp4_backward.png" width="300px"/>
+
+
+## Annotations
+
+It is possible to add annotations to your network, such as labeling nodes with names or attributes which will show up when you graph the network.
+This can be helpful in large graphs.
+
+For the full list of graph attributes see the
+[graphviz documentation](http://www.graphviz.org/doc/info/attrs.html).
+
+```lua
+input = nn.Identity()()
+L1 = nn.Tanh()(nn.Linear(10, 20)(input)):annotate{
+   name = 'L1', description = 'Level 1 Node',
+   graphAttributes = {color = 'red'}
+}
+L2 = nn.Tanh()(nn.Linear(30, 60)(nn.JoinTable(1)({input, L1}))):annotate{
+   name = 'L2', description = 'Level 2 Node',
+   graphAttributes = {color = 'blue', fontcolor = 'green'}
+}
+L3 = nn.Tanh()(nn.Linear(80, 160)(nn.JoinTable(1)({L1, L2}))):annotate{
+   name = 'L3', descrption = 'Level 3 Node',
+   graphAttributes = {color = 'green',
+   style = 'filled', fillcolor = 'yellow'}
+}
+
+g = nn.gModule({input},{L3})
+
+indata = torch.rand(10)
+gdata = torch.rand(160)
+g:forward(indata)
+g:backward(indata, gdata)
+
+graph.dot(g.fg, 'Forward Graph', '/tmp/fg')
+graph.dot(g.bg, 'Backward Graph', '/tmp/bg')
+```
+
+In this case, the graphs are saved in the following 4 files: `/tmp/{fg,bg}.{dot,svg}`.
+
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/annotation_fg.png" width="300px"/>
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/annotation_bg.png" width="300px"/>
+
+## Debugging
+
+With nngraph, one can create very complicated networks. In these cases, finding errors can be hard. For that purpose, nngraph provides several useful utilities. The following code snippet shows how to use local variable names for annotating the nodes in a graph and how to enable debugging mode that automatically creates an svg file with error node marked in case of a runtime error.
+
+```lua
+
+require 'nngraph'
+
+-- generate SVG of the graph with the problem node highlighted
+-- and hover over the nodes in svg to see the filename:line_number info
+-- nodes will be annotated with local variable names even if debug mode is not enabled.
+nngraph.setDebug(true)
+
+local function get_net(from, to)
+	local from = from or 10
+	local to = to or 10
+	local input_x = nn.Identity()()
+	local linear_module = nn.Linear(from, to)(input_x)
+
+	-- Annotate nodes with local variable names
+	nngraph.annotateNodes()
+	return nn.gModule({input_x},{linear_module})
+end
+
+local net = get_net(10,10)
+
+-- if you give a name to the net, it will use that name to produce the
+-- svg in case of error, if not, it will come up with a name
+-- that is derived from number of inputs and outputs to the graph
+net.name = 'my_bad_linear_net'
+
+-- prepare an input that is of the wrong size to force an error
+local input = torch.rand(11)
+pcall(function() net:updateOutput(input) end)
+-- it should have produced an error and spit out a graph
+-- just run Safari to display the svg
+os.execute('open -a  Safari my_bad_linear_net.svg')
+```
+<img src= "https://raw.github.com/koraykv/torch-nngraph/master/doc/my_bad_linear_net.png" width="300px"/>
+
diff --git a/doc/annotation_bg.png b/doc/annotation_bg.png
new file mode 100644
index 0000000..292a0d0
Binary files /dev/null and b/doc/annotation_bg.png differ
diff --git a/doc/annotation_fg.png b/doc/annotation_fg.png
new file mode 100644
index 0000000..ece5b92
Binary files /dev/null and b/doc/annotation_fg.png differ
diff --git a/doc/mlp.png b/doc/mlp.png
new file mode 100644
index 0000000..76a58be
Binary files /dev/null and b/doc/mlp.png differ
diff --git a/doc/mlp2.png b/doc/mlp2.png
new file mode 100644
index 0000000..a0179d2
Binary files /dev/null and b/doc/mlp2.png differ
diff --git a/doc/mlp3_backward.png b/doc/mlp3_backward.png
new file mode 100644
index 0000000..2507701
Binary files /dev/null and b/doc/mlp3_backward.png differ
diff --git a/doc/mlp3_forward.png b/doc/mlp3_forward.png
new file mode 100644
index 0000000..866b4d5
Binary files /dev/null and b/doc/mlp3_forward.png differ
diff --git a/doc/mlp4_backward.png b/doc/mlp4_backward.png
new file mode 100644
index 0000000..5710a69
Binary files /dev/null and b/doc/mlp4_backward.png differ
diff --git a/doc/mlp4_forward.png b/doc/mlp4_forward.png
new file mode 100644
index 0000000..17fe10b
Binary files /dev/null and b/doc/mlp4_forward.png differ
diff --git a/doc/my_bad_linear_net.png b/doc/my_bad_linear_net.png
new file mode 100644
index 0000000..2b9de1e
Binary files /dev/null and b/doc/my_bad_linear_net.png differ
diff --git a/gmodule.lua b/gmodule.lua
new file mode 100644
index 0000000..6e118d8
--- /dev/null
+++ b/gmodule.lua
@@ -0,0 +1,514 @@
+local nesting = require('nngraph.nesting')
+local utils = require('nngraph.utils')
+local istensor = torch.isTensor
+local istable = utils.istable
+local istorchclass = utils.istorchclass
+
+local function getTotalGradOutput(node)
+   local gradOutput = node.data.gradOutput
+   assert(istable(gradOutput), "expecting gradients to sum")
+   if #gradOutput > 1 then
+      -- Check if we can bypass the allocation, for the special case where all
+      -- gradOutputs but one are zero tensors with an underlying one-element
+      -- storage. Note that for the case that we
+      -- cannot bypass it, this check will only be performed once
+      if not node.data.gradOutputBuffer then
+         local count = 0
+         local idx = 1
+         -- Count how many gradOutput are tensors of 1 element filled with zero
+         for i=1,#gradOutput do
+            local zero = torch.isTensor(gradOutput[i]) and
+                         gradOutput[i]:storage() ~= nil and
+                         gradOutput[i]:storage():size() == 1 and
+                         gradOutput[i]:storage()[1] == 0
+            if not zero then
+               idx = i
+               count = count + 1
+            end
+         end
+         if count < 2 then
+            -- Return the only non-zero one, or the first one
+            -- if they are all zero
+            return gradOutput[idx]
+         end
+      end
+      node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
+      local gobuff = node.data.gradOutputBuffer
+      nesting.resizeNestedAs(gobuff, gradOutput[1])
+      nesting.copyNested(gobuff, gradOutput[1])
+      for i=2,#gradOutput do
+         nesting.addNestedTo(gobuff, gradOutput[i])
+      end
+      gradOutput = gobuff
+   else
+      gradOutput = gradOutput[1]
+   end
+   return gradOutput
+end
+
+-- The gModule allows to have a general non-cyclic graph of of modules.
+--
+-- Each node of the graph can have multiple inputs.
+-- The order of inputs is remembered in node.data.mapindex.
+--
+-- Each node have only one output.
+-- The output can be also a table.
+-- To route parts of the outputted table to different modules,
+-- use the node:split(nOutputs) function.
+-- The split will create subnodes with narrowed output.
+--
+-- Implementation details:
+-- The node.data.input holds a list of inputs.
+-- If a module expects only one input, the node.data.input[1] is used.
+--
+-- The node.data.gradOutput holds the to-be-summed gradOutputs.
+-- Each node has only one output. So we need only one gradOutput.
+local gModule, parent = torch.class('nn.gModule','nn.Container')
+
+function gModule:__init(inputs,outputs)
+   parent.__init(self)
+   -- the graph is defined backwards, we have the output modules as input here
+   -- we will define a dummy output node that connects all output modules
+   -- into itself. This will be the output for the forward graph and
+   -- input point for the backward graph
+   local node
+   local outnode = nngraph.Node({input={}})
+   for i = 1, utils.tableMaxN(outputs) do
+      node = outputs[i]
+      if torch.typename(node) ~= 'nngraph.Node' then
+         error(utils.expectingNodeErrorMessage(node, 'outputs', i))
+      end
+      outnode:add(node, true)
+   end
+   for i = 1, utils.tableMaxN(inputs) do
+      node = inputs[i]
+      if torch.typename(node) ~= 'nngraph.Node' then
+         error(utils.expectingNodeErrorMessage(node, 'inputs', i))
+      end
+   end
+   -- We add also a dummy input node.
+   -- The input node will be split to feed the passed input nodes.
+   local innode = nngraph.Node({input={}})
+   assert(#inputs > 0, "no inputs are not supported")
+   if #inputs == 1 then
+      inputs[1]:add(innode,true)
+   else
+      local splits = {innode:split(#inputs)}
+      for i = 1, #inputs do
+         assert(#inputs[i].children == 0, "an input should have no inputs")
+      end
+      for i = 1, #inputs do
+         inputs[i]:add(splits[i],true)
+      end
+   end
+
+   -- the backward graph (bg) is for gradients
+   -- the forward graph (fg) is for function evaluation
+   self.bg = outnode:graph()
+   self.fg = self.bg:reverse()
+
+   -- the complete graph is constructed
+   -- now regenerate the graphs with the additional nodes
+
+   local roots = self.fg:roots()
+   -- if there are more than one root in the forward graph, then make sure that
+   -- extra roots are parameter nodes
+   if #roots > 1 then
+      local innodeRoot = nil
+      -- first find our innode
+      for _, root in ipairs(roots) do
+         if root.data == innode.data then
+            assert(innodeRoot == nil, 'more than one matching input node found in leaves')
+            innodeRoot = root
+         else
+            assert(root.data.module, 'Expected nnop.Parameters node, module not found in node')
+            assert(torch.typename(root.data.module) == 'nnop.Parameters',
+                  'Expected nnop.Parameters node, found : ' ..torch.typename(root.data.module))
+         end
+      end
+      assert(innodeRoot ~= nil, 'input node not found among roots')
+      self.innode = innodeRoot
+   else
+      assert(#self.fg:roots() == 1, "expecting only one start")
+      self.innode = self.fg:roots()[1]
+   end
+
+   assert(self.innode.data == innode.data, "expecting the forward innode")
+   self.outnode = outnode
+   self.verbose = false
+   self.nInputs = #inputs
+
+   -- computation on the graph is done through topsort of forward and backward graphs
+   self.forwardnodes = self.fg:topsort()
+   self.backwardnodes = self.bg:topsort()
+
+   -- iteratare over all nodes: check, tag and add to container
+   for i,node in ipairs(self.forwardnodes) do
+      -- check for unused inputs or unused split() outputs
+      if node.data.nSplitOutputs and node.data.nSplitOutputs ~=  #node.children then
+         local nUnused = node.data.nSplitOutputs - #node.children
+         local debugLabel = node.data.annotations._debugLabel
+         local errStr =
+            "%s of split(%s) outputs from the node declared at %s are unused"
+         error(string.format(errStr, nUnused, node.data.nSplitOutputs,
+                             debugLabel))
+      end
+
+      -- Check whether any nodes were defined as taking this node as an input,
+      -- but then left dangling and don't connect to the output. If this is
+      -- the case, then they won't be present in forwardnodes, so error out.
+      for successor, _ in pairs(node.data.reverseMap) do
+         local successorIsInGraph = false
+
+         -- Only need to the part of forwardnodes from i onwards, topological
+         -- sort guarantees it cannot be in the first part.
+         for j = i+1, #self.forwardnodes do
+            -- Compare equality of data tables, as new Node objects have been
+            -- created by processes such as topoological sort, but the
+            -- underlying .data table is shared.
+            if self.forwardnodes[j].data == successor.data then
+               successorIsInGraph = true
+               break
+            end
+         end
+         local errStr =
+            "node declared on %s does not connect to gmodule output"
+         assert(successorIsInGraph,
+                string.format(errStr, successor.data.annotations._debugLabel))
+      end
+
+      -- set data.forwardNodeId for node:label() output
+      node.data.forwardNodeId = node.id
+
+      -- add module to container
+      if node.data.module then
+         self:add(node.data.module)
+      end
+   end
+
+   self.output = nil
+   self.gradInput = nil
+   if #self.outnode.children > 1 then
+      self.output = self.outnode.data.input
+   end
+end
+
+function gModule:replace(callback)
+    local out = callback(self)
+    local revmodules = {}
+    for i,m in ipairs(self.modules) do
+        revmodules[m] = i
+    end
+    for i,node in ipairs(self.forwardnodes) do
+        if node.data.module then
+            local m = node.data.module
+            node.data.module = m:replace(callback)
+            self.modules[revmodules[m]] = node.data.module
+        end
+    end
+    return out
+end
+
+function gModule:map(gm, func)
+   for i,node in ipairs(self.forwardnodes) do
+      local gmnode = gm.forwardnodes[i]
+      assert(gmnode, 'trying to map another gModule with a different structure')
+      if node.data.module then
+         assert(gmnode.data.module, 'trying to map another gModule with a different structure')
+         func(node.data.module, gmnode.data.module)
+      end
+   end
+end
+
+--[[ Recursively applies type(type_str) to any tensors in the argument. If the
+argument is a tensor, type(type_str) is applied; if the argument is an array,
+this function recurses into it. ]]
+local function recursiveType(param, type_str)
+   if torch.type(param) == 'table' then
+      for i = 1, #param do
+         param[i] = recursiveType(param[i], type_str)
+      end
+   elseif torch.typename(param) and
+      torch.typename(param):find('torch%..+Tensor') then
+      param = param:type(type_str)
+   end
+   return param
+end
+
+function gModule:type(type, tensorCache)
+   if not type then
+      return self._type
+   end
+
+   tensorCache = tensorCache or {}
+
+   local function applyTypeToTable(table)
+      for key, value in pairs(table) do
+         table[key] = recursiveType(table[key], type)
+      end
+   end
+
+   -- Convert any stored data in self, and in the in and out nodes
+   applyTypeToTable(self)
+   if self.innode then applyTypeToTable(self.innode.data) end
+   if self.outnode then applyTypeToTable(self.outnode.data) end
+
+   -- Loop through modules and convert data
+   for _, m in ipairs(self.modules) do
+      m:type(type, tensorCache)
+   end
+
+   for i,node in ipairs(self.backwardnodes) do
+      if node.data.gradOutputBuffer ~= nil then
+         node.data.gradOutputBuffer =
+            recursiveType(node.data.gradOutputBuffer, type)
+      end
+      for k, child in ipairs(node.children) do
+         applyTypeToTable(child.data)
+      end
+   end
+
+   for i,node in ipairs(self.forwardnodes) do
+      if node.data.input ~= nil then
+         node.data.input = recursiveType(node.data.input, type)
+      end
+      for k, child in ipairs(node.children) do
+         applyTypeToTable(child.data)
+      end
+   end
+
+   self._type = type
+   return self
+end
+
+function gModule:updateOutput(input)
+   return self:runForwardFunction('updateOutput',input)
+end
+
+function gModule:clearState()
+   local ret = parent.clearState(self)
+   for _,node in ipairs(self.backwardnodes) do
+      node.data.gradOutput = nil
+      node.data.gradOutputBuffer = nil
+   end
+   for _,node in ipairs(self.forwardnodes) do
+      node.data.input = nil
+   end
+   return ret
+end
+
+function gModule:runForwardFunction(func,input)
+   if type(func) == "string" then
+      local func_name = func
+      func = function(module,input) return module[func_name](module,input) end
+   end
+   -- For backward compatibility, we allow self.nInputs to be missing.
+   local nInputs = self.nInputs or #self.innode.children
+   -- We see the input as a list of inputs.
+   if nInputs <= 1 then
+      input={input}
+   elseif type(input) ~= "table" then
+      error(string.format("expecting table of %s inputs", nInputs))
+   end
+   local function neteval(node)
+      local function propagate(node,x)
+         for i,child in ipairs(node.children) do
+            child.data.input = child.data.input or {}
+            local mapindex = child.data.mapindex[node.data]
+            assert(not child.data.input[mapindex], "each input should have one source")
+            child.data.input[mapindex] = x
+         end
+      end
+      if node.data.selectindex then
+         assert(not node.data.module, "the selectindex-handling nodes should have no module")
+         local input = node.data.input
+         assert(#input == 1, "only the splitted node should be the input")
+         assert(istable(input[1]), "the input for a split should be a table")
+         input = input[1][node.data.selectindex]
+         propagate(node,input)
+      else
+         local input = node.data.input
+
+         -- a parameter node is captured
+         if input == nil and node.data.module ~= nil then
+            input = {}
+         end
+         if #input == 1 then
+            input = input[1]
+         end
+         -- forward through this node
+         -- If no module is present, the node behaves like nn.Identity.
+         local output
+         if not node.data.module then
+            output = input
+         else
+            output = func(node.data.module,input)
+         end
+         if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
+            error(string.format("split(%s) cannot split %s outputs",
+            node.data.nSplitOutputs,
+            #output))
+         end
+         -- propagate the output to children
+         propagate(node,output)
+      end
+      if self.verbose then
+         print(' V : ' .. node:label())
+      end
+   end
+
+   local innode = self.innode
+   if #input ~= nInputs then
+      error(string.format('Got %s inputs instead of %s', #input, nInputs))
+   end
+   -- first clear the input states
+   for _,node in ipairs(self.forwardnodes) do
+      local input = node.data.input
+      while input and #input>0 do
+         table.remove(input)
+      end
+   end
+   -- Set the starting input.
+   -- We do copy instead of modifying the passed input.
+   innode.data.input = innode.data.input or {}
+   for i, item in ipairs(input) do
+      innode.data.input[i] = item
+   end
+
+   -- the run forward
+   for i,node in ipairs(self.forwardnodes) do
+      neteval(node)
+   end
+
+   self.output = self.outnode.data.input
+   if #self.outnode.children == 1 then
+      self.output = self.output[1]
+   end
+   return self.output
+end
+
+function gModule:updateGradInput(input,gradOutput)
+   local function neteval(node)
+      if node.data.selectindex then
+         assert(not node.data.module, "the selectindex-handling nodes should have no module")
+         assert(#node.children == 1, "only the splitted node should be the input")
+         local child = node.children[1]
+         local go = getTotalGradOutput(node)
+         child.data.gradOutput = child.data.gradOutput or {}
+         assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
+         -- The data.gradOutput holds the to-be-summed gradients.
+         child.data.gradOutput[1] = child.data.gradOutput[1] or {}
+         assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
+         child.data.gradOutput[1][node.data.selectindex] = go
+      else
+         local gradOutput = getTotalGradOutput(node)
+         -- updateGradInput through this node
+         -- If no module is present, the node behaves like nn.Identity.
+         local gradInput
+         if not node.data.module then
+            gradInput = gradOutput
+         else
+            local input = node.data.input
+            -- a parameter node is captured
+            if input == nil and node.data.module ~= nil then
+               input = {}
+            end
+            if #input == 1 then
+               input = input[1]
+            end
+            local module = node.data.module
+            gradInput = module:updateGradInput(input,gradOutput)
+         end
+         -- propagate the output to children
+         for i,child in ipairs(node.children) do
+            child.data.gradOutput = child.data.gradOutput or {}
+            local mapindex = node.data.mapindex[child.data]
+            local gi
+            if #node.children == 1 then
+               gi = gradInput
+            else
+               gi = gradInput[mapindex]
+            end
+            table.insert(child.data.gradOutput,gi)
+         end
+      end
+      if self.verbose then
+         print(' V : ' .. node:label())
+      end
+   end
+   local outnode = self.outnode
+   if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
+   end
+   for _,node in ipairs(self.backwardnodes) do
+      local gradOutput = node.data.gradOutput
+      while gradOutput and #gradOutput >0 do
+         table.remove(gradOutput)
+      end
+   end
+   -- Set the starting gradOutput.
+   outnode.data.gradOutput = outnode.data.gradOutput or {}
+   outnode.data.gradOutput[1] = gradOutput
+
+   for i,node in ipairs(self.backwardnodes) do
+      neteval(node)
+   end
+
+   assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
+   self.gradInput = self.innode.data.gradOutput[1]
+   return self.gradInput
+end
+
+function gModule:accGradParameters(input,gradOutput,lr)
+   local function neteval(node)
+      if node.data.module then
+         local module = node.data.module
+         local gradOutput = node.data.gradOutput[1]
+         if #node.data.gradOutput > 1 then
+            gradOutput = node.data.gradOutputBuffer
+         end
+         local input = node.data.input
+         -- a parameter node is captured
+         if input == nil and node.data.module ~= nil then
+            input = {}
+         end
+         if #input == 1 then
+            input = input[1]
+         end
+         -- accGradParameters through this node
+         module:accGradParameters(input,gradOutput,lr)
+      end
+      if self.verbose then
+         print(' V : ' .. node:label())
+      end
+   end
+   local outnode = self.outnode
+   if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+      error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
+   end
+   for i,node in ipairs(self.backwardnodes) do
+      neteval(node)
+   end
+end
+
+function gModule:read(file)
+   local data = file:readObject()
+   for k, v in pairs(data) do
+      self[k] = v
+   end
+
+   -- Initialize the modules table if necessary.
+   if not self.modules then
+      self.modules = {}
+      for _, node in ipairs(self.forwardnodes) do
+         if node.data.module then
+            table.insert(self.modules, node.data.module)
+         end
+      end
+   end
+end
+
+
+function gModule:__tostring__()
+   return self.name or torch.type(self)
+end
diff --git a/graphinspecting.lua b/graphinspecting.lua
new file mode 100644
index 0000000..8e858cf
--- /dev/null
+++ b/graphinspecting.lua
@@ -0,0 +1,159 @@
+
+-- The findCurrentNode() depends on the names of the
+-- local variables in the nngraph.gModule source code.
+local function findCurrentNode()
+   for level = 2, math.huge do
+      local info = debug.getinfo(level, "n")
+      if info == nil then
+         return nil
+      end
+
+      local funcName = info.name
+      if funcName == "neteval" then
+         local varName, node = debug.getlocal(level, 1)
+         if varName == "node" then
+            return node
+         end
+      end
+   end
+end
+
+-- Runs the func and calls onError(failedNode, ...) on an error.
+-- The stack trace is inspected to find the failedNode.
+local function runChecked(func, onError, ...)
+   -- The current node needs to be searched-for, before unrolling the stack.
+   local failedNode
+   local function errorHandler(message)
+      -- The stack traceback is added only if not already present.
+      if not string.find(message, 'stack traceback:\n', 1, true) then
+         message = debug.traceback(message, 2)
+      end
+      failedNode = findCurrentNode()
+      return message
+   end
+
+   local ok, result = xpcall(func, errorHandler)
+   if ok then
+      return result
+   end
+
+   onError(failedNode, ...)
+   -- Passing the level 0 avoids adding an additional error position info
+   -- to the message.
+   error(result, 0)
+end
+
+local function customToDot(graph, title, failedNode)
+   local str = graph:todot(title)
+   if not failedNode then
+      return str
+   end
+
+   local failedNodeId = nil
+   for i, node in ipairs(graph.nodes) do
+      if node.data == failedNode.data then
+         failedNodeId = node.id
+         break
+      end
+   end
+
+   if failedNodeId ~= nil then
+      -- The closing '}' is removed.
+      -- And red fillcolor is specified for the failedNode.
+      str = string.gsub(str, '}%s*$', '')
+      str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
+      failedNodeId)
+   end
+   return str
+end
+
+local function saveSvg(svgPathPrefix, dotStr)
+   io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
+   local dotPath = svgPathPrefix .. '.dot'
+   local dotFile = io.open(dotPath, 'w')
+   dotFile:write(dotStr)
+   dotFile:close()
+
+   local svgPath = svgPathPrefix .. '.svg'
+   local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
+   os.execute(cmd)
+end
+
+local function onError(failedNode, gmodule)
+   local nInputs = gmodule.nInputs or #gmodule.innode.children
+   local svgPathPrefix = gmodule.name or string.format(
+   'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
+   if paths.filep(svgPathPrefix .. '.svg') then
+      svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname())
+   end
+   local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
+   saveSvg(svgPathPrefix, dotStr)
+end
+
+local origFuncs = {
+   runForwardFunction = nn.gModule.runForwardFunction,
+   updateGradInput = nn.gModule.updateGradInput,
+   accGradParameters = nn.gModule.accGradParameters,
+}
+
+-- When debug is enabled,
+-- a gmodule.name .. '.svg' will be saved
+-- if an exception occurs in a graph execution.
+-- The problematic node will be marked by red color.
+function nngraph.setDebug(enable)
+   if not enable then
+      -- When debug is disabled,
+      -- the origFuncs are restored on nn.gModule.
+      for funcName, origFunc in pairs(origFuncs) do
+         nn.gModule[funcName] = origFunc
+      end
+      return
+   end
+
+   for funcName, origFunc in pairs(origFuncs) do
+      nn.gModule[funcName] = function(...)
+         local args = {...}
+         local gmodule = args[1]
+	 local unpack = unpack or table.unpack
+         return runChecked(function()
+            return origFunc(unpack(args))
+         end, onError, gmodule)
+      end
+   end
+end
+
+-- Sets node.data.annotations.name for the found nodes.
+-- The local variables at the given stack level are inspected.
+-- The default stack level is 1 (the function that called annotateNodes()).
+function nngraph.annotateNodes(stackLevel)
+   stackLevel = stackLevel or 1
+   for index = 1, math.huge do
+      local varName, varValue = debug.getlocal(stackLevel + 1, index)
+      if not varName then
+         break
+      end
+      if torch.typename(varValue) == "nngraph.Node" then
+         -- An explicit name is preserved.
+         if not varValue.data.annotations.name then
+            varValue:annotate({name = varName})
+         end
+      end
+   end
+end
+
+--[[
+   SVG visualization for gmodule
+   TODO: add custom coloring with node types
+]]
+function nngraph.display(gmodule)
+   local ffi = require 'ffi'
+   local cmd
+   if ffi.os == 'Linux' then
+      cmd = 'xdg-open'
+   elseif ffi.os == 'OSX' then
+      cmd = 'open -a Safari'
+   end
+   local fname = os.tmpname()
+   graph.dot(gmodule.fg, fname, fname)
+   os.execute(cmd .. ' ' .. fname .. '.svg')
+end
diff --git a/init.lua b/init.lua
new file mode 100644
index 0000000..0e354f6
--- /dev/null
+++ b/init.lua
@@ -0,0 +1,80 @@
+require 'nn'
+require 'graph'
+
+nngraph = {}
+
+require('nngraph.nest')
+require('nngraph.node')
+require('nngraph.gmodule')
+require('nngraph.graphinspecting')
+require('nngraph.JustElement')
+require('nngraph.JustTable')
+require('nngraph.ModuleFromCriterion')
+
+-- handy functions
+local utils = require('nngraph.utils')
+local istensor = torch.isTensor
+local istable = utils.istable
+local istorchclass = utils.istorchclass
+
+-- simpler todot functions
+nngraph.simple_print =  require('nngraph.simple_print')
+
+-- Modify the __call function to hack into nn.Module
+local Module = torch.getmetatable('nn.Module')
+function Module:__call__(...)
+   local nArgs = select("#", ...)
+   assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
+
+   local input = ...
+   if nArgs == 1 and input == nil then
+      error(utils.expectingNodeErrorMessage(input, 'inputs', 1))
+   end
+   -- Disallow passing empty table, in case someone passes a table with some
+   -- typo'd variable name in.
+   if type(input) == 'table' and next(input) == nil then
+      error('cannot pass an empty table of inputs. To indicate no incoming ' ..
+            'connections, leave the second set of parens blank.')
+   end
+   if not istable(input) then
+      input = {input}
+   end
+   local mnode = nngraph.Node({module=self})
+
+   local dnode
+   for i = 1, utils.tableMaxN(input) do
+      dnode = input[i]
+      if torch.typename(dnode) ~= 'nngraph.Node' then
+         error(utils.expectingNodeErrorMessage(dnode, 'inputs', i))
+      end
+      mnode:add(dnode,true)
+   end
+
+   return mnode
+end
+
+local Criterion = torch.getmetatable('nn.Criterion')
+function Criterion:__call__(...)
+   return nn.ModuleFromCriterion(self)(...)
+end
+
+
+
+
+Module.__unm__ = function( obj )
+    return obj()
+end
+
+Module.__sub__ = function( prev, next )
+    return next(prev)
+end
+
+
+do
+    local Node = torch.getmetatable('nngraph.Node')
+    Node.__sub__ = function( prev, next )
+        return next(prev)
+    end
+end
+
+return nngraph
diff --git a/nest.lua b/nest.lua
new file mode 100644
index 0000000..a9da62e
--- /dev/null
+++ b/nest.lua
@@ -0,0 +1,46 @@
+
+local function isNode(input)
+   local typename = torch.typename(input)
+   return typename and typename == 'nngraph.Node'
+end
+
+local function isNonEmptyList(input)
+   return type(input) == "table" and #input > 0
+end
+
+local function _nest(input)
+   if not isNode(input) and not isNonEmptyList(input) then
+      error('what is this in the nest input? ' .. tostring(input))
+   end
+
+   if isNode(input) then
+      return input
+   end
+
+   if #input == 1 then
+      return nngraph.JustTable()(input)
+   end
+
+   local wrappedChildren = {}
+   for i, child in ipairs(input) do
+      wrappedChildren[i] = _nest(child)
+   end
+   return nn.Identity()(wrappedChildren)
+end
+
+-- Returns a nngraph node to represent a nested structure.
+-- Usage example:
+--    local in1 = nn.Identity()()
+--    local in2 = nn.Identity()()
+--    local in3 = nn.Identity()()
+--    local ok = nn.CAddTable()(nngraph.nest({in1}))
+--    local in1Again = nngraph.nest(in1)
+--    local state = nngraph.nest({in1, {in2}, in3})
+function nngraph.nest(...)
+   local nArgs = select("#", ...)
+   assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
+
+   local input = ...
+   assert(nArgs > 0 and input ~= nil, 'Pass an input.')
+   return _nest(input)
+end
diff --git a/nesting.lua b/nesting.lua
new file mode 100644
index 0000000..8c497f8
--- /dev/null
+++ b/nesting.lua
@@ -0,0 +1,85 @@
+
+local nesting = {}
+
+local utils = require('nngraph.utils')
+
+-- Creates a clone of a tensor or of a table with tensors.
+function nesting.cloneNested(obj)
+   if torch.isTensor(obj) then
+      return obj:clone()
+   end
+
+   local result = {}
+   for key, child in pairs(obj) do
+      result[key] = nesting.cloneNested(child)
+   end
+   return result
+end
+
+-- Fills the obj with the given value.
+-- The obj can be a tensor or a table with tensors.
+function nesting.fillNested(obj, value)
+   if torch.isTensor(obj) then
+      obj:fill(value)
+   else
+      for key, child in pairs(obj) do
+         nesting.fillNested(child, value)
+      end
+   end
+end
+
+-- Resizes all tensors in the output.
+function nesting.resizeNestedAs(output, input)
+   if torch.isTensor(output) then
+      output:resizeAs(input)
+   else
+      for key, child in pairs(input) do
+         -- A new element is added to the output, if needed.
+         if not output[key] then
+            output[key] = nesting.cloneNested(child)
+         else
+            nesting.resizeNestedAs(output[key], child)
+         end
+      end
+      -- Extra elements are removed from the output.
+      for key, child in pairs(output) do
+         if not input[key] then
+            output[key] = nil
+         end
+      end
+   end
+end
+
+-- Copies all tensors in the output.
+function nesting.copyNested(output, input)
+   if torch.isTensor(output) then
+      output:copy(input)
+   else
+      for key, child in pairs(input) do
+          nesting.copyNested(output[key], child)
+      end
+      -- Extra elements in the output table cause an error.
+      for key, child in pairs(output) do
+         if not input[key] then
+            error('key ' .. tostring(key) ..
+                  ' present in output but not in input')
+         end
+      end
+   end
+end
+
+-- Adds the input to the output.
+-- The input can contain nested tables.
+-- The output will contain the same nesting of tables.
+function nesting.addNestedTo(output, input)
+   if torch.isTensor(output) then
+      output:add(input)
+   else
+      for key, child in pairs(input) do
+         assert(output[key] ~= nil, "missing key")
+         nesting.addNestedTo(output[key], child)
+      end
+   end
+end
+
+return nesting
diff --git a/nngraph-scm-1.rockspec b/nngraph-scm-1.rockspec
new file mode 100644
index 0000000..4963ecd
--- /dev/null
+++ b/nngraph-scm-1.rockspec
@@ -0,0 +1,30 @@
+package = "nngraph"
+version = "scm-1"
+
+source = {
+   url = "git://github.com/torch/nngraph",
+   tag = "master"
+}
+
+description = {
+   summary = "This package provides graphical computation for nn library in Torch7.",
+   homepage = "https://github.com/torch/nngraph",
+   license = "UNKNOWN"
+}
+
+dependencies = {
+   "torch >= 7.0",
+   "graph",
+   "nn"
+}
+
+build = {
+   type = "command",
+   build_command = [[
+cmake -E make_directory build;
+cd build;
+cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)"; 
+$(MAKE)
+   ]],
+   install_command = "cd build && $(MAKE) install"
+}
diff --git a/node.lua b/node.lua
new file mode 100644
index 0000000..b9cf87b
--- /dev/null
+++ b/node.lua
@@ -0,0 +1,178 @@
+
+local utils = require('nngraph.utils')
+local istensor = torch.isTensor
+local istable = utils.istable
+local istorchclass = utils.istorchclass
+require 'debug'
+
+local nnNode,parent = torch.class('nngraph.Node','graph.Node')
+
+function nnNode:__init(data)
+   parent.__init(self,data)
+   self.data.annotations = self.data.annotations or {}
+   self.data.mapindex = self.data.mapindex or {}
+   self.data.reverseMap = self.data.reverseMap or {}
+   if not self.data.annotations._debugLabel then
+      self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
+   end
+end
+
+--[[ Build a string label which will be used a tooltip when
+making a graph.]]
+function nnNode:_makeDebugLabel(dinfo)
+   if dinfo then
+      self.data.annotations._debugLabel = string.format('[%s]:%d_%s',
+                                                        dinfo.short_src,
+                                                        dinfo.currentline,
+                                                        dinfo.name or '')
+   end
+end
+
+-- domap ensures that this node will keep track of the order its children are added.
+-- mapindex is a forward/backward list
+-- index = self.data.mapindex[child.data]
+-- child.data = self.data.mapindex[index]
+function nnNode:add(child,domap)
+   parent.add(self,child)
+   if domap then
+      local mapindex = self.data.mapindex
+      local data = child.data
+      assert(not mapindex[data], "Don't pass the same input twice.")
+      table.insert(mapindex,data)
+      mapindex[data] = #mapindex
+
+      -- The "child" that is added here actually represents the input node,
+      -- so we write into that node to indicate that we are downstream of it.
+      -- This enables dangling pointer detection.
+      local revMap = child.data.reverseMap
+      assert(not revMap[self], 'this connection has already been made!')
+      revMap[self] = true
+   end
+end
+
+-- this function returns noutput number of new nodes
+-- that each take a single component of the output of this
+-- node in the order they are returned.
+function nnNode:split(noutput)
+   if noutput == 1 then
+     return nngraph.JustElement()(self)
+   end
+   local debugLabel = self.data.annotations._debugLabel
+   -- Specify the source location where :split is called.
+   local dinfo = debug.getinfo(2, 'Sl')
+   local splitLoc = string.format(' split at [%s]:%d',
+                                  dinfo.short_src,
+                                  dinfo.currentline)
+   local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. splitLoc .. '-mnode'}})
+   mnode:add(self,true)
+
+   local selectnodes = {}
+   for i=1,noutput do
+      local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
+      node:add(mnode,true)
+      table.insert(selectnodes,node)
+   end
+
+   local unpack = unpack or table.unpack -- Lua52 compat
+   return unpack(selectnodes)
+end
+
+function nnNode:annotate(annotations)
+   for k, v in pairs(annotations) do
+      self.data.annotations[k] = v
+   end
+
+   return self
+end
+
+function nnNode:graphNodeName()
+   if self.data.annotations.name then
+      return self.data.annotations.name .. ' (' .. self.id .. ')'
+   else
+      return 'Node' .. self.id
+   end
+end
+
+function nnNode:graphNodeAttributes()
+   self.data.annotations.graphAttributes =
+   self.data.annotations.graphAttributes or {}
+   if not self.data.annotations.graphAttributes.tooltip then
+      self.data.annotations.graphAttributes.tooltip =
+      self.data.annotations._debugLabel
+   end
+
+   return self.data.annotations.graphAttributes
+end
+
+local function getNanFlag(data)
+   if data:nElement() == 0 then
+      return ''
+   end
+   local isNan = (data:ne(data):sum() > 0)
+   if isNan then
+      return 'NaN'
+   end
+   if data:max() == math.huge then
+      return 'inf'
+   end
+   if data:min() == -math.huge then
+      return '-inf'
+   end
+   return ''
+end
+
+function nnNode:label()
+
+   local lbl = {}
+
+   local function getstr(data)
+      if not data then return '' end
+      if istensor(data) then
+         local nanFlag = getNanFlag(data)
+         local tensorType = 'Tensor'
+         if data:type() ~= torch.Tensor():type() then
+            tensorType = data:type()
+         end
+         return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+      elseif istable(data) then
+         local tstr = {}
+         for i,v in ipairs(data) do
+            table.insert(tstr, getstr(v))
+         end
+         return '{' .. table.concat(tstr,',') .. '}'
+      else
+         return tostring(data):gsub('\n','\\l')
+      end
+   end
+   local function getmapindexstr(mapindex)
+      local tstr = {}
+      for i,data in ipairs(mapindex) do
+         local inputId = 'Node' .. (data.forwardNodeId or '')
+         table.insert(tstr, inputId)
+      end
+      return '{' .. table.concat(tstr,',') .. '}'
+   end
+
+   for k,v in pairs(self.data) do
+      local vstr = ''
+      if k== 'mapindex' then
+         if #v > 1 then
+            vstr = getmapindexstr(v)
+            table.insert(lbl, k .. ' = ' .. vstr)
+         end
+      elseif k== 'forwardNodeId' or k== 'annotations' then
+         -- the forwardNodeId is not displayed in the label.
+      else
+         vstr = getstr(v)
+         table.insert(lbl, k .. ' = ' .. vstr)
+      end
+   end
+
+   local desc
+   if self.data.annotations.description then
+      desc = 'desc = ' .. self.data.annotations.description .. '\\n'
+   else
+      desc = ''
+   end
+   return desc .. table.concat(lbl,"\\l")
+end
diff --git a/simple_print.lua b/simple_print.lua
new file mode 100644
index 0000000..87bf152
--- /dev/null
+++ b/simple_print.lua
@@ -0,0 +1,124 @@
+local function removeNodeFromEdges(node_id, edges)
+   local from_nodes = {}
+   local to_nodes = {}
+   -- remove edges
+   local idx = 1
+   while idx <= #edges do
+      local edge = edges[idx]
+      if edge.source == node_id then
+         local to_node = edges[idx].target
+         table.insert(to_nodes, to_node)
+         table.remove(edges, idx)
+      elseif edge.target == node_id then
+         local from_node = edges[idx].source
+         table.insert(from_nodes, from_node)
+         table.remove(edges, idx)
+      else
+         idx = idx + 1
+      end
+   end
+
+   -- add new edges
+   for _, f in pairs(from_nodes) do
+      for _, t in pairs(to_nodes) do
+         local edge = {source = f, target= t}
+         table.insert(edges, edge)
+      end
+   end
+
+   return edges
+end
+
+local function isNodeGood(node)
+   return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
+end
+
+local function reIndexNodes(nodes, edges)
+   -- make reverse map
+   local rev_map = {}
+   for idx = 1, #nodes do
+      rev_map[nodes[idx].id] = idx
+      nodes[idx].id = idx
+   end
+   for idx = 1, #edges do
+      local edge = edges[idx]
+      edge.source = rev_map[edge.source]
+      edge.target = rev_map[edge.target]
+   end
+   return nodes, edges
+end
+
+local function cleanGraph(nodes, edges)
+   local idx = 1
+   while idx <= #nodes do
+      local node = nodes[idx]
+      if isNodeGood(node.orig_node) then
+         idx = idx + 1
+      else
+         local id = node.id
+         table.remove(nodes, idx)
+         edges = removeNodeFromEdges(id, edges)
+      end
+   end
+   return reIndexNodes(nodes, edges)
+end
+
+local function loadGraph(graph)
+   local nodes = {}
+   local edges = {}
+   for _, node in ipairs(graph.nodes) do
+      local idx = node.id
+      table.insert(nodes, {id=idx, orig_node = node} )
+      for ich = 1, #node.children do
+         table.insert( edges, {source = idx, target = node.children[ich].id})
+      end
+   end
+   nodes, edges = cleanGraph(nodes, edges)
+   return nodes , edges
+end
+
+local M = {}
+
+function M.todot( graph, title )
+   local nodes, edges = loadGraph(graph)
+   local str = {}
+   table.insert(str,'digraph G {\n')
+   if title then
+      table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
+   end
+   table.insert(str,'node [shape = oval]; ')
+   local nodelabels = {}
+   for i,node in ipairs(nodes) do
+      local true_node = node.orig_node
+      local l =  '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
+      nodelabels[i] = 'n' .. true_node.id
+      table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
+   end
+   table.insert(str,'\n')
+   for i,edge in ipairs(edges) do
+      table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
+   end
+   table.insert(str,'}')
+   return table.concat(str,'')
+end
+
+function M.dot(g,title,fname)
+   local gv = M.todot(g, title)
+   local fngv = (fname or os.tmpname()) .. '.dot'
+   local fgv = io.open(fngv,'w')
+   fgv:write(gv)
+   fgv:close()
+   local fnsvg = (fname or os.tmpname()) .. '.svg'
+   os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
+   if not fname then
+      require 'qtsvg'
+      local qs = qt.QSvgWidget(fnsvg)
+      qs:show()
+      os.remove(fngv)
+      os.remove(fnsvg)
+      -- print(fngv,fnpng)
+      return qs
+   end
+end
+
+return M
diff --git a/test/speed.lua b/test/speed.lua
new file mode 100644
index 0000000..7218cbe
--- /dev/null
+++ b/test/speed.lua
@@ -0,0 +1,111 @@
+
+require 'nngraph'
+
+function time_benchmark(model, input, n)
+   local forward_timer = torch.Timer():stop():reset()
+   local backward_timer = torch.Timer():stop():reset()
+   local total_timer = torch.Timer():stop():reset()
+   local gradOut
+   total_timer:resume()
+   for i = 1, n do
+      forward_timer:resume()
+      local out = model:forward(input)
+      forward_timer:stop()
+      if not gradOut then
+         gradOut = torch.rand(out:size())
+      end
+      backward_timer:resume()
+      model:backward(input, gradOut)
+      backward_timer:stop()
+   end
+   total_timer:stop()
+
+   return {forward = forward_timer:time().real,
+   backward = backward_timer:time().real,
+   total = total_timer:time().real}
+end
+
+function report_benchmark(result, title)
+   local nspace = (80-string.len(title))/2
+   report = {string.rep('#', 80),
+   string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
+   string.format('Total Time Spent = %.2f s', result.total),
+   string.format('    Forward Time = %.2f s', result.forward),
+   string.format('   Backward Time = %.2f s', result.backward),
+   string.rep('#', 80)
+}
+print(table.concat(report, '\n'))
+return result
+end
+
+function compare_benchmarks(result, base, title)
+   local nspace = (80-string.len(title))/2
+   report = {string.rep('#', 80),
+   string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
+   string.format('Total Time Spent = %.2f %%', result.total/base.total*100),
+   string.format('    Forward Time = %.2f %%', result.forward/base.forward*100),
+   string.format('   Backward Time = %.2f %%', result.backward/base.backward*100),
+   string.rep('#', 80)
+}
+print(table.concat(report, '\n'))
+return result
+end
+
+function get_models(nhidden_layers, ninput, noutput, nhidden)
+
+   local function get_concat_layer(nfrom, nto)
+      local concat_module = nn.Sequential()
+      local concat_layer = nn.ConcatTable()
+      concat_layer:add(nn.Linear(nfrom, nto))
+      concat_layer:add(nn.Linear(nfrom, nto))
+      concat_module:add(concat_layer)
+      concat_module:add(nn.CAddTable())
+      concat_module:add(nn.ReLU())
+      return concat_module
+   end
+
+   -- NN
+   local nn_model = nn.Sequential()
+   nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU())
+   for i = 1, nhidden_layers do
+      nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU())
+   end
+   nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput))
+
+   -- NN graph
+   local input = nn.Identity()()
+   local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input),
+   nn.Linear(ninput, nhidden)(input)}))
+   for i = 1, nhidden_layers do
+      nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model),
+      nn.Linear(nhidden, nhidden)(nng_model)}))
+   end
+   nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model),
+   nn.Linear(nhidden, noutput)(nng_model)}))
+
+   nng_model = nn.gModule({input},{nng_model})
+
+   return {nn_model = nn_model, nng_model = nng_model}
+end
+
+function get_options(arg)
+   local cmd = torch.CmdLine()
+   cmd:text('nngraph benchmarking')
+   cmd:option('-niter', 10, 'number of iterations of forward/backward for each model')
+   cmd:option('-nhidden_layers', 10, 'number of hidden layers')
+   cmd:option('-input_size', 512, 'size of input')
+   cmd:option('-batch_size', 16, 'size of batch')
+   cmd:option('-hidden_size', 1024, 'size of hidden layer')
+   cmd:option('-output_size', 128, 'size of output layer')
+   local opt = cmd:parse(arg)
+   return opt
+end
+
+local opt = get_options(arg)
+models = get_models(opt.nhidden_layers, opt.input_size, opt.output_size, opt.hidden_size)
+print(opt)
+
+local nn_bench = report_benchmark(time_benchmark(models.nn_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NN')
+local nng_bench = report_benchmark(time_benchmark(models.nng_model, torch.rand(opt.batch_size,opt.input_size), opt.niter), 'NNGRAPH')
+compare_benchmarks(nng_bench, nn_bench, 'NNGRAPH / NN (%)')
+
diff --git a/test/test_JustElement.lua b/test/test_JustElement.lua
new file mode 100644
index 0000000..d6c49a8
--- /dev/null
+++ b/test/test_JustElement.lua
@@ -0,0 +1,28 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+   local input = {torch.randn(7, 11)}
+   local module = nngraph.JustElement()
+   tester:eq(module:forward(input), input[1], "output")
+end
+
+function test.test_grad()
+   local input = {torch.randn(7, 11)}
+   local module = nngraph.JustElement()
+   totem.nn.checkGradients(tester, module, input)
+end
+
+function test.test_split()
+   local in1 = nn.Identity()()
+   local output = in1:split(1)
+   local net = nn.gModule({in1}, {output})
+
+   local input = {torch.randn(7, 11)}
+   tester:eq(net:forward(input), input[1], "output of split(1)")
+end
+
+tester:add(test):run()
diff --git a/test/test_JustTable.lua b/test/test_JustTable.lua
new file mode 100644
index 0000000..d24d739
--- /dev/null
+++ b/test/test_JustTable.lua
@@ -0,0 +1,19 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+   local input = torch.randn(7, 11)
+   local module = nngraph.JustTable()
+   tester:eq(module:forward(input), {input}, "output")
+end
+
+function test.test_grad()
+   local input = torch.randn(7, 11)
+   local module = nngraph.JustTable()
+   totem.nn.checkGradients(tester, module, input)
+end
+
+tester:add(test):run()
diff --git a/test/test_ModuleFromCriterion.lua b/test/test_ModuleFromCriterion.lua
new file mode 100644
index 0000000..9206905
--- /dev/null
+++ b/test/test_ModuleFromCriterion.lua
@@ -0,0 +1,57 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+function test.test_call()
+   local prediction = nn.Identity()()
+   local target = nn.Identity()()
+   local mse = nn.MSECriterion()({prediction, target})
+   local costBits = nn.MulConstant(1/math.log(2))(mse)
+   local net = nn.gModule({prediction, target}, {costBits})
+
+   local input = {torch.randn(3, 5), torch.rand(3, 5)}
+   local criterion = nn.MSECriterion()
+   local output = net:forward(input)
+   criterion:forward(input[1], input[2])
+   tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)
+
+   local gradOutput = torch.randn(1)
+   local gradInput = net:backward(input, gradOutput)
+   criterion:backward(input[1], input[2])
+   tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
+   tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
+end
+
+function test.test_grad()
+   local prediction = nn.Identity()()
+   local zero = nn.MulConstant(0)(prediction)
+   -- The target is created inside of the nngraph
+   -- to ignore the zero gradTarget.
+   local target = nn.AddConstant(1.23)(zero)
+   local mse = nn.MSECriterion()({prediction, target})
+   local net = nn.gModule({prediction}, {mse})
+
+   local input = torch.randn(4, 7)
+   totem.nn.checkGradients(tester, net, input)
+end
+
+local function module()
+   local module = nn.ModuleFromCriterion(nn.MSECriterion())
+   local input = {torch.randn(3, 5), torch.randn(3, 5)}
+   return module, input
+end
+
+function test.test_serializable()
+   local module, input = module()
+   totem.nn.checkSerializable(tester, module, input)
+end
+
+function test.test_typeCastable()
+   local module, input = module()
+   totem.nn.checkTypeCastable(tester, module, input)
+end
+
+
+tester:add(test):run()
diff --git a/test/test_connectivity.lua b/test/test_connectivity.lua
new file mode 100644
index 0000000..99b2539
--- /dev/null
+++ b/test/test_connectivity.lua
@@ -0,0 +1,26 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.connectivity()
+  -- Store debug info here, need to call debug.getinfo on same line as the
+  -- dangling pointer is declared.
+  local dInfo
+  local input = nn.Identity()()
+  local lin = nn.Linear(20, 10)(input)
+  -- The Sigmoid does not connect to the output, so should cause an error
+  -- when we call gModule.
+  local dangling = nn.Sigmoid()(lin); dInfo = debug.getinfo(1, 'Sl')
+  local actualOutput = nn.Tanh()(lin)
+  local errStr = string.format(
+      'node declared on %%[%s%%]:%d_ does not connect to gmodule output',
+      dInfo.short_src, dInfo.currentline)
+  tester:assertErrorPattern(
+      function()
+        return nn.gModule({input}, {actualOutput})
+      end,
+      errStr)
+end
+
+return tester:add(tests):run()
diff --git a/test/test_debug.lua b/test/test_debug.lua
new file mode 100644
index 0000000..f5c8b06
--- /dev/null
+++ b/test/test_debug.lua
@@ -0,0 +1,80 @@
+local totem = require 'totem'
+require 'nngraph'
+local tests = totem.TestSuite()
+local tester = totem.Tester()
+
+function tests.whatIsThisInTheInput()
+  tester:assertErrorPattern(
+      function()
+        local inp1, inp2 = nn.Identity()(), nn.Identity() -- missing 2nd parens
+        local lin = nn.Linear(20, 10)(nn.CMulTable(){inp1, inp2})
+      end,
+      'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+      'only valid thing to pass is an instance of nngraph%.Node')
+
+  tester:assertErrorPattern(
+      function()
+        -- pass-through module, again with same mistake
+        local graphNode, nnModule = nn.Identity()(), nn.Identity()
+        return nn.gModule({graphNode, nnModule}, {graphNode})
+      end,
+      'inputs%[2%] is an nn%.Module, specifically a nn%.Identity, but the ' ..
+      'only valid thing to pass is an instance of nngraph%.Node')
+
+  tester:assertErrorPattern(
+      function()
+        local input = nn.Identity()()
+        local out1 = nn.Linear(20, 10)(input)
+        local out2 = nn.Sigmoid()(input)
+        local unconnectedOut = nn.Linear(2, 3)
+        return nn.gModule({input}, {out1, out2, unconnectedOut})
+      end,
+      'outputs%[3%] is an nn%.Module, specifically a nn%.Linear, but the ' ..
+      'only valid thing to pass is an instance of nngraph%.Node')
+
+  -- Check for detecting a nil in the middle of a table.
+  tester:assertErrorPattern(
+      function()
+        local input = nn.Identity()()
+        local out1 = nn.Tanh()(input)
+        local out2 = nn.Sigmoid()(input)
+        -- nil here is simulating a mis-spelt variable name
+        return nn.gModule({input}, {out1, nil, out2})
+      end,
+      'outputs%[2%] is nil %(typo / bad index%?%)')
+
+  tester:assertErrorPattern(
+      function()
+        -- Typo variable name returns nil, meaning an empty table
+        local input = nn.Identity()({aNonExistentVariable})
+      end,
+      'cannot pass an empty table of inputs%. To indicate no incoming ' ..
+      'connections, leave the second set of parens blank%.')
+end
+
+function tests.splitUnused()
+  -- Need to do debuginfo on the same lines as the other code here to match
+  -- what debug.getinfo inside those calls will return
+  local dInfoDeclare, dInfoSplit
+  local input = nn.Identity()(); dInfoDeclare = debug.getinfo(1, 'Sl')
+  local output, unused = input:split(2); dInfoSplit = debug.getinfo(1, 'Sl')
+
+  local function willCrash()
+    return nn.gModule({input}, {output})
+  end
+
+  -- Work out what strings will be in the error message
+  local declareLoc = string.format('%%[%s%%]:%d_',
+                                   dInfoDeclare.short_src,
+                                   dInfoDeclare.currentline)
+  local splitLoc = string.format('%%[%s%%]:%d',
+                                 dInfoSplit.short_src,
+                                 dInfoSplit.currentline)
+
+  tester:assertErrorPattern(
+      willCrash,
+      '1 of split%(2%) outputs from the node declared at ' ..
+      declareLoc .. ' split at ' .. splitLoc .. '%-mnode are unused')
+end
+
+tester:add(tests):run()
diff --git a/test/test_nest.lua b/test/test_nest.lua
new file mode 100644
index 0000000..2e9b410
--- /dev/null
+++ b/test/test_nest.lua
@@ -0,0 +1,33 @@
+
+require 'totem'
+require 'nngraph'
+
+local test = {}
+local tester = totem.Tester()
+
+function test.test_output()
+   local in1 = nn.Identity()()
+   local in2 = nn.Identity()()
+   local in3 = nn.Identity()()
+   local ok = nn.CAddTable()(nngraph.nest({in1}))
+   local in1Again = nngraph.nest(in1)
+   local state = nngraph.nest({in1, {in2}, in3})
+
+   local net = nn.gModule(
+      {in1, in2, in3},
+      {ok, in1Again, state, nngraph.nest({in3}), nngraph.nest({in1, in2})})
+
+   local val1 = torch.randn(7, 3)
+   local val2 = torch.randn(2)
+   local val3 = torch.randn(3)
+   local expectedOutput = {
+      val1, val1, {val1, {val2}, val3}, {val3}, {val1, val2},
+   }
+   local output = net:forward({val1, val2, val3})
+   tester:eq(output, expectedOutput, "output")
+end
+
+
+return tester:add(test):run()
+
+
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
new file mode 100644
index 0000000..340ab24
--- /dev/null
+++ b/test/test_nngraph.lua
@@ -0,0 +1,481 @@
+
+require 'totem'
+require 'nngraph'
+local test = {}
+local tester = totem.Tester()
+
+local function checkGradients(...)
+   totem.nn.checkGradients(tester, ...)
+end
+
+function test.test_oneOutput()
+   local in1 = nn.Identity()()
+   local out1 = nn.Identity()(in1)
+   local module = nn.gModule({in1}, {out1})
+
+   local input = torch.Tensor({1})
+   module:forward(input)
+   tester:eq(module.output, torch.Tensor{1}, "output")
+   local gradInput = module:backward(input, torch.Tensor({-123}))
+   tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
+
+   local input2 = torch.Tensor({2})
+   module:forward(input2)
+   tester:eq(module.output, torch.Tensor{2}, "output for input2")
+   gradInput = module:backward(input2, torch.Tensor({-2}))
+   tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
+end
+
+
+function test.test_twoOutputs()
+   local in1 = nn.Identity()()
+   local out1 = nn.Identity()(in1)
+   local out2 = nn.Identity()(in1)
+   local module = nn.gModule({in1}, {out1, out2})
+
+   local input = torch.Tensor({1})
+   module:forward(input)
+   local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
+   tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
+   checkGradients(module, input)
+end
+
+function test.test_twoGradOutputs()
+   local in1 = nn.Sigmoid()()
+   local splitTable = nn.SplitTable(1)({in1})
+   local out1, out2 = splitTable:split(2)
+   local module = nn.gModule({in1}, {out1, out2})
+
+   local input = torch.randn(2, 3)
+   local output = module:forward(input)
+   assert(#output == 2, "wrong number of outputs")
+   module:backward(input, {torch.randn(3), torch.randn(3)})
+   checkGradients(module, input)
+end
+
+function test.test_twoInputs()
+   local in1 = nn.Identity()()
+   local in2 = nn.Identity()()
+   local prevH, prevCell = in2:split(2)
+
+   local out1 = nn.CMulTable()({in1, prevH, prevCell})
+   local module = nn.gModule({in1, in2}, {out1})
+
+   local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
+   module:forward(input)
+   local gradInput = module:backward(input, torch.randn(3))
+   assert(#gradInput == 2, "wrong number of gradInputs")
+   assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
+   checkGradients(module, input)
+end
+
+function test.test_twoInputs2()
+   local in1 = nn.Sigmoid()()
+   local in2 = nn.Sigmoid()()
+   local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
+
+   local input = {torch.randn(3), torch.randn(3)}
+   module:forward(input)
+   local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+   checkGradients(module, input)
+end
+
+function test.test_splitDebugLabels()
+   local node = nn.Identity()()
+   node.data.annotations._debugLabel = "node"
+   local node1, node2 = node:split(2)
+   assert(node1.data.annotations._debugLabel == "node-1")
+   assert(node2.data.annotations._debugLabel == "node-2")
+end
+
+function test.test_identity()
+   local in1 = nn.Identity()()
+   local in2 = nn.Identity()()
+   local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
+
+   local input = {torch.randn(3), torch.randn(3)}
+   module:forward(input)
+   module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+   checkGradients(module, input)
+end
+
+function test.test_gradInputType()
+   local xInput = torch.randn(3)
+   local h = torch.randn(3)
+
+   local x = nn.Identity()()
+   local prevRnnState = nn.Identity()()
+   local prevH1, prevCell = prevRnnState:split(2)
+   local prevH = prevH1
+
+   local cellOut = nn.CAddTable()({
+      nn.CMulTable()({x, prevH}),
+      nn.CMulTable()({prevH, prevCell})})
+      local module = nn.gModule({x, prevRnnState}, {cellOut})
+
+      local c = torch.randn(h:size())
+      local prevRnnState = {h, c}
+      local input = {xInput, prevRnnState}
+      local output = module:forward(input)
+
+      local gradOutput = torch.randn(h:size())
+      local gradInput = module:backward(input, gradOutput)
+
+      local unpack = unpack or table.unpack
+      local gradX, gradPrevState = unpack(gradInput)
+      local gradPrevH, gradPrevCell = unpack(gradPrevState)
+      assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
+
+      tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
+      tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
+      checkGradients(module, input)
+   end
+
+   function test.test_tabularInput()
+      local in1 = nn.SplitTable(1)()
+      local out1 = nn.CAddTable()(in1)
+      local module = nn.gModule({in1}, {out1})
+
+      local input = torch.randn(2, 3)
+      checkGradients(module, input)
+   end
+
+   function test.test_extraTable()
+      local in1 = nn.Identity()()
+      local out1 = nn.Identity()(in1)
+      local module = nn.gModule({in1}, {out1})
+
+      local input = torch.Tensor({123})
+      tester:eq(module:forward(input), input, "simple output")
+      tester:eq(module:forward({input}), {input}, "tabular output")
+   end
+
+   function test.test_accGradParameters()
+      local input = torch.randn(10)
+
+      local in1 = nn.CMul(input:nElement())()
+      local out1 = nn.Identity()(in1)
+      local out2 = nn.Identity()(in1)
+      local module = nn.gModule({in1}, {out1, out2})
+      checkGradients(module, input)
+   end
+
+   function test.test_example1()
+      local x1 = nn.Linear(20,10)()
+      local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
+      local mlp = nn.gModule({x1},{mout})
+
+      local x = torch.rand(20)
+      checkGradients(mlp, x)
+   end
+
+   function test.test_example2()
+      local x1=nn.Linear(20,20)()
+      local x2=nn.Linear(10,10)()
+      local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+      local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+      local madd=nn.CAddTable()({m0,m1})
+      local m2=nn.Sigmoid()(madd)
+      local m3=nn.Tanh()(madd)
+      local gmod = nn.gModule({x1,x2},{m2,m3})
+
+      local x = torch.rand(20)
+      local y = torch.rand(10)
+      checkGradients(gmod, {x, y})
+   end
+
+   function test.test_example3()
+      local m = nn.Sequential()
+      m:add(nn.SplitTable(1))
+      m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+      local input = nn.Identity()()
+      local input1,input2 = m(input):split(2)
+      local m3 = nn.JoinTable(1)({input1,input2})
+      local g = nn.gModule({input},{m3})
+
+      local indata = torch.rand(2,10)
+      checkGradients(g, indata)
+   end
+
+   function test.test_example4()
+      local input = nn.Identity()()
+      local L1 = nn.Tanh()(nn.Linear(1,2)(input))
+      local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
+      local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
+      local g = nn.gModule({input},{L3})
+
+      local indata = torch.rand(1)
+      checkGradients(g, indata)
+   end
+
+   function test.test_type()
+      local in1 = nn.Linear(20,10)()
+      local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
+      local module = nn.gModule({in1}, {out1})
+      local input = torch.rand(20)
+      local output = module:forward(input)
+      local gradOutput = output:clone():normal()
+      local gradInput = module:backward(input, gradOutput)
+
+      module:backward(input, output)
+      tester:eq(torch.typename(output), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.output), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.DoubleTensor")
+      tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.DoubleTensor")
+
+      module:float()
+      tester:eq(torch.typename(module.output), "torch.FloatTensor")
+      tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
+      tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
+      tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
+      tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
+      tester:eq(torch.typename(module.forwardnodes[1].children[1].data.input[1]), "torch.FloatTensor")
+      tester:eq(torch.typename(module.backwardnodes[1].children[1].data.gradOutput[1]), "torch.FloatTensor")
+      local output = module:forward(input:float())
+      tester:eq(torch.typename(output), "torch.FloatTensor")
+      local gradInput = module:backward(input:float(), gradOutput:float())
+      tester:eq(torch.typename(gradInput), "torch.FloatTensor")
+
+   end
+
+   function test.test_nestedGradInput()
+      local x = nn.Identity()()
+      local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
+      local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
+      local out = nn.CAddTable()({h1(x), h2(x)})
+
+      local model = nn.gModule({x}, {out})
+
+      local input = {}
+      input[1] = torch.randn(3, 3)
+      input[2] = torch.randn(3, 3)
+      input[3] = torch.randn(3, 3)
+
+      checkGradients(model, input)
+
+      local input = {}
+      input[1] = torch.randn(2, 3)
+      input[2] = torch.randn(2, 3)
+      input[3] = torch.randn(2, 3)
+
+      checkGradients(model, input)
+   end
+
+   function test.test_unusedInput()
+      local x = nn.Identity()()
+      local h = nn.Identity()()
+      local h2 = nn.Identity()()
+
+      local ok, result = pcall(nn.gModule, {x, h}, {x})
+      assert(not ok, "the unused input should be detected")
+   end
+
+   function test.test_unusedChild()
+      local prevState = nn.Identity()()
+      local h, cell = prevState:split(2)
+
+      local ok, result = pcall(nn.gModule, {prevState}, {h})
+      assert(not ok, "the unused cell should be detected")
+   end
+
+   function test.test_nilInput()
+      local ok, result = pcall(function() nn.Sigmoid()(nil) end)
+      assert(not ok, "the nil input should be detected")
+   end
+
+   function test.test_unusedNode()
+      local in1 = nn.Identity()()
+      local in2 = nn.Identity()()
+      local middleResult = nn.Sigmoid()(in2)
+      local out1 = nn.Sigmoid()(in1)
+
+      local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
+      assert(not ok, "the unused middleResult should be detected")
+   end
+
+   function test.test_usageAfterSplit()
+      local prevState = nn.Identity()()
+      local h, cell = prevState:split(2)
+      local nextState = nn.Identity()(prevState)
+      local transformed = nn.Sigmoid()(cell)
+
+      local model = nn.gModule({prevState}, {h, nextState, transformed})
+      local nHidden = 10
+      local input = {torch.randn(nHidden), torch.randn(nHidden)}
+      checkGradients(model, input)
+   end
+
+   function test.test_resizeNestedAs()
+      local in1 = nn.Identity()()
+      local out1 = nn.Identity()(in1)
+      local out2 = nn.Identity()(in1)
+
+      local net = nn.gModule({in1}, {out1, out2})
+      local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+      net:forward(input)
+      net:backward(input, net.output)
+      checkGradients(net, input)
+
+      input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
+      net:forward(input)
+      net:backward(input, net.output)
+      checkGradients(net, input)
+
+      input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+      net:forward(input)
+      local gradInput = net:backward(input, net.output)
+      tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
+      checkGradients(net, input)
+   end
+
+
+   function test.test_annotateGraph()
+      local input = nn.Identity()():annotate(
+      {name = 'Input', description = 'DescA',
+      graphAttributes = {color = 'red'}})
+
+      local hidden_a = nn.Linear(10, 10)(input):annotate(
+      {name = 'Hidden A', description = 'DescB',
+      graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
+      local hidden_b = nn.Sigmoid()(hidden_a)
+      local output = nn.Linear(10, 10)(hidden_b)
+      local net = nn.gModule({input}, {output})
+
+      tester:assert(hidden_a:label():match('DescB') ~= nil)
+      local fg_tmpfile = os.tmpname()
+      local bg_tmpfile = os.tmpname()
+      graph.dot(net.fg, 'Test', fg_tmpfile)
+      graph.dot(net.fg, 'Test BG', bg_tmpfile)
+
+      local function checkDotFile(tmpfile)
+         local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
+         tester:assert(
+         dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]') ~= nil)
+         tester:assert(
+         dotcontent:match(
+         '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]') ~= nil)
+         tester:assert(
+         dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]') ~= nil)
+         tester:assert(
+         dotcontent:match(
+         '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]') ~= nil)
+      end
+
+      checkDotFile(fg_tmpfile)
+      checkDotFile(bg_tmpfile)
+   end
+
+   function test.test_splitMore()
+      local nSplits = 2
+      local in1 = nn.Identity()()
+      local out1, out2 = nn.SplitTable(2)(in1):split(nSplits)
+
+      local model = nn.gModule({in1}, {out1, out2})
+      local input = torch.randn(10, nSplits + 1)
+      local ok, result = pcall(model.forward, model, input)
+      assert(not ok, "the extra input to split should be detected")
+   end
+
+   function test.test_splitLess()
+      local nSplits = 3
+      local in1 = nn.Identity()()
+      local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits)
+
+      local model = nn.gModule({in1}, {out1, out2, out3})
+      local input = torch.randn(10, nSplits - 1)
+      local ok, result = pcall(model.forward, model, input)
+      assert(not ok, "the missing input to split should be detected")
+   end
+
+   function test.test_gradOutputZeroOptim()
+      local unpack = function(...)
+	 if _G[unpack] then return _G[unpack](...)
+	 else return table.unpack(...) end
+      end
+      -- Make module that produces an expanded zero gradInput tensor
+      local dummyModule = nn.Module()
+      dummyModule.updateOutput = function(self, input)
+         self.output = torch.Tensor(1, 2, 3):uniform()
+         return self.output
+      end
+      dummyModule.updateGradInput = function(self, input, gradOutput)
+         local zeroTensor = torch.Tensor{0}
+          :view(unpack(torch.ones(input:dim()):totable()))
+          :expandAs(input)
+         self.gradInput = zeroTensor
+         return self.gradInput
+      end
+
+      -- First input and final gradOutput
+      local input = torch.Tensor(1, 2, 3):uniform()
+      local gradOutput = torch.Tensor(1, 2, 3):uniform()
+
+      -- First case: one intermediary gradOutput is going to be zero
+      local x = nn.Identity()()
+      local h1 = dummyModule:clone()(x)
+      local h2 = nn.Identity()(x)
+      local y = nn.CAddTable()({h1, h2})
+      local mod = nn.gModule({x}, {y})
+
+      local ok, result = pcall(nn.Module.forward, mod, input)
+      assert(ok, "forward should succeed")
+
+      nn.Module.backward( mod, input, gradOutput)
+      ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+      assert(ok, "backward should succeed")
+
+      -- Second case: all intermediary gradOutputs are going to be zero
+      local x = nn.Identity()()
+      local h1 = dummyModule:clone()(x)
+      local h2 = dummyModule:clone()(x)
+      local y = nn.CAddTable()({h1, h2})
+      local mod = nn.gModule({x}, {y})
+
+      local ok, result = pcall(nn.Module.forward, mod, input)
+      assert(ok, "forward should succeed")
+
+      ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
+      assert(ok, "backward should succeed")
+   end
+
+   function test.test_replace()
+      local i = nn.Identity()()
+      local l1 = nn.Linear(5, 2)(i)
+      local sig = nn.Sigmoid()(l1)
+      local l2  = nn.Linear(2, 5)(sig)
+      local model = nn.gModule({i}, {l2})
+
+      local input = torch.randn(4, 5)
+      local gradOutput = torch.randn(4, 5)
+      tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
+      tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
+
+      model:replace(function(m)
+         if torch.type(m) == 'nn.Linear' then
+            if m.weight:size(1) == 5 then
+               return nn.Linear(2, 10)
+            elseif m.weight:size(1) == 2 then
+               return nn.Linear(10, 2)
+            end
+         elseif torch.type(m) == 'nn.Sigmoid' then
+            return nn.Tanh()
+         end
+         return m
+      end)
+
+      local input = torch.randn(4, 10)
+      local gradOutput = torch.randn(4, 10)
+      tester:eq(model:forward(input):size(), input:size(), "inconsistent output size")
+      tester:eq(model:backward(input, gradOutput):size(), input:size(), "inconsistent output size")
+
+      tester:ne(model.modules[2], l1, "gModule.modules wasn't updated")
+      tester:ne(model.modules[3], sig, "gModule.modules wasn't updated")
+      tester:eq(torch.type(model.modules[3]), 'nn.Tanh', "replace didn't update gModule.modules")
+      tester:ne(model.modules[4], l2, "gModule.modules wasn't updated")
+   end
+
+   tester:add(test):run()
diff --git a/test/test_old.lua b/test/test_old.lua
new file mode 100644
index 0000000..1a1e862
--- /dev/null
+++ b/test/test_old.lua
@@ -0,0 +1,227 @@
+require 'nngraph'
+function t1()
+   local x1 = nn.Linear(20,20)()
+   local x2 = nn.Linear(10,10)()
+   local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+   local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+   local madd=nn.CAddTable()({m0,m1})
+   local m2=nn.Sigmoid()(madd)
+   local m3=nn.Tanh()(madd)
+   local x = torch.rand(20)
+   local y = torch.rand(10)
+   gmod = nn.gModule({x1,x2},{m2,m3})
+   gmod.verbose = true
+   print('forward')
+   gmod:updateOutput({x,y})
+   print('updateGradInput')
+   gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
+   graph.dot(gmod.fg,'forward')
+   graph.dot(gmod.bg,'backward')
+end
+
+function t2()
+   print('compare')
+   local m0 = nn.Linear(5,10)()
+   local m1 = nn.Linear(10,20)()
+   local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1})
+   gmod = nn.gModule({m0,m1},{m2})
+
+   local nn0 = nn.Linear(5,10)
+   local nn1 = nn.Linear(10,20)
+   local nn2 = nn.Linear(30,50)
+   local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2)
+
+   nn0.weight:copy(m0.data.module.weight)
+   nn0.bias:copy(m0.data.module.bias)
+   nn1.weight:copy(m1.data.module.weight)
+   nn1.bias:copy(m1.data.module.bias)
+   nn2.weight:copy(m2.data.module.weight)
+   nn2.bias:copy(m2.data.module.bias)
+
+
+   for i=1,5 do
+      local x,y = torch.rand(5),torch.rand(10)
+      local xx,yy = x:clone(),y:clone()
+
+      gmod:updateOutput({x,y})
+      nnmod:updateOutput({xx,yy})
+      print('fdiff = ', torch.dist(gmod.output,nnmod.output))
+
+      local odx = torch.rand(50)
+      local odxx = odx:clone()
+
+      gmod:updateGradInput({x,y},odx)
+      nnmod:updateGradInput({xx,yy},odxx)
+      graph.dot(gmod.fg,tostring(i))
+      for i,v in ipairs(gmod.gradInput) do
+         print('bdiff [' ..i..  '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i]))
+      end
+   end
+
+   local gms = {m0,m1,m2}
+   local nms = {nn0,nn1,nn2}
+
+   for i=1,5 do
+      local x,y = torch.rand(5),torch.rand(10)
+      local xx,yy = x:clone(),y:clone()
+
+      gmod:updateOutput({x,y})
+      nnmod:updateOutput({xx,yy})
+      print('fdiff = ', torch.dist(gmod.output,nnmod.output))
+
+      local odx = torch.rand(50)
+      local odxx = odx:clone()
+
+      gmod:zeroGradParameters()
+      nnmod:zeroGradParameters()
+
+      gmod:updateGradInput({x,y},odx)
+      nnmod:updateGradInput({xx,yy},odxx)
+
+      gmod:accGradParameters({x,y},odx)
+      nnmod:accGradParameters({xx,yy},odxx)
+      graph.dot(gmod.fg)
+      for i,v in ipairs(gms) do
+         print('accdiff [' ..i..  '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight))
+         print('accdiff [' ..i..  '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias))
+      end
+   end
+end
+
+function t3()
+   mlp=nn.Sequential();       --Create a network that takes a Tensor as input
+   mlp:add(nn.SplitTable(2))
+   c=nn.ParallelTable()      --The two Tensors go through two different Linear
+   c:add(nn.Linear(10,3))	   --Layers in Parallel
+   c:add(nn.Linear(10,7))
+   mlp:add(c)                 --Outputing a table with 2 elements
+   p=nn.ParallelTable()      --These tables go through two more linear layers
+   p:add(nn.Linear(3,2))	   -- separately.
+   p:add(nn.Linear(7,1))
+   mlp:add(p)
+   mlp:add(nn.JoinTable(1))   --Finally, the tables are joined together and output.
+
+   pred=mlp:forward(torch.randn(10,2))
+   print(pred)
+
+   for i=1,25 do             -- A few steps of training such a network..
+      x=torch.ones(10,2);
+      y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3))
+      pred=mlp:forward(x)
+
+      criterion= nn.MSECriterion()
+      local err=criterion:forward(pred,y)
+      local gradCriterion = criterion:backward(pred,y);
+      print(x,y)
+      mlp:zeroGradParameters();
+      mlp:backward(x, gradCriterion);
+      mlp:updateParameters(0.05);
+
+      print(err)
+   end
+end
+
+function t4()
+   local getInput1 = nn.Identity()()
+   local getInput2 = nn.Identity()()
+   local mlp = nn.Tanh()(getInput1)
+   net = nn.gModule({getInput1, getInput2}, {mlp, getInput2})
+
+
+   local input1 = torch.randn(2)
+   local input2 = torch.randn(5)
+
+   net:forward({input1, input2})
+   local gradInput = net:backward({input1, input2},
+   {torch.randn(input1:size()), torch.randn(input2:size())})
+   print("gradInput[1]:", gradInput[1])
+   print("gradInput[2]:", gradInput[2])
+   graph.dot(net.fg)
+   assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch")
+
+end
+
+function t5()
+   local m = nn.Sequential()
+   m:add(nn.SplitTable(1))
+   m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+   local input = nn.Identity()()
+   local input1,input2 = m(input):split(2)
+   local m3 = nn.JoinTable(1)({input1,input2})
+
+   g = nn.gModule({input},{m3})
+   graph.dot(g.fg,'init forward')
+
+   local indata = torch.rand(2,10)
+   local gdata = torch.rand(50)
+   g:forward(indata)
+   g:backward(indata,gdata)
+
+   graph.dot(g.fg,'forward')
+   graph.dot(g.bg,'backward')
+end
+
+function topsort(a)
+   -- first clone the graph
+   -- local g = self:clone()
+   -- local nodes = g.nodes
+   -- local edges = g.edges
+   -- for i,node in ipairs(nodes) do
+   -- 	node.children = {}
+   -- end
+
+   -- reverse the graph
+   rg,map = a:reverse()
+   local rmap = {}
+   for k,v in pairs(map) do
+      rmap[v] = k
+   end
+
+   -- work on the sorted graph
+   sortednodes = {}
+   rootnodes = rg:roots()
+
+   if #rootnodes == 0 then
+      print('Graph has cycles')
+   end
+
+   -- run
+   for i,root in ipairs(rootnodes) do
+      root:dfs(function(node)
+         print(node.id,rmap[node].id)
+         -- print(rmap[node])
+         table.insert(sortednodes,rmap[node]) end)
+      end
+
+      if #sortednodes ~= #a.nodes then
+         print('Graph has cycles')
+      end
+      return sortednodes,rg,rootnodes
+   end
+
+   local my={eq =
+   function(a,b,s)
+      if a:dist(b) == 0 then
+         print('ok')
+      else
+         print('error : ' .. s)
+         print('a : ');print(a)
+         print('b : ');print(b)
+      end
+   end}
+
+   function t8()
+      local in1 = nn.Identity()()
+      local m = nn.Linear(10,10)(in1)
+      local out1 = nn.Tanh()(m)
+      local out2 = nn.Tanh()(m)
+      local out = nn.CAddTable(){out1, out2}
+      local mod = nn.gModule({in1}, {out})
+
+      local dot = nngraph.simple_print.todot(mod.fg, 'bogus')
+      print (dot)
+      nngraph.simple_print.dot(mod.fg, 'bogus', 'new')
+      graph.dot(mod.fg, 'bogus', 'old')
+   end
+   -- t2()
+   t8()
diff --git a/utils.lua b/utils.lua
new file mode 100644
index 0000000..1b39607
--- /dev/null
+++ b/utils.lua
@@ -0,0 +1,42 @@
+local utils = {}
+
+function utils.istorchclass(x)
+   return type(x) == 'table' and torch.typename(x)
+end
+
+function utils.istable(x)
+   return type(x) == 'table' and not torch.typename(x)
+end
+
+--[[ Returns a useful error message when a nngraph.Node is expected. ]]
+function utils.expectingNodeErrorMessage(badVal, array, idx)
+   if badVal == nil then
+      return string.format('%s[%d] is nil (typo / bad index?)', array, idx)
+   elseif torch.isTypeOf(badVal, 'nn.Module') then
+      local errStr = '%s[%d] is an nn.Module, specifically a %s, but the ' ..
+                     'only valid thing to pass is an instance of ' ..
+                     'nngraph.Node. Did you forget a second set of parens, ' ..
+                     'which convert a nn.Module to a nngraph.Node?'
+      return string.format(errStr, array, idx, torch.typename(badVal))
+   else
+      local errStr = '%s[%d] should be an nngraph.Node but is of type %s'
+      return string.format(errStr, array, idx,
+                         torch.typename(badVal) or type(badVal))
+   end
+end
+
+--[[ Lua 5.2+ removed table.maxn, provide fallback implementation. ]]
+if table.maxn then
+   utils.tableMaxN = table.maxn
+else
+   function utils.tableMaxN(tbl)
+      local max = 0
+      for k, v in pairs(tbl) do
+         if type(k) == 'number' and k > max then
+            max = k
+         end
+      end
+      return max
+   end
+ end
+return utils

-- 
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/lua-torch-nngraph.git



More information about the debian-science-commits mailing list