Implement fmax and fmin (vector) for aarch64 assembler
PiperOrigin-RevId: 423405907
diff --git a/src/jit/aarch64-assembler.cc b/src/jit/aarch64-assembler.cc
index 757585a..7dbd276 100644
--- a/src/jit/aarch64-assembler.cc
+++ b/src/jit/aarch64-assembler.cc
@@ -239,6 +239,24 @@
return emit32(0x0E20D400 | q(vd) | fp_sz(vn) | rm(vm) | rn(vn) | rd(vd));
}
+Assembler& Assembler::fmax(VRegister vd, VRegister vn, VRegister vm) {
+ if (!is_same_shape(vd, vn, vm)) {
+ error_ = Error::kInvalidOperand;
+ return *this;
+ }
+
+ return emit32(0x0E20F400 | q(vd) | fp_sz(vn) | rm(vm) | rn(vn) | rd(vd));
+}
+
+Assembler& Assembler::fmin(VRegister vd, VRegister vn, VRegister vm) {
+ if (!is_same_shape(vd, vn, vm)) {
+ error_ = Error::kInvalidOperand;
+ return *this;
+ }
+
+ return emit32(0x0EA0F400 | q(vd) | fp_sz(vn) | rm(vm) | rn(vn) | rd(vd));
+}
+
Assembler& Assembler::fmla(VRegister vd, VRegister vn, VRegisterLane vm) {
if (!is_same_shape(vd, vn) || !is_same_data_type(vd, vm)) {
error_ = Error::kInvalidOperand;
diff --git a/src/xnnpack/aarch64-assembler.h b/src/xnnpack/aarch64-assembler.h
index 355fb27..098165b 100644
--- a/src/xnnpack/aarch64-assembler.h
+++ b/src/xnnpack/aarch64-assembler.h
@@ -247,6 +247,8 @@
// SIMD instructions
Assembler& fadd(VRegister vd, VRegister vn, VRegister vm);
+ Assembler& fmax(VRegister vd, VRegister vn, VRegister vm);
+ Assembler& fmin(VRegister vd, VRegister vn, VRegister vm);
Assembler& fmla(VRegister vd, VRegister vn, VRegisterLane vm);
Assembler& ld1(VRegisterList vs, MemOperand xn, int32_t imm);
Assembler& ld2r(VRegisterList xs, MemOperand xn);
diff --git a/test/aarch64-assembler.cc b/test/aarch64-assembler.cc
index affdd83..29812b2 100644
--- a/test/aarch64-assembler.cc
+++ b/test/aarch64-assembler.cc
@@ -62,6 +62,12 @@
CHECK_ENCODING(0x4E25D690, a.fadd(v16.v4s(), v20.v4s(), v5.v4s()));
EXPECT_ERROR(Error::kInvalidOperand, a.fadd(v16.v4s(), v20.v4s(), v5.v2s()));
+ CHECK_ENCODING(0x4E30F7E3, a.fmax(v3.v4s(), v31.v4s(), v16.v4s()));
+ EXPECT_ERROR(Error::kInvalidOperand, a.fmax(v3.v8h(), v31.v4s(), v16.v4s()));
+
+ CHECK_ENCODING(0x4EB1F7C2, a.fmin(v2.v4s(), v30.v4s(), v17.v4s()));
+ EXPECT_ERROR(Error::kInvalidOperand, a.fmin(v2.v4s(), v30.v16b(), v17.v4s()));
+
CHECK_ENCODING(0x4F801290, a.fmla(v16.v4s(), v20.v4s(), v0.s()[0]));
EXPECT_ERROR(Error::kInvalidOperand, a.fmla(v16.v4s(), v20.v2s(), v0.s()[0]));
EXPECT_ERROR(Error::kInvalidOperand, a.fmla(v16.v2d(), v20.v2d(), v0.s()[0]));