blob: e98cd06f7a0044bfb4fa2b999f33f0444d376800 [file] [log] [blame]
# Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for flagsaver."""
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
INT_FLAG = flags.DEFINE_integer(
'flagsaver_test_int_flag', default=1, help='help')
STR_FLAG = flags.DEFINE_string(
'flagsaver_test_str_flag', default='str default', help='help')
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
def validate_test_flags(flag_dict):
return (flag_dict['flagsaver_test_validated_flag1'] ==
flag_dict['flagsaver_test_validated_flag2'])
FLAGS = flags.FLAGS
@flags.validator('flagsaver_test_flag0')
def check_no_upper_case(value):
return value == value.lower()
class _TestError(Exception):
"""Exception class for use in these tests."""
class FlagSaverTest(absltest.TestCase):
def test_context_manager_without_parameters(self):
with flagsaver.flagsaver():
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_context_manager_with_overrides(self):
with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag1 = 'another value'
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
def test_context_manager_with_flagholders(self):
with flagsaver.flagsaver((INT_FLAG, 3), (STR_FLAG, 'new value')):
self.assertEqual('new value', STR_FLAG.value)
self.assertEqual(3, INT_FLAG.value)
FLAGS.flagsaver_test_flag1 = 'another value'
self.assertEqual(INT_FLAG.value, INT_FLAG.default)
self.assertEqual(STR_FLAG.value, STR_FLAG.default)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
def test_context_manager_with_overrides_and_flagholders(self):
with flagsaver.flagsaver((INT_FLAG, 3), flagsaver_test_flag0='new value'):
self.assertEqual(STR_FLAG.default, STR_FLAG.value)
self.assertEqual(3, INT_FLAG.value)
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual(INT_FLAG.value, INT_FLAG.default)
self.assertEqual(STR_FLAG.value, STR_FLAG.default)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_context_manager_with_cross_validated_overrides_set_together(self):
# When the flags are set in the same flagsaver call their validators will
# be triggered only once the setting is done.
with flagsaver.flagsaver(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value'):
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_context_manager_with_cross_validated_overrides_set_badly(self):
# Different values should violate the validator.
with self.assertRaisesRegex(flags.IllegalFlagValueError,
'Flag validation failed'):
with flagsaver.flagsaver(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value'):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_context_manager_with_cross_validated_overrides_set_separately(self):
# Setting just one flag will trip the validator as well.
with self.assertRaisesRegex(flags.IllegalFlagValueError,
'Flag validation failed'):
with flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value'):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_context_manager_with_exception(self):
with self.assertRaises(_TestError):
with flagsaver.flagsaver(flagsaver_test_flag0='new value'):
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag1 = 'another value'
raise _TestError('oops')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
def test_context_manager_with_validation_exception(self):
with self.assertRaises(flags.IllegalFlagValueError):
with flagsaver.flagsaver(
flagsaver_test_flag0='new value',
flagsaver_test_validated_flag='new value'):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
def test_decorator_without_call(self):
@flagsaver.flagsaver
def mutate_flags(value):
"""Test function that mutates a flag."""
# The undecorated method mutates --flagsaver_test_flag0 to the given value
# and then returns the value of that flag. If the @flagsaver.flagsaver
# decorator works as designed, then this mutation will be reverted after
# this method returns.
FLAGS.flagsaver_test_flag0 = value
return FLAGS.flagsaver_test_flag0
# mutate_flags returns the flag value before it gets restored by
# the flagsaver decorator. So we check that flag value was
# actually changed in the method's scope.
self.assertEqual('new value', mutate_flags('new value'))
# But... notice that the flag is now unchanged0.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_without_parameters(self):
@flagsaver.flagsaver()
def mutate_flags(value):
FLAGS.flagsaver_test_flag0 = value
return FLAGS.flagsaver_test_flag0
self.assertEqual('new value', mutate_flags('new value'))
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_overrides(self):
@flagsaver.flagsaver(flagsaver_test_flag0='new value')
def mutate_flags():
"""Test function expecting new value."""
# If the @flagsaver.decorator decorator works as designed,
# then the value of the flag should be changed in the scope of
# the method but the change will be reverted after this method
# returns.
return FLAGS.flagsaver_test_flag0
# mutate_flags returns the flag value before it gets restored by
# the flagsaver decorator. So we check that flag value was
# actually changed in the method's scope.
self.assertEqual('new value', mutate_flags())
# But... notice that the flag is now unchanged0.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_cross_validated_overrides_set_together(self):
# When the flags are set in the same flagsaver call their validators will
# be triggered only once the setting is done.
@flagsaver.flagsaver(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value')
def mutate_flags_together():
return (FLAGS.flagsaver_test_validated_flag1,
FLAGS.flagsaver_test_validated_flag2)
self.assertEqual(('new_value', 'new_value'), mutate_flags_together())
# The flags have not changed outside the context of the function.
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_decorator_with_cross_validated_overrides_set_badly(self):
# Different values should violate the validator.
@flagsaver.flagsaver(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value')
def mutate_flags_together_badly():
return (FLAGS.flagsaver_test_validated_flag1,
FLAGS.flagsaver_test_validated_flag2)
with self.assertRaisesRegex(flags.IllegalFlagValueError,
'Flag validation failed'):
mutate_flags_together_badly()
# The flags have not changed outside the context of the exception.
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_decorator_with_cross_validated_overrides_set_separately(self):
# Setting the flags sequentially and not together will trip the validator,
# because it will be called at the end of each flagsaver call.
@flagsaver.flagsaver(flagsaver_test_validated_flag1='new_value')
@flagsaver.flagsaver(flagsaver_test_validated_flag2='new_value')
def mutate_flags_separately():
return (FLAGS.flagsaver_test_validated_flag1,
FLAGS.flagsaver_test_validated_flag2)
with self.assertRaisesRegex(flags.IllegalFlagValueError,
'Flag validation failed'):
mutate_flags_separately()
# The flags have not changed outside the context of the exception.
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_save_flag_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's value field and check that it changed.
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_save_flag_default(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's default field and check that it changed.
FLAGS.set_default('flagsaver_test_flag0', 'new_default')
self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default)
# Now restore the flag's default field.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
def test_restore_after_parse(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Sanity check (would fail if called with --flagsaver_test_flag0).
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
# Now populate the flag and check that it changed.
FLAGS['flagsaver_test_flag0'].parse('new value')
self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
def test_decorator_with_exception(self):
@flagsaver.flagsaver
def raise_exception():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
raise _TestError('something happened')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertRaises(_TestError, raise_exception)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_validator_list_is_restored(self):
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
@flagsaver.flagsaver
def modify_validators():
def no_space(value):
return ' ' not in value
flags.register_validator('flagsaver_test_flag0', no_space)
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
modify_validators()
self.assertEqual(original_validators,
FLAGS['flagsaver_test_flag0'].validators)
class FlagSaverDecoratorUsageTest(absltest.TestCase):
@flagsaver.flagsaver
def test_mutate1(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
@flagsaver.flagsaver
def test_mutate2(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
@flagsaver.flagsaver
def test_mutate3(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
@flagsaver.flagsaver
def test_mutate4(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
class FlagSaverSetUpTearDownUsageTest(absltest.TestCase):
def setUp(self):
self.saved_flag_values = flagsaver.save_flag_values()
def tearDown(self):
flagsaver.restore_flag_values(self.saved_flag_values)
def test_mutate1(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
def test_mutate2(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
def test_mutate3(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
def test_mutate4(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
class FlagSaverBadUsageTest(absltest.TestCase):
"""Tests that certain kinds of improper usages raise errors."""
def test_flag_saver_on_class(self):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver.flagsaver
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_call_on_class(self):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver.flagsaver()
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_with_overrides_on_class(self):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver.flagsaver(foo='bar')
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_multiple_positional_parameters(self):
with self.assertRaises(ValueError):
func_a = lambda: None
func_b = lambda: None
flagsaver.flagsaver(func_a, func_b)
def test_both_positional_and_keyword_parameters(self):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver.flagsaver(func_a, flagsaver_test_flag0='new value')
def test_duplicate_holder_parameters(self):
with self.assertRaises(ValueError):
flagsaver.flagsaver((INT_FLAG, 45), (INT_FLAG, 45))
def test_duplicate_holder_and_kw_parameter(self):
with self.assertRaises(ValueError):
flagsaver.flagsaver((INT_FLAG, 45), **{INT_FLAG.name: 45})
def test_both_positional_and_holder_parameters(self):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver.flagsaver(func_a, (INT_FLAG, 45))
def test_holder_parameters_wrong_shape(self):
with self.assertRaises(ValueError):
flagsaver.flagsaver(INT_FLAG)
def test_holder_parameters_tuple_too_long(self):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver.flagsaver((INT_FLAG, 4, 5))
def test_holder_parameters_tuple_wrong_type(self):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver.flagsaver((4, INT_FLAG))
def test_both_wrong_positional_parameters(self):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver.flagsaver(func_a, STR_FLAG, '45')
if __name__ == '__main__':
absltest.main()