| |
| |
| |
| |
| |
| from caffe2.python import core, workspace |
| from caffe2.python import test_util as tu |
| import caffe2.python.nomnigraph as ng |
| from caffe2.python.nomnigraph_transformations import transpose_network |
| |
| import numpy as np |
| from hypothesis import given |
| import hypothesis.strategies as st |
| |
| |
| class TestNomnigraphTransformations(tu.TestCase): |
| def test_simple_replace(self): |
| net = core.Net("name") |
| net.FC(["X", "W"], ["Y"]) |
| nn = ng.NNModule(net) |
| fc = nn.controlFlow[0] |
| add = nn.createNode(core.CreateOperator("Add", ["X"], ["Y"], engine="CUDNN")) |
| nn.replaceNode(fc, add) |
| nn.deleteNode(fc) |
| |
| # Test it out |
| new_netdef = nn.convertToCaffe2Proto() |
| workspace.ResetWorkspace() |
| workspace.FeedBlob("X", np.array([1, 2, 3])) |
| workspace.FeedBlob("W", np.array([1, 2, 3])) |
| workspace.RunNetOnce(new_netdef) |
| out = workspace.FetchBlob("Y") |
| expected_out = np.array([2, 4, 6]) |
| np.testing.assert_almost_equal(out, expected_out) |
| |
| def test_simple_rewire(self): |
| net = core.Net("name") |
| # Rewire this so that we get |
| # c = Add(a, d) |
| # e = Mul(c, b) |
| # |
| # if a = 1, b = 2, d = 3 |
| # we get 8: (1 + 3) * 2 |
| # as opposed to 7: 1 + (3 * 2) |
| net.Mul(["a", "b"], ["c"]) |
| net.Add(["c", "d"], ["e"]) |
| nn = ng.NNModule(net) |
| |
| mul = nn.controlFlow[0] |
| add = nn.controlFlow[1] |
| a = mul.inputs[0] |
| b = mul.inputs[1] |
| c = mul.outputs[0] |
| d = add.inputs[1] |
| e = add.outputs[0] |
| |
| nn.deleteEdge(a, mul) |
| nn.deleteEdge(b, mul) |
| nn.deleteEdge(mul, c) |
| nn.deleteEdge(c, add) |
| nn.deleteEdge(d, add) |
| nn.deleteEdge(add, e) |
| |
| nn.createEdge(a, add) |
| nn.createEdge(d, add) |
| nn.createEdge(add, c) |
| nn.createEdge(c, mul) |
| nn.createEdge(b, mul) |
| nn.createEdge(mul, e) |
| |
| # Test it out |
| new_netdef = nn.convertToCaffe2Proto() |
| workspace.ResetWorkspace() |
| workspace.FeedBlob("a", np.array([1, 1, 1])) |
| workspace.FeedBlob("b", np.array([2, 2, 2])) |
| workspace.FeedBlob("d", np.array([3, 3, 3])) |
| workspace.RunNetOnce(new_netdef) |
| out = workspace.FetchBlob("e") |
| expected_out = np.array([8, 8, 8]) |
| np.testing.assert_almost_equal(out, expected_out) |
| |
| @given( |
| batch_size=st.integers(16, 20), |
| channels=st.integers(1, 10), |
| height=st.integers(10, 15), |
| width=st.integers(10, 15), |
| seed=st.integers(0, 65535), |
| kernel=st.integers(3, 5), |
| ) |
| def test_transpose_network(self, batch_size, channels, height, width, seed, |
| kernel): |
| net = core.Net("net") |
| net.Conv(["X", "w1", "b1"], ["c1"], stride=1, pad=0, kernel=kernel) |
| net.Conv(["X", "w2", "b2"], ["c2"], stride=1, pad=0, kernel=kernel) |
| # c1 and c2: batch_size, 2*channels, height - kernel + 1, width - kernel + 1 |
| net.Conv(["c1", "w3", "b3"], ["c3"], stride=1, pad=0, kernel=kernel) |
| net.Conv(["c1", "w4", "b4"], ["c4"], stride=1, pad=0, kernel=kernel) |
| # c3 and c4: batch_size, 2*channels, height - 2*kernel + 2, width - 2*kernel + 2 |
| net.Flatten(["c3"], "c3f") |
| net.Flatten(["c4"], "c4f") |
| net.Flatten(["X"], "Xf") |
| net.Concat(["c3f", "c4f", "Xf"], ["out", "split_info"], axis=1, add_axis=0) |
| np.random.seed(seed) |
| workspace.ResetWorkspace() |
| tu.randBlobFloat32("X", batch_size, channels, height, width) |
| tu.randBlobsFloat32(["w1", "w2"], 2 * channels, channels, kernel, kernel) |
| tu.randBlobsFloat32(["b1", "b2"], 2 * channels) |
| tu.randBlobsFloat32(["w3", "w4"], 4 * channels, 2 * channels, kernel, kernel) |
| tu.randBlobsFloat32(["b3", "b4"], 4 * channels) |
| all_inp_names = ["X", "w1", "w2", "b1", "b2", "w3", "w4", "b3", "b4"] |
| all_input = workspace.FetchBlobs(all_inp_names) |
| workspace.RunNetOnce(net) |
| preTransformC1 = workspace.FetchBlob("c1") |
| preTransformC3 = workspace.FetchBlob("c3") |
| preTransformOut = workspace.FetchBlob("out") |
| nn = ng.NNModule(net) |
| preTransformNumOperators = len(nn.operators) |
| preTransformNumTensors = len(nn.tensors) |
| transpose_network(nn) |
| new_netdef = nn.convertToCaffe2Proto() |
| postTransformNumOperators = len(nn.operators) |
| postTransformNumTensors = len(nn.tensors) |
| # The minimal number of additional operators and tensors is at least one |
| # NCHW2NHWC operator and tensor for each channel-based input tensor |
| # and a NHWC2NCHW operator and tensor for the output of each convolution |
| # X, w1, w2, w3, w4 are channel-based inputs |
| # c1, c2, c3, c4 are the outputs of convolutions |
| # i.e. a total of 9. |
| self.assertEqual(postTransformNumOperators, |
| preTransformNumOperators + 9, |
| "expected 9 additional operators") |
| self.assertEqual(postTransformNumTensors, |
| preTransformNumTensors + 9, |
| "expected 9 additional tensors") |
| workspace.ResetWorkspace() |
| for name, val in zip(all_inp_names, all_input): |
| workspace.FeedBlob(name, val) |
| workspace.RunNetOnce(new_netdef) |
| postTransformC1 = workspace.FetchBlob("c1") |
| postTransformC3 = workspace.FetchBlob("c3") |
| postTransformOut = workspace.FetchBlob("out") |
| np.testing.assert_almost_equal(postTransformC1, preTransformC1, 1) |
| np.testing.assert_almost_equal(postTransformC3, preTransformC3, 1) |
| np.testing.assert_almost_equal(postTransformOut, preTransformOut, 1) |