Pass new parameters while generating Target object
http://b/23536224
While generating tests, identify and pass FunctionType and ReturnType
parameters to Target's constructor.
Change-Id: I0e1be6279d21028175571d53a7def7fc197778eb
diff --git a/api/GenerateTestFiles.cpp b/api/GenerateTestFiles.cpp
index 45cf1f3..c5a5b54 100644
--- a/api/GenerateTestFiles.cpp
+++ b/api/GenerateTestFiles.cpp
@@ -128,6 +128,9 @@
*/
void writeJavaVerifyVectorMethod() const;
+ // Generate the line that creates the Target.
+ void writeJavaCreateTarget() const;
+
// Generate the method header of the verify function.
void writeJavaVerifyMethodHeader() const;
@@ -430,7 +433,7 @@
}
mJava->indent() << "// Ask the CoreMathVerifier to validate.\n";
if (hasFloat) {
- mJava->indent() << "Target target = new Target(relaxed);\n";
+ writeJavaCreateTarget();
}
mJava->indent() << "String errorMessage = CoreMathVerifier."
<< mJavaVerifierVerifyMethodName << "(args";
@@ -442,7 +445,7 @@
} else {
mJava->indent() << "// Figure out what the outputs should have been.\n";
if (hasFloat) {
- mJava->indent() << "Target target = new Target(relaxed);\n";
+ writeJavaCreateTarget();
}
mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName << "(args";
if (hasFloat) {
@@ -539,7 +542,7 @@
}
}
}
- mJava->indent() << "Target target = new Target(relaxed);\n";
+ writeJavaCreateTarget();
mJava->indent() << "CoreMathVerifier." << mJavaVerifierComputeMethodName
<< "(args, target);\n\n";
@@ -582,6 +585,39 @@
*mJava << "\n";
}
+
+void PermutationWriter::writeJavaCreateTarget() const {
+ string name = mPermutation.getName();
+
+ const char* functionType = "NORMAL";
+ size_t end = name.find('_');
+ if (end != string::npos) {
+ if (name.compare(0, end, "native") == 0) {
+ functionType = "NATIVE";
+ } else if (name.compare(0, end, "half") == 0) {
+ functionType = "HALF";
+ } else if (name.compare(0, end, "fast") == 0) {
+ functionType = "FAST";
+ }
+ }
+
+ string floatType = mReturnParam->specType;
+ const char* precisionStr = "";
+ if (floatType.compare("f16") == 0) {
+ precisionStr = "HALF";
+ } else if (floatType.compare("f32") == 0) {
+ precisionStr = "FLOAT";
+ } else if (floatType.compare("f64") == 0) {
+ precisionStr = "DOUBLE";
+ } else {
+ cerr << "Error. Unreachable. Return type is not floating point\n";
+ }
+
+ mJava->indent() << "Target target = new Target(Target.FunctionType." <<
+ functionType << ", Target.ReturnType." << precisionStr <<
+ ", relaxed);\n";
+}
+
void PermutationWriter::writeJavaVerifyMethodHeader() const {
mJava->indent() << "private void " << mJavaVerifyMethodName << "(";
for (auto p : mAllInputsAndOutputs) {
@@ -903,7 +939,8 @@
*file << "import android.renderscript.Allocation;\n";
*file << "import android.renderscript.RSRuntimeException;\n";
- *file << "import android.renderscript.Element;\n\n";
+ *file << "import android.renderscript.Element;\n";
+ *file << "import android.renderscript.cts.Target;\n\n";
*file << "import java.util.Arrays;\n\n";
*file << "public class " << testName << " extends RSBaseCompute";