blob: a4ebaa1be67695700b7e608507eac6724546f741 [file] [log] [blame]
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use quote::quote;
use syn::{parse_macro_input, punctuated::Punctuated, spanned::Spanned, token::Comma, Expr, Ident};
struct AccumulatePartsState {
var_num: usize,
error_message_ident: Ident,
statements: Vec<proc_macro2::TokenStream>,
}
fn expr_to_string(expr: &Expr) -> String {
quote!(#expr).to_string()
}
impl AccumulatePartsState {
fn new() -> Self {
Self {
var_num: 0,
error_message_ident: Ident::new(
"__googletest__verify_pred__error_message",
::proc_macro2::Span::call_site(),
),
statements: vec![],
}
}
/// Takes an expression with chained field accesses and method calls and
/// accumulates intermediate expressions used for computing `verify_pred!`'s
/// expression, including intermediate variable assignments to evaluate
/// parts of the expression exactly once, and the format string used to
/// output intermediate values on condition failure. It returns the new form
/// of the input expression with parts of it potentially replaced by the
/// intermediate variables.
fn accumulate_parts(&mut self, expr: Expr) -> Expr {
// Literals don't need to be printed or stored in intermediate variables.
if is_literal(&expr) {
return expr;
}
let expr_string = expr_to_string(&expr);
let new_expr = match expr {
Expr::Group(mut group) => {
// This is an invisible group added for correct precedence in the AST. Just pass
// through without having a separate printing result.
*group.expr = self.accumulate_parts(*group.expr);
return Expr::Group(group);
}
Expr::Field(mut field) => {
// Don't assign field access to an intermediate variable to avoid moving out of
// non-`Copy` fields.
*field.base = self.accumulate_parts(*field.base);
Expr::Field(field)
}
Expr::Call(mut call) => {
// Cache args into intermediate variables.
call.args = self.define_variables_for_args(call.args);
// Cache function value into an intermediate variable.
self.define_variable(&Expr::Call(call))
}
Expr::MethodCall(mut method_call) => {
*method_call.receiver = self.accumulate_parts(*method_call.receiver);
// Cache args into intermediate variables.
method_call.args = self.define_variables_for_args(method_call.args);
// Cache method value into an intermediate variable.
self.define_variable(&Expr::MethodCall(method_call))
}
Expr::Binary(mut binary) => {
*binary.left = self.accumulate_parts(*binary.left);
*binary.right = self.accumulate_parts(*binary.right);
Expr::Binary(binary)
}
Expr::Unary(mut unary) => {
*unary.expr = self.accumulate_parts(*unary.expr);
Expr::Unary(unary)
}
// A path expression doesn't need to be stored in an intermediate variable.
// This avoids moving out of an existing variable.
Expr::Path(_) => expr,
// By default, assume it's some expression that needs to be cached to avoid
// double-evaluation.
_ => self.define_variable(&expr),
};
let error_message_ident = &self.error_message_ident;
self.statements.push(quote! {
::googletest::fmt::internal::__googletest__write_expr_value!(
&mut #error_message_ident,
#expr_string,
#new_expr,
);
});
new_expr
}
// Defines a variable for each argument expression so that it's evaluated
// exactly once.
fn define_variables_for_args(
&mut self,
args: Punctuated<Expr, Comma>,
) -> Punctuated<Expr, Comma> {
args.into_pairs()
.map(|mut pair| {
// Don't need to assign literals to intermediate variables.
if is_literal(pair.value()) {
return pair;
}
let var_expr = self.define_variable(pair.value());
let error_message_ident = &self.error_message_ident;
let expr_string = expr_to_string(pair.value());
self.statements.push(quote! {
::googletest::fmt::internal::__googletest__write_expr_value!(
&mut #error_message_ident,
#expr_string,
#var_expr,
);
});
*pair.value_mut() = var_expr;
pair
})
.collect()
}
/// Defines a new variable assigned to the expression and returns the
/// variable as an expression to be used in place of the passed-in
/// expression.
fn define_variable(&mut self, value: &Expr) -> Expr {
let var_name =
Ident::new(&format!("__googletest__verify_pred__var{}", self.var_num), value.span());
self.var_num += 1;
self.statements.push(quote! {
#[allow(non_snake_case)]
let mut #var_name = #value;
});
syn::parse::<Expr>(quote!(#var_name).into()).unwrap()
}
}
// Whether it's a literal or unary operator applied to a literal (1, -1).
fn is_literal(expr: &Expr) -> bool {
match expr {
Expr::Lit(_) => true,
Expr::Unary(unary) => matches!(&*unary.expr, Expr::Lit(_)),
_ => false,
}
}
pub fn verify_pred_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let parsed = parse_macro_input!(input as Expr);
let error_message = quote!(#parsed).to_string() + " was false with";
let mut state = AccumulatePartsState::new();
let pred_value = state.accumulate_parts(parsed);
let AccumulatePartsState { error_message_ident, mut statements, .. } = state;
let _ = statements.pop(); // The last statement prints the full expression itself.
quote! {
{
let mut #error_message_ident = #error_message.to_string();
#(#statements)*
if (#pred_value) {
Ok(())
} else {
::core::result::Result::Err(
::googletest::internal::test_outcome::TestAssertionFailure::create(
#error_message_ident))
}
}
}
.into()
}