| /* |
| * Copyright (C) 2015 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| |
| #include "RenderScript.h" |
| #include "rsCppInternal.h" |
| |
| #define NELEM(m) (sizeof(m) / sizeof((m)[0])) |
| |
| using android::RSC::Allocation; |
| using android::RSC::Element; |
| using android::RSC::RS; |
| using android::RSC::RS_ERROR_INVALID_ELEMENT; |
| using android::RSC::RS_ERROR_INVALID_PARAMETER; |
| using android::RSC::RS_SUCCESS; |
| using android::RSC::ScriptIntrinsicBLAS; |
| using android::RSC::sp; |
| |
| // ScriptIntrinsicBLAS APIS |
| ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp<RS> rs, sp<const Element> e) |
| : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { |
| |
| } |
| |
| sp<ScriptIntrinsicBLAS> ScriptIntrinsicBLAS::create(const sp<RS>& rs) { |
| return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); |
| } |
| |
| enum RsBlasDataType { |
| SINGLE, |
| DOUBLE, |
| SINGLE_COMPLEX, |
| DOUBLE_COMPLEX |
| }; |
| |
| static RsBlasCall |
| setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, |
| int TransA, int TransB, int Side, int Uplo, int Diag, |
| int M, int N, int K, int incX, int incY, int KL, int KU, |
| float alphaF, float betaF, double alphaD, double betaD, |
| float alphaCX, float alphaCY, float betaCX, float betaCY, |
| double alphaZX, double alphaZY, double betaZX, double betaZY |
| ) { |
| RsBlasCall call; |
| memset(&call, 0, sizeof(call)); |
| call.func = func; |
| call.transA = (RsBlasTranspose)TransA; |
| call.transB = (RsBlasTranspose)TransB; |
| call.side = (RsBlasSide)Side; |
| call.uplo = (RsBlasUplo)Uplo; |
| call.diag = (RsBlasDiag)Diag; |
| call.M = M; |
| call.N = N; |
| call.K = K; |
| |
| switch (dataType) { |
| case SINGLE: |
| // For Single-precision BLAS. |
| call.alpha.f = alphaF; |
| call.beta.f = betaF; |
| break; |
| case DOUBLE: |
| // For Double-precision BLAS. |
| call.alpha.d = alphaD; |
| call.beta.d = betaD; |
| break; |
| case SINGLE_COMPLEX: |
| // For Single-precision complex BLAS. |
| call.alpha.c.r = alphaCX; |
| call.alpha.c.i = alphaCY; |
| call.beta.c.r = betaCX; |
| call.beta.c.i = betaCY; |
| break; |
| case DOUBLE_COMPLEX: |
| // For Double-precision complex BLAS. |
| call.alpha.z.r = alphaZX; |
| call.alpha.z.i = alphaZY; |
| call.beta.z.r = betaZX; |
| call.beta.z.i = betaZY; |
| break; |
| default: |
| break; |
| } |
| |
| call.incX = incX; |
| call.incY = incY; |
| call.KL = KL; |
| call.KU = KU; |
| |
| return call; |
| } |
| |
| static void |
| nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, |
| int TransB, int Side, int Uplo, int Diag, int M, int N, int K, |
| float alpha, RsAllocation A, RsAllocation B, |
| float beta, RsAllocation C, int incX, int incY, int KL, int KU) { |
| RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, |
| M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, |
| 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); |
| RsAllocation in_allocs[3] = {A, B, C}; |
| tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, |
| &call, sizeof(call), nullptr, 0)); |
| } |
| |
| |
| static void |
| nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, |
| int TransB, int Side, int Uplo, int Diag, int M, int N, int K, |
| double alpha, RsAllocation A, RsAllocation B, |
| double beta, RsAllocation C, int incX, int incY, int KL, int KU) { |
| RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, |
| M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, |
| 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); |
| RsAllocation in_allocs[3] = {A, B, C}; |
| tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, |
| &call, sizeof(call), nullptr, 0)); |
| } |
| |
| static void |
| nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, |
| int TransB, int Side, int Uplo, int Diag, int M, int N, int K, |
| float alphaX, float alphaY, RsAllocation A, RsAllocation B, |
| float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { |
| RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, |
| M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, |
| alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); |
| RsAllocation in_allocs[3] = {A, B, C}; |
| tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, |
| &call, sizeof(call), nullptr, 0)); |
| } |
| |
| static void |
| nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, |
| int TransB, int Side, int Uplo, int Diag, int M, int N, int K, |
| double alphaX, double alphaY, RsAllocation A, RsAllocation B, |
| double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { |
| RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, |
| M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, |
| 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); |
| RsAllocation in_allocs[3] = {A, B, C}; |
| tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, |
| &call, sizeof(call), nullptr, 0)); |
| } |
| |
| |
| static void |
| nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, |
| RsAllocation A, int a_offset, RsAllocation B, int b_offset, |
| RsAllocation C, int c_offset, int c_mult_int) { |
| RsBlasCall call; |
| memset(&call, 0, sizeof(call)); |
| call.func = RsBlas_bnnm; |
| call.M = M; |
| call.N = N; |
| call.K = K; |
| call.a_offset = a_offset & 0xFF; |
| call.b_offset = b_offset & 0xFF; |
| call.c_offset = c_offset; |
| call.c_mult_int = c_mult_int; |
| |
| RsAllocation in_allocs[3] = {A, B, C}; |
| tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, |
| &call, sizeof(call), nullptr, 0)); |
| } |
| |
| /** |
| * Level 2 BLAS |
| */ |
| static void validateGEMV(RS* mRS, const sp<const Element>& e, RsBlasTranspose TransA, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) { |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = -1, expectedYDim = -1; |
| if (TransA == RsBlasNoTrans) { |
| expectedXDim = 1 + (N - 1) * incX; |
| expectedYDim = 1 + (M - 1) * incY; |
| } else { |
| expectedXDim = 1 + (M - 1) * incX; |
| expectedYDim = 1 + (N - 1) * incY; |
| } |
| if ((int)X->getType()->getX() != expectedXDim || |
| (int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, float beta, const sp<Allocation>& Y, int incY) { |
| validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha, A->getID(), X->getID(), |
| beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, double beta, const sp<Allocation>& Y, int incY) { |
| validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha, A->getID(), X->getID(), |
| beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, Float2 beta, const sp<Allocation>& Y, int incY) { |
| validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, Double2 beta, const sp<Allocation>& Y, int incY) { |
| validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, float beta, const sp<Allocation>& Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); |
| } |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha, A->getID(), X->getID(), |
| beta, Y->getID(), incX, incY, KL, KU); |
| } |
| |
| void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, double beta, const sp<Allocation>& Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); |
| } |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha, A->getID(), X->getID(), |
| beta, Y->getID(), incX, incY, KL, KU); |
| } |
| |
| void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); |
| } |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, KL, KU); |
| } |
| |
| void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) { |
| // GBMV has the same validation requirements as GEMV + KL and KU >= 0 |
| validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); |
| if (KL < 0 || KU < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); |
| } |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, |
| TransA, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, KL, KU); |
| } |
| |
| static void validateTRMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA, |
| RsBlasDiag Diag, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| int N = A->getType()->getY(); |
| if ((int)A->getType()->getX() != N) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); |
| } |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); |
| } |
| } |
| |
| static int validateTPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, RsBlasTranspose TransA, |
| RsBlasDiag Diag, const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| if (!Ap->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = sqrt((double)Ap->getType()->getX() * 2); |
| if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); |
| } |
| if (incX <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); |
| } |
| |
| return N; |
| } |
| |
| |
| void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBMV has the same requirements as TRMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBMV has the same requirements as TRMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBMV has the same requirements as TRMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBMV has the same requirements as TRMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| |
| } |
| |
| void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| |
| } |
| |
| void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TRSV is the same as TRMV |
| validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| |
| } |
| |
| void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBSV is the same as TRMV + K >= 0 |
| validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); |
| } |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBSV is the same as TRMV + K >= 0 |
| validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); |
| } |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, |
| A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBSV is the same as TRMV + K >= 0 |
| validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, |
| 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| int K, const sp<Allocation>& A, const sp<Allocation>& X, int incX) { |
| // TBSV is the same as TRMV + K >= 0 |
| validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); |
| int N = A->getType()->getY(); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, |
| A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| const sp<Allocation>& Ap, const sp<Allocation>& X, int incX) { |
| // TPSV is same as TPMV |
| int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, |
| TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, |
| Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); |
| } |
| |
| /** |
| * Level 2, S and D only |
| */ |
| static int validateSYMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& A, |
| const sp<Allocation>& X, const sp<Allocation>& Y, int incX, int incY) { |
| int N = A->getType()->getY(); |
| if ((int)A->getType()->getX() != N) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); |
| } |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e) ) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); |
| } |
| return N; |
| } |
| static int validateSPMV(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& Ap, |
| const sp<Allocation>& X, int incX, const sp<Allocation>& Y, int incY) { |
| if (!Ap->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = sqrt((double)Ap->getType()->getX() * 2); |
| if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); |
| } |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); |
| } |
| |
| return N; |
| } |
| static void validateGER(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e) ) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| |
| if (N < 1 || M < 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); |
| } |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (M - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); |
| } |
| |
| |
| } |
| static int validateSYR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, |
| const sp<Allocation>& X, int incX, const sp<Allocation>& A) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| |
| int N = A->getType()->getX(); |
| |
| if (X->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| if (N != (int)A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); |
| } |
| if (incX <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); |
| } |
| return N; |
| } |
| static int validateSPR(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, |
| const sp<Allocation>& X, int incX, const sp<Allocation>& Ap) { |
| if (!Ap->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = sqrt((double)Ap->getType()->getX() * 2); |
| if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); |
| } |
| if (incX <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); |
| } |
| |
| return N; |
| } |
| |
| static int validateSYR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int N = A->getType()->getX(); |
| |
| if (N != (int)A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); |
| } |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); |
| } |
| return N; |
| |
| } |
| static int validateSPR2(RS* mRS, const sp<const Element>& e, RsBlasUplo Uplo, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { |
| if (!Ap->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| if (Ap->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); |
| } |
| |
| int N = sqrt((double)Ap->getType()->getX() * 2); |
| if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); |
| } |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (N - 1) * incX; |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); |
| } |
| |
| return N; |
| } |
| |
| void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, float beta, const sp<Allocation>& Y, int incY) { |
| int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, float beta, const sp<Allocation>& Y, int incY) { |
| // SBMV is the same as SYMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, |
| 0, 0, 0, Uplo, 0, 0, N, K, alpha, |
| A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, |
| int incX, float beta, const sp<Allocation>& Y, int incY) { |
| int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SGER(float alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, |
| 0, 0, 0, 0, 0, M, N, 0, alpha, |
| X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& A) { |
| int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Ap) { |
| int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { |
| int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, double beta, const sp<Allocation>& Y, int incY) { |
| int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, double beta, const sp<Allocation>& Y, int incY) { |
| // SBMV is the same as SYMV + K >= 0 |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); |
| } |
| int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, |
| 0, 0, 0, Uplo, 0, 0, N, K, alpha, |
| A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, |
| int incX, double beta, const sp<Allocation>& Y, int incY) { |
| int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DGER(double alpha, const sp<Allocation>& X, int incX, const sp<Allocation>& Y, |
| int incY, const sp<Allocation>& A) { |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, |
| 0, 0, 0, 0, 0, M, N, 0, alpha, |
| X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& A) { |
| int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Ap) { |
| int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { |
| int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, alpha, |
| X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); |
| } |
| |
| |
| /** |
| * Level 2, C and Z only |
| */ |
| |
| static void validateGERU(RS* mRS, const sp<const Element>& e, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !X->getType()->getElement()->isCompatible(e) || |
| !Y->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); |
| } |
| |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| if (incX <= 0 || incY <= 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); |
| } |
| int expectedXDim = 1 + (M - 1) * incX; |
| if ((int)X->getType()->getX() != expectedXDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); |
| } |
| int expectedYDim = 1 + (N - 1) * incY; |
| if ((int)Y->getType()->getX() != expectedYDim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); |
| } |
| |
| } |
| |
| void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { |
| // HEMV is the same as SYR2 validation-wise |
| int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { |
| // HBMV is the same as SYR2 validation-wise |
| int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, |
| 0, 0, 0, Uplo, 0, 0, N, K, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& Ap, |
| const sp<Allocation>& X, int incX, Float2 beta, const sp<Allocation>& Y, int incY) { |
| // HPMV is the same as SPR2 |
| int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, Ap->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, |
| 0, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| // Same as GERU |
| validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, |
| 0, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& A) { |
| // Same as SYR |
| int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha, 0, X->getID(), 0, |
| 0, 0, A->getID(), incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Ap) { |
| // Equivalent to SPR for validation |
| int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha, 0, X->getID(), 0, |
| 0, 0, Ap->getID(), incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| // Same as SYR2 |
| int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { |
| // Same as SPR2 |
| int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, Ap->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& A, |
| const sp<Allocation>& X, int incX, Double2 beta, const sp<Allocation>& Y, int incY) { |
| // HEMV is the same as SYR2 validation-wise |
| int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& X, |
| int incX, Double2 beta, const sp<Allocation>& Y, int incY) { |
| // HBMV is the same as SYR2 validation-wise |
| int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| if (K < 0) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, |
| 0, 0, 0, Uplo, 0, 0, N, K, |
| alpha.x, alpha.y, A->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& Ap, const sp<Allocation>& X, |
| int incX, Double2 beta, const sp<Allocation>& Y, int incY) { |
| // HPMV is the same as SPR2 |
| int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, Ap->getID(), X->getID(), |
| beta.x, beta.y, Y->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, |
| 0, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| // Same as GERU |
| validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); |
| int M = A->getType()->getY(); |
| int N = A->getType()->getX(); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, |
| 0, 0, 0, 0, 0, M, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& A) { |
| // Same as SYR |
| int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha, 0, X->getID(), 0, |
| 0, 0, A->getID(), incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp<Allocation>& X, |
| int incX, const sp<Allocation>& Ap) { |
| // Equivalent to SPR for validation |
| int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha, 0, X->getID(), 0, |
| 0, 0, Ap->getID(), incX, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& A) { |
| // Same as SYR2 |
| int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, A->getID(), incX, incY, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp<Allocation>& X, int incX, |
| const sp<Allocation>& Y, int incY, const sp<Allocation>& Ap) { |
| // Same as SPR2 |
| int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, |
| 0, 0, 0, Uplo, 0, 0, N, 0, |
| alpha.x, alpha.y, X->getID(), Y->getID(), |
| 0, 0, Ap->getID(), incX, incY, 0, 0); |
| } |
| |
| |
| /** |
| * Level 3 BLAS |
| */ |
| |
| static void validateL3(RS* mRS, const sp<const Element>& e, int TransA, int TransB, int Side, |
| const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { |
| int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; |
| if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || |
| (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || |
| (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (C == nullptr) { |
| // Since matrix C is used to store the result, it cannot be null. |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); |
| } |
| cM = C->getType()->getY(); |
| cN = C->getType()->getX(); |
| |
| if (Side == RsBlasRight) { |
| if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); |
| } |
| if (B != nullptr) { |
| bM = A->getType()->getY(); |
| bN = A->getType()->getX(); |
| } |
| if (A != nullptr) { |
| aM = B->getType()->getY(); |
| aN = B->getType()->getX(); |
| } |
| } else { |
| if (A != nullptr) { |
| if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { |
| aN = A->getType()->getY(); |
| aM = A->getType()->getX(); |
| } else { |
| aM = A->getType()->getY(); |
| aN = A->getType()->getX(); |
| } |
| } |
| if (B != nullptr) { |
| if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { |
| bN = B->getType()->getY(); |
| bM = B->getType()->getX(); |
| } else { |
| bM = B->getType()->getY(); |
| bN = B->getType()->getX(); |
| } |
| } |
| } |
| if (A != nullptr && B != nullptr && C != nullptr) { |
| if (aN != bM || aM != cM || bN != cN) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); |
| } |
| } else if (A != nullptr && C != nullptr) { |
| // A and C only, for SYRK |
| if (cM != cN) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); |
| } |
| if (aM != cM) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); |
| } |
| } else if (A != nullptr && B != nullptr) { |
| // A and B only |
| if (aN != bM) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); |
| } |
| } |
| |
| } |
| |
| void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); |
| |
| int M = -1, N = -1, K = -1; |
| if (TransA != RsBlasNoTrans) { |
| M = A->getType()->getX(); |
| K = A->getType()->getY(); |
| } else { |
| M = A->getType()->getY(); |
| K = A->getType()->getX(); |
| } |
| if (TransB != RsBlasNoTrans) { |
| N = B->getType()->getY(); |
| } else { |
| N = B->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, |
| TransA, TransB, 0, 0, 0, M, N, K, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA != RsBlasNoTrans) { |
| M = A->getType()->getX(); |
| K = A->getType()->getY(); |
| } else { |
| M = A->getType()->getY(); |
| K = A->getType()->getX(); |
| } |
| if (TransB != RsBlasNoTrans) { |
| N = B->getType()->getY(); |
| } else { |
| N = B->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, |
| TransA, TransB, 0, 0, 0, M, N, K, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA != RsBlasNoTrans) { |
| M = A->getType()->getX(); |
| K = A->getType()->getY(); |
| } else { |
| M = A->getType()->getY(); |
| K = A->getType()->getX(); |
| } |
| if (TransB != RsBlasNoTrans) { |
| N = B->getType()->getY(); |
| } else { |
| N = B->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, |
| TransA, TransB, 0, 0, 0, M, N, K, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); |
| int M = -1, N = -1, K = -1; |
| if (TransA != RsBlasNoTrans) { |
| M = A->getType()->getX(); |
| K = A->getType()->getY(); |
| } else { |
| M = A->getType()->getY(); |
| K = A->getType()->getX(); |
| } |
| if (TransB != RsBlasNoTrans) { |
| N = B->getType()->getY(); |
| } else { |
| N = B->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, |
| TransA, TransB, 0, 0, 0, M, N, K, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { |
| //For SYMM, Matrix A should be symmetric |
| if (A->getType()->getX() != A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); |
| } |
| validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, |
| 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { |
| if (A->getType()->getX() != A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); |
| } |
| validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, |
| 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { |
| if (A->getType()->getX() != A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); |
| } |
| validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, |
| 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { |
| if (A->getType()->getX() != A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); |
| } |
| validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, |
| 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, |
| const sp<Allocation>& A, float beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha, A->getID(), 0, |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, |
| const sp<Allocation>& A, double beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha, A->getID(), 0, |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, |
| const sp<Allocation>& A, Float2 beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha.x, alpha.y, A->getID(), 0, |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, |
| const sp<Allocation>& A, Double2 beta, const sp<Allocation>& C) { |
| validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha.x, alpha.y, A->getID(), 0, |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| static void validateSYR2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, |
| const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !B->getType()->getElement()->isCompatible(e) || |
| !C->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| int Cdim = -1; |
| // A is n x k if no transpose, k x n if transpose |
| // C is n x n |
| if (Trans == RsBlasTrans) { |
| // check columns versus C |
| Cdim = A->getType()->getX(); |
| } else { |
| // check rows versus C |
| Cdim = A->getType()->getY(); |
| } |
| if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); |
| } |
| // A dims == B dims |
| if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { |
| validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { |
| validateSYR2K(mRS, Element::F64(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha, A->getID(), B->getID(), |
| beta, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { |
| validateSYR2K(mRS, Element::F32_2(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyr2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { |
| validateSYR2K(mRS, Element::F64_2(mRS), Trans, A, B, C); |
| int K = -1; |
| if (Trans != RsBlasNoTrans) { |
| K = A->getType()->getY(); |
| } else { |
| K = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyr2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| static void validateTRMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA, |
| const sp<Allocation>& A, const sp<Allocation>& B) { |
| int aM = -1, aN = -1, bM = -1, bN = -1; |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !B->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| |
| aM = A->getType()->getY(); |
| aN = A->getType()->getX(); |
| if (aM != aN) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with a non-symmetric matrix A"); |
| } |
| |
| bM = B->getType()->getY(); |
| bN = B->getType()->getX(); |
| if (Side == RsBlasLeft) { |
| if (aN != bM) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); |
| } |
| } else { |
| if (bN != aM) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRMM with invalid matrices"); |
| } |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::STRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| float alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRMM(mRS, Element::F32(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmm, |
| TransA, 0, Side, Uplo, Diag,\ |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), 0.f, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| double alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRMM(mRS, Element::F64(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRMM(mRS, Element::F32_2(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTRMM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRMM(mRS, Element::F64_2(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| static void validateTRSM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, RsBlasTranspose TransA, |
| const sp<Allocation>& A, const sp<Allocation>& B) { |
| int adim = -1, bM = -1, bN = -1; |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !B->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| adim = A->getType()->getX(); |
| if (adim != (int)A->getType()->getY()) { |
| // This may be unnecessary, the restriction could potentially be relaxed. |
| // Allocation A needs to contain at least that symmetric matrix but could theoretically |
| // be larger for now we assume adapters are sufficient, will reevaluate in the future. |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with a non-symmetric matrix A"); |
| } |
| bM = B->getType()->getY(); |
| bN = B->getType()->getX(); |
| if (Side == RsBlasLeft) { |
| // A is M*M |
| if (adim != bM) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); |
| } |
| } else { |
| // A is N*N |
| if (adim != bN) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called TRSM with invalid matrix dimensions"); |
| } |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::STRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| float alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRSM(mRS, Element::F32(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::DTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| double alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRSM(mRS, Element::F64(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::CTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| Float2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRSM(mRS, Element::F32_2(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZTRSM(RsBlasSide Side, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, |
| Double2 alpha, const sp<Allocation>& A, const sp<Allocation>& B) { |
| validateTRSM(mRS, Element::F64_2(mRS), Side, TransA, A, B); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsm, |
| TransA, 0, Side, Uplo, Diag, |
| B->getType()->getY(), B->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), 0, 0, 0, 0, 0, 0, 0); |
| } |
| |
| static void validateHEMM(RS* mRS, const sp<const Element>& e, RsBlasSide Side, |
| const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !B->getType()->getElement()->isCompatible(e) || |
| !C->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| |
| // A must be square; can potentially be relaxed similar to TRSM |
| int adim = A->getType()->getX(); |
| if (adim != (int)A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with non-square A"); |
| } |
| if ((Side == RsBlasLeft && adim != (int)B->getType()->getY()) || |
| (Side == RsBlasRight && adim != (int)B->getType()->getX())) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with invalid B"); |
| } |
| if (B->getType()->getX() != C->getType()->getX() || |
| B->getType()->getY() != C->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HEMM with mismatched B and C"); |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::CHEMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Float2 beta, const sp<Allocation>& C) { |
| validateHEMM(mRS, Element::F32_2(mRS), Side, A, B, C); |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemm, |
| 0, 0, Side, Uplo, 0, |
| C->getType()->getY(), C->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHEMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, Double2 beta, const sp<Allocation>& C) { |
| validateHEMM(mRS, Element::F64_2(mRS), Side, A, B, C); |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemm, |
| 0, 0, Side, Uplo, 0, |
| C->getType()->getY(), C->getType()->getX(), 0, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta.x, beta.y, C->getID(), 0, 0, 0, 0); |
| } |
| |
| static void validateHERK(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, |
| const sp<Allocation>& A, const sp<Allocation>& C) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !C->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); |
| } |
| int cdim = C->getType()->getX(); |
| if (cdim != (int)C->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with non-square C"); |
| } |
| if (Trans == RsBlasNoTrans) { |
| if (cdim != (int)A->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); |
| } |
| } else { |
| if (cdim != (int)A->getType()->getX()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HERK with invalid A"); |
| } |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::CHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, |
| const sp<Allocation>& A, float beta, const sp<Allocation>& C) { |
| validateHERK(mRS, Element::F32_2(mRS), Trans, A, C); |
| int k = 0; |
| if (Trans == RsBlasConjTrans) { |
| k = A->getType()->getY(); |
| } else { |
| k = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cherk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, |
| alpha, 0, A->getID(), 0, |
| beta, 0, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHERK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, |
| const sp<Allocation>& A, double beta, const sp<Allocation>& C) { |
| validateHERK(mRS, Element::F64_2(mRS), Trans, A, C); |
| int k = 0; |
| if (Trans == RsBlasConjTrans) { |
| k = A->getType()->getY(); |
| } else { |
| k = A->getType()->getX(); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zherk, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, |
| alpha, 0, A->getID(), 0, |
| beta, 0, C->getID(), 0, 0, 0, 0); |
| } |
| |
| static void validateHER2K(RS* mRS, const sp<const Element>& e, RsBlasTranspose Trans, |
| const sp<Allocation>& A, const sp<Allocation>& B, const sp<Allocation>& C) { |
| if (!A->getType()->getElement()->isCompatible(e) || |
| !B->getType()->getElement()->isCompatible(e) || |
| !C->getType()->getElement()->isCompatible(e)) { |
| mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); |
| } |
| if (Trans != RsBlasNoTrans && Trans != RsBlasConjTrans) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Call HERK with invalid Transpose"); |
| } |
| int cdim = C->getType()->getX(); |
| if (cdim != (int)C->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with non-square C"); |
| } |
| if (Trans == RsBlasNoTrans) { |
| if ((int)A->getType()->getY() != cdim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); |
| } |
| } else { |
| if ((int)A->getType()->getX() != cdim) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid matrices"); |
| } |
| } |
| if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called HER2K with invalid A and B matrices"); |
| } |
| } |
| |
| void ScriptIntrinsicBLAS::CHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, float beta, const sp<Allocation>& C) { |
| validateHER2K(mRS, Element::F32_2(mRS), Trans, A, B, C); |
| int k = 0; |
| if (Trans == RsBlasNoTrans) { |
| k = A->getType()->getX(); |
| } else { |
| k = A->getType()->getY(); |
| } |
| nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta, 0, C->getID(), 0, 0, 0, 0); |
| } |
| |
| void ScriptIntrinsicBLAS::ZHER2K(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, |
| const sp<Allocation>& A, const sp<Allocation>& B, double beta, const sp<Allocation>& C) { |
| validateHER2K(mRS, Element::F64_2(mRS), Trans, A, B, C); |
| int k = 0; |
| if (Trans == RsBlasNoTrans) { |
| k = A->getType()->getX(); |
| } else { |
| k = A->getType()->getY(); |
| } |
| nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2k, |
| Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), k, |
| alpha.x, alpha.y, A->getID(), B->getID(), |
| beta, 0, C->getID(), 0, 0, 0, 0); |
| } |
| |
| |
| |
| void ScriptIntrinsicBLAS::BNNM(const sp<Allocation>& A, int a_offset, const sp<Allocation>& B, int b_offset, |
| const sp<Allocation>& C, int c_offset, int c_mult) { |
| validateL3(mRS, Element::U8(mRS), RsBlasNoTrans, RsBlasTrans, 0, A, B, C); |
| |
| if (a_offset < 0 || a_offset > 255) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid a_offset passed to BNNM"); |
| } |
| if (b_offset < 0 || b_offset > 255) { |
| mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid b_offset passed to BNNM"); |
| } |
| int M = -1, N = -1, K = -1; |
| M = A->getType()->getY(); |
| N = B->getType()->getY(); |
| K = A->getType()->getX(); |
| |
| nScriptIntrinsicBLAS_BNNM(mRS, mRS->getContext(), getID(), M, N, K, A->getID(), a_offset, |
| B->getID(), b_offset, C->getID(), c_offset, c_mult); |
| } |