blob: 0051aa1591dbe177908032d076c6f124c9a02654 [file] [log] [blame]
"""
Unit tests for CLI entry points.
"""
import unittest
import sys
import functools
from contextlib import contextmanager
import os
from io import StringIO, BytesIO
import rsa
import rsa.cli
if sys.version_info[0] < 3:
IOClass = BytesIO
else:
IOClass = StringIO
@contextmanager
def captured_output():
"""Captures output to stdout and stderr"""
new_out, new_err = IOClass(), IOClass()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
@contextmanager
def cli_args(*new_argv):
"""Updates sys.argv[1:] for a single test."""
old_args = sys.argv[:]
sys.argv[1:] = [str(arg) for arg in new_argv]
try:
yield
finally:
sys.argv[1:] = old_args
def cleanup_files(*filenames):
"""Makes sure the files don't exist when the test runs, and deletes them afterward."""
def remove():
for fname in filenames:
if os.path.exists(fname):
os.unlink(fname)
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove()
try:
return func(*args, **kwargs)
finally:
remove()
return wrapper
return decorator
class AbstractCliTest(unittest.TestCase):
def assertExits(self, status_code, func, *args, **kwargs):
try:
func(*args, **kwargs)
except SystemExit as ex:
if status_code == ex.code:
return
self.fail('SystemExit() raised by %r, but exited with code %i, expected %i' % (
func, ex.code, status_code))
else:
self.fail('SystemExit() not raised by %r' % func)
class KeygenTest(AbstractCliTest):
def test_keygen_no_args(self):
with cli_args():
self.assertExits(1, rsa.cli.keygen)
def test_keygen_priv_stdout(self):
with captured_output() as (out, err):
with cli_args(128):
rsa.cli.keygen()
lines = out.getvalue().splitlines()
self.assertEqual('-----BEGIN RSA PRIVATE KEY-----', lines[0])
self.assertEqual('-----END RSA PRIVATE KEY-----', lines[-1])
# The key size should be shown on stderr
self.assertTrue('128-bit key' in err.getvalue())
@cleanup_files('test_cli_privkey_out.pem')
def test_keygen_priv_out_pem(self):
with captured_output() as (out, err):
with cli_args('--out=test_cli_privkey_out.pem', '--form=PEM', 128):
rsa.cli.keygen()
# The key size should be shown on stderr
self.assertTrue('128-bit key' in err.getvalue())
# The output file should be shown on stderr
self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
# If we can load the file as PEM, it's good enough.
with open('test_cli_privkey_out.pem', 'rb') as pemfile:
rsa.PrivateKey.load_pkcs1(pemfile.read())
@cleanup_files('test_cli_privkey_out.der')
def test_keygen_priv_out_der(self):
with captured_output() as (out, err):
with cli_args('--out=test_cli_privkey_out.der', '--form=DER', 128):
rsa.cli.keygen()
# The key size should be shown on stderr
self.assertTrue('128-bit key' in err.getvalue())
# The output file should be shown on stderr
self.assertTrue('test_cli_privkey_out.der' in err.getvalue())
# If we can load the file as der, it's good enough.
with open('test_cli_privkey_out.der', 'rb') as derfile:
rsa.PrivateKey.load_pkcs1(derfile.read(), format='DER')
@cleanup_files('test_cli_privkey_out.pem', 'test_cli_pubkey_out.pem')
def test_keygen_pub_out_pem(self):
with captured_output() as (out, err):
with cli_args('--out=test_cli_privkey_out.pem',
'--pubout=test_cli_pubkey_out.pem',
'--form=PEM', 256):
rsa.cli.keygen()
# The key size should be shown on stderr
self.assertTrue('256-bit key' in err.getvalue())
# The output files should be shown on stderr
self.assertTrue('test_cli_privkey_out.pem' in err.getvalue())
self.assertTrue('test_cli_pubkey_out.pem' in err.getvalue())
# If we can load the file as PEM, it's good enough.
with open('test_cli_pubkey_out.pem', 'rb') as pemfile:
rsa.PublicKey.load_pkcs1(pemfile.read())