




from caffe2.python import control, core, test_util, workspace

import logging
logger = logging.getLogger(__name__)


class TestControl(test_util.TestCase):
    def setUp(self):
        super().setUp()
        self.N_ = 10

        self.init_net_ = core.Net("init-net")
        cnt = self.init_net_.CreateCounter([], init_count=0)
        const_n = self.init_net_.ConstantFill(
            [], shape=[], value=self.N_, dtype=core.DataType.INT64)
        const_0 = self.init_net_.ConstantFill(
            [], shape=[], value=0, dtype=core.DataType.INT64)

        self.cnt_net_ = core.Net("cnt-net")
        self.cnt_net_.CountUp([cnt])
        curr_cnt = self.cnt_net_.RetrieveCount([cnt])
        self.init_net_.ConstantFill(
            [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64)
        self.cnt_net_.AddExternalOutput(curr_cnt)

        self.cnt_2_net_ = core.Net("cnt-2-net")
        self.cnt_2_net_.CountUp([cnt])
        self.cnt_2_net_.CountUp([cnt])
        curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt])
        self.init_net_.ConstantFill(
            [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64)
        self.cnt_2_net_.AddExternalOutput(curr_cnt_2)

        self.cond_net_ = core.Net("cond-net")
        cond_blob = self.cond_net_.LT([curr_cnt, const_n])
        self.cond_net_.AddExternalOutput(cond_blob)

        self.not_cond_net_ = core.Net("not-cond-net")
        cond_blob = self.not_cond_net_.GE([curr_cnt, const_n])
        self.not_cond_net_.AddExternalOutput(cond_blob)

        self.true_cond_net_ = core.Net("true-cond-net")
        true_blob = self.true_cond_net_.LT([const_0, const_n])
        self.true_cond_net_.AddExternalOutput(true_blob)

        self.false_cond_net_ = core.Net("false-cond-net")
        false_blob = self.false_cond_net_.GT([const_0, const_n])
        self.false_cond_net_.AddExternalOutput(false_blob)

        self.idle_net_ = core.Net("idle-net")
        self.idle_net_.ConstantFill(
            [], shape=[], value=0, dtype=core.DataType.INT64)

    def CheckNetOutput(self, nets_and_expects):
        """
        Check the net output is expected
        nets_and_expects is a list of tuples (net, expect)
        """
        for net, expect in nets_and_expects:
            output = workspace.FetchBlob(
                net.Proto().external_output[-1])
            self.assertEqual(output, expect)

    def CheckNetAllOutput(self, net, expects):
        """
        Check the net output is expected
        expects is a list of bools.
        """
        self.assertEqual(len(net.Proto().external_output), len(expects))
        for i in range(len(expects)):
            output = workspace.FetchBlob(
                net.Proto().external_output[i])
            self.assertEqual(output, expects[i])

    def BuildAndRunPlan(self, step):
        plan = core.Plan("test")
        plan.AddStep(control.Do('init', self.init_net_))
        plan.AddStep(step)
        self.assertEqual(workspace.RunPlan(plan), True)

    def ForLoopTest(self, nets_or_steps):
        step = control.For('myFor', nets_or_steps, self.N_)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, self.N_)])

    def testForLoopWithNets(self):
        self.ForLoopTest(self.cnt_net_)
        self.ForLoopTest([self.cnt_net_, self.idle_net_])

    def testForLoopWithStep(self):
        step = control.Do('count', self.cnt_net_)
        self.ForLoopTest(step)
        self.ForLoopTest([step, self.idle_net_])

    def WhileLoopTest(self, nets_or_steps):
        step = control.While('myWhile', self.cond_net_, nets_or_steps)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, self.N_)])

    def testWhileLoopWithNet(self):
        self.WhileLoopTest(self.cnt_net_)
        self.WhileLoopTest([self.cnt_net_, self.idle_net_])

    def testWhileLoopWithStep(self):
        step = control.Do('count', self.cnt_net_)
        self.WhileLoopTest(step)
        self.WhileLoopTest([step, self.idle_net_])

    def UntilLoopTest(self, nets_or_steps):
        step = control.Until('myUntil', self.not_cond_net_, nets_or_steps)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, self.N_)])

    def testUntilLoopWithNet(self):
        self.UntilLoopTest(self.cnt_net_)
        self.UntilLoopTest([self.cnt_net_, self.idle_net_])

    def testUntilLoopWithStep(self):
        step = control.Do('count', self.cnt_net_)
        self.UntilLoopTest(step)
        self.UntilLoopTest([step, self.idle_net_])

    def DoWhileLoopTest(self, nets_or_steps):
        step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, self.N_)])

    def testDoWhileLoopWithNet(self):
        self.DoWhileLoopTest(self.cnt_net_)
        self.DoWhileLoopTest([self.idle_net_, self.cnt_net_])

    def testDoWhileLoopWithStep(self):
        step = control.Do('count', self.cnt_net_)
        self.DoWhileLoopTest(step)
        self.DoWhileLoopTest([self.idle_net_, step])

    def DoUntilLoopTest(self, nets_or_steps):
        step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, self.N_)])

    def testDoUntilLoopWithNet(self):
        self.DoUntilLoopTest(self.cnt_net_)
        self.DoUntilLoopTest([self.cnt_net_, self.idle_net_])

    def testDoUntilLoopWithStep(self):
        step = control.Do('count', self.cnt_net_)
        self.DoUntilLoopTest(step)
        self.DoUntilLoopTest([self.idle_net_, step])

    def IfCondTest(self, cond_net, expect, cond_on_blob):
        if cond_on_blob:
            step = control.Do(
                'if-all',
                control.Do('count', cond_net),
                control.If('myIf', cond_net.Proto().external_output[-1],
                           self.cnt_net_))
        else:
            step = control.If('myIf', cond_net, self.cnt_net_)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, expect)])

    def testIfCondTrueOnNet(self):
        self.IfCondTest(self.true_cond_net_, 1, False)

    def testIfCondTrueOnBlob(self):
        self.IfCondTest(self.true_cond_net_, 1, True)

    def testIfCondFalseOnNet(self):
        self.IfCondTest(self.false_cond_net_, 0, False)

    def testIfCondFalseOnBlob(self):
        self.IfCondTest(self.false_cond_net_, 0, True)

    def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
        if cond_value:
            run_net = self.cnt_net_
        else:
            run_net = self.cnt_2_net_
        if cond_on_blob:
            step = control.Do(
                'if-else-all',
                control.Do('count', cond_net),
                control.If('myIfElse', cond_net.Proto().external_output[-1],
                           self.cnt_net_, self.cnt_2_net_))
        else:
            step = control.If('myIfElse', cond_net,
                              self.cnt_net_, self.cnt_2_net_)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(run_net, expect)])

    def testIfElseCondTrueOnNet(self):
        self.IfElseCondTest(self.true_cond_net_, True, 1, False)

    def testIfElseCondTrueOnBlob(self):
        self.IfElseCondTest(self.true_cond_net_, True, 1, True)

    def testIfElseCondFalseOnNet(self):
        self.IfElseCondTest(self.false_cond_net_, False, 2, False)

    def testIfElseCondFalseOnBlob(self):
        self.IfElseCondTest(self.false_cond_net_, False, 2, True)

    def IfNotCondTest(self, cond_net, expect, cond_on_blob):
        if cond_on_blob:
            step = control.Do(
                'if-not',
                control.Do('count', cond_net),
                control.IfNot('myIfNot', cond_net.Proto().external_output[-1],
                              self.cnt_net_))
        else:
            step = control.IfNot('myIfNot', cond_net, self.cnt_net_)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, expect)])

    def testIfNotCondTrueOnNet(self):
        self.IfNotCondTest(self.true_cond_net_, 0, False)

    def testIfNotCondTrueOnBlob(self):
        self.IfNotCondTest(self.true_cond_net_, 0, True)

    def testIfNotCondFalseOnNet(self):
        self.IfNotCondTest(self.false_cond_net_, 1, False)

    def testIfNotCondFalseOnBlob(self):
        self.IfNotCondTest(self.false_cond_net_, 1, True)

    def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob):
        if cond_value:
            run_net = self.cnt_2_net_
        else:
            run_net = self.cnt_net_
        if cond_on_blob:
            step = control.Do(
                'if-not-else',
                control.Do('count', cond_net),
                control.IfNot('myIfNotElse',
                              cond_net.Proto().external_output[-1],
                              self.cnt_net_, self.cnt_2_net_))
        else:
            step = control.IfNot('myIfNotElse', cond_net,
                                 self.cnt_net_, self.cnt_2_net_)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(run_net, expect)])

    def testIfNotElseCondTrueOnNet(self):
        self.IfNotElseCondTest(self.true_cond_net_, True, 2, False)

    def testIfNotElseCondTrueOnBlob(self):
        self.IfNotElseCondTest(self.true_cond_net_, True, 2, True)

    def testIfNotElseCondFalseOnNet(self):
        self.IfNotElseCondTest(self.false_cond_net_, False, 1, False)

    def testIfNotElseCondFalseOnBlob(self):
        self.IfNotElseCondTest(self.false_cond_net_, False, 1, True)

    def testSwitch(self):
        step = control.Switch(
            'mySwitch',
            (self.false_cond_net_, self.cnt_net_),
            (self.true_cond_net_, self.cnt_2_net_)
        )
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)])

    def testSwitchNot(self):
        step = control.SwitchNot(
            'mySwitchNot',
            (self.false_cond_net_, self.cnt_net_),
            (self.true_cond_net_, self.cnt_2_net_)
        )
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)])

    def testBoolNet(self):
        bool_net = control.BoolNet(('a', True))
        step = control.Do('bool', bool_net)
        self.BuildAndRunPlan(step)
        self.CheckNetAllOutput(bool_net, [True])

        bool_net = control.BoolNet(('a', True), ('b', False))
        step = control.Do('bool', bool_net)
        self.BuildAndRunPlan(step)
        self.CheckNetAllOutput(bool_net, [True, False])

        bool_net = control.BoolNet([('a', True), ('b', False)])
        step = control.Do('bool', bool_net)
        self.BuildAndRunPlan(step)
        self.CheckNetAllOutput(bool_net, [True, False])

    def testCombineConditions(self):
        # combined by 'Or'
        combine_net = control.CombineConditions(
            'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
        step = control.Do('combine',
                          self.true_cond_net_,
                          self.false_cond_net_,
                          combine_net)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(combine_net, True)])

        # combined by 'And'
        combine_net = control.CombineConditions(
            'test', [self.true_cond_net_, self.false_cond_net_], 'And')
        step = control.Do('combine',
                          self.true_cond_net_,
                          self.false_cond_net_,
                          combine_net)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(combine_net, False)])

    def testMergeConditionNets(self):
        # merged by 'Or'
        merge_net = control.MergeConditionNets(
            'test', [self.true_cond_net_, self.false_cond_net_], 'Or')
        step = control.Do('merge', merge_net)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(merge_net, True)])

        # merged by 'And'
        merge_net = control.MergeConditionNets(
            'test', [self.true_cond_net_, self.false_cond_net_], 'And')
        step = control.Do('merge', merge_net)
        self.BuildAndRunPlan(step)
        self.CheckNetOutput([(merge_net, False)])
