blob: 61ec54b667c5728642a3f249fb450c831cff80d2 [file] [log] [blame]
#!/usr/bin/python3
# Copyright (C) 2024 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts a legacy test class (using AndroidTestCase, TestCase or
# InstrumentationTestCase to a modern style test class, in a best-effort manner.
#
# Usage:
# convert-androidtest.py TARGET-FILE [TARGET-FILE ...]
#
# Caveats:
# - It adds all the extra imports, even if they're not needed.
# - It won't sort imports.
# - It also always adds getContext() and getTestContext().
#
import sys
import fileinput
import re
import subprocess
# Print message on console
def log(msg):
print(msg, file=sys.stderr)
# Matches `extends AndroidTestCase` (or another similar base class)
re_extends = re.compile(
r''' \b extends \s+ (AndroidTestCase|TestCase|InstrumentationTestCase) \s* ''',
re.S + re.X)
# Look into given files and return the files that have `re_extends`.
def find_target_files(files):
ret = []
for file in files:
try:
with open(file, 'r') as f:
data = f.read()
if re_extends.search(data):
ret.append(file)
except FileNotFoundError as e:
log(f'Failed to open file {file}: {e}')
return ret
def main(args):
files = args
# Find the files that should be processed.
files = find_target_files(files)
if len(files) == 0:
log("No target files found.")
return 0
# Process the files.
with fileinput.input(files=(files), inplace = True, backup = '.bak') as f:
import_seen = False
carry_over = ''
class_body_started = False
class_seen = False
def on_file_start():
nonlocal import_seen, carry_over, class_body_started, class_seen
import_seen = False
carry_over = ''
class_body_started = False
class_seen = False
for line in f:
if (fileinput.filelineno() == 1):
log(f"Processing: {fileinput.filename()}")
on_file_start()
line = line.rstrip('\n')
# Carry over a certain line to the next line.
if re.search(r'''@Override\b''', line):
carry_over = carry_over + line + '\n'
continue
if carry_over:
line = carry_over + line
carry_over = ''
# Remove the base class from the class definition.
line = re_extends.sub('', line)
# Add a @RunWith.
if not class_seen and re.search(r'''\b class \b''', line, re.X):
class_seen = True
print("@RunWith(AndroidJUnit4.class)")
# Inject extra imports.
if not import_seen and re.search(r'''^import\b''', line):
import_seen = True
print("""\
import android.content.Context;
import androidx.test.platform.app.InstrumentationRegistry;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertSame;
import static junit.framework.TestCase.assertNotSame;
import static junit.framework.TestCase.assertTrue;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertNull;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.fail;
import org.junit.After;
import org.junit.Before;
import org.junit.runner.RunWith;
import org.junit.Test;
import androidx.test.ext.junit.runners.AndroidJUnit4;
""")
# Add @Test to the test methods.
if re.search(r'''^ \s* public \s* void \s* test''', line, re.X):
print(" @Test")
# Convert setUp/tearDown to @Before/@After.
if re.search(r''' ^\s+ ( \@Override \s+ ) ? (public|protected) \s+ void \s+ (setUp|tearDown) ''',
line, re.X):
if re.search('setUp', line):
print(' @Before')
else:
print(' @After')
line = re.sub(r''' \s* \@Override \s* \n ''', '', line, 0, re.X)
line = re.sub(r'''protected''', 'public', line, 0, re.X)
# Remove the super setUp / tearDown call.
if re.search(r''' \b super \. (setUp|tearDown) \b ''', line, re.X):
continue
# Convert mContext to getContext().
line = re.sub(r'''\b mContext \b ''', 'getContext()', line, 0, re.X)
# Print the processed line.
print(line)
# Add getContext() / getTestContext() at the beginning of the class.
if not class_body_started and re.search(r'''\{''', line):
class_body_started = True
print("""\
private Context getContext() {
return InstrumentationRegistry.getInstrumentation().getTargetContext();
}
private Context getTestContext() {
return InstrumentationRegistry.getInstrumentation().getContext();
}
""")
# Run diff
for file in files:
subprocess.call(["diff", "-u", "--color=auto", f"{file}.bak", file])
log(f'{len(files)} file(s) converted.')
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))