Merge "Minimize allocations in the Kalman library" into androidx-main
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java
index 207b941..91f2d12 100644
--- a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/KalmanFilter.java
@@ -52,8 +52,16 @@
// Measurement matrix
public @NonNull Matrix H;
- // Kalman gain
- public @NonNull Matrix K;
+ // Buffers to minimize matrix allocations on every MotionEvent
+ private @NonNull Matrix mBufferXDimOne;
+ private @NonNull Matrix mBufferXDimXDim;
+ private @NonNull Matrix mBufferXDimXDim2;
+ private @NonNull Matrix mBufferXDimZDim;
+ private @NonNull Matrix mBufferXDimZDim2;
+ private @NonNull Matrix mBufferZDimOne;
+ private @NonNull Matrix mBufferZDimXDim;
+ private @NonNull Matrix mBufferZDimZDim;
+ private @NonNull Matrix mBufferZDimTwiceZDim;
public KalmanFilter(int xDim, int zDim) {
x = new Matrix(xDim, 1);
@@ -62,7 +70,15 @@
R = Matrix.identity(zDim);
F = new Matrix(xDim, xDim);
H = new Matrix(zDim, xDim);
- K = new Matrix(xDim, zDim);
+ mBufferXDimZDim = new Matrix(xDim, zDim);
+ mBufferXDimZDim2 = new Matrix(xDim, zDim);
+ mBufferXDimOne = new Matrix(xDim, 1);
+ mBufferXDimXDim = new Matrix(xDim, xDim);
+ mBufferXDimXDim2 = new Matrix(xDim, xDim);
+ mBufferZDimOne = new Matrix(zDim, 1);
+ mBufferZDimXDim = new Matrix(zDim, xDim);
+ mBufferZDimZDim = new Matrix(zDim, zDim);
+ mBufferZDimTwiceZDim = new Matrix(zDim, 2 * zDim);
}
/** Resets the internal state of this Kalman filter. */
@@ -70,7 +86,6 @@
// NOTE: It is not necessary to reset Q, R, F, and H matrices.
x.fill(0);
Matrix.setIdentity(P);
- K.fill(0);
}
/**
@@ -78,16 +93,24 @@
* estimate for the current timestep.
*/
public void predict() {
- x = F.dot(x);
- P = F.dot(P).dotTranspose(F).plus(Q);
+ Matrix originalX = x;
+ x = F.dot(x, mBufferXDimOne);
+ mBufferXDimOne = originalX;
+
+ F.dot(P, mBufferXDimXDim).dotTranspose(F, P).plus(Q);
}
/** Updates the state estimate to incorporate the new observation z. */
public void update(@NonNull Matrix z) {
- Matrix y = z.minus(H.dot(x));
- Matrix tS = H.dot(P).dotTranspose(H).plus(R);
- K = P.dotTranspose(H).dot(tS.inverse());
- x = x.plus(K.dot(y));
- P = P.minus(K.dot(H).dot(P));
+ z.minus(H.dot(x, mBufferZDimOne));
+ H.dot(P, mBufferZDimXDim)
+ .dotTranspose(H, mBufferZDimZDim)
+ .plus(R)
+ .inverse(mBufferZDimTwiceZDim);
+
+ P.dotTranspose(H, mBufferXDimZDim2).dot(mBufferZDimZDim, mBufferXDimZDim);
+
+ x.plus(mBufferXDimZDim.dot(z, mBufferXDimOne));
+ P.minus(mBufferXDimZDim.dot(H, mBufferXDimXDim).dot(P, mBufferXDimXDim2));
}
}
diff --git a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java
index 0294b18..399263d 100644
--- a/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java
+++ b/input/input-motionprediction/src/main/java/androidx/input/motionprediction/kalman/matrix/Matrix.java
@@ -230,27 +230,6 @@
* Calculates the matrix product of this matrix and {@code that}.
*
* @param that the other matrix
- * @return newly created matrix representing the matrix product of this and that
- * @throws IllegalArgumentException if the dimensions differ
- */
- public @NonNull Matrix dot(@NonNull Matrix that) {
- try {
- return dot(that, new Matrix(mRows, that.mCols));
- } catch (IllegalArgumentException e) {
- throw new IllegalArgumentException(
- String.format(
- Locale.ROOT,
- "The matrices dimensions are not conformant for a dot matrix "
- + "operation. this:%s that:%s",
- shortString(),
- that.shortString()));
- }
- }
-
- /**
- * Calculates the matrix product of this matrix and {@code that}.
- *
- * @param that the other matrix
* @param result matrix to hold the result
* @return result, filled with the matrix product
* @throws IllegalArgumentException if the dimensions differ
@@ -281,15 +260,26 @@
/**
* Calculates the inverse of a square matrix
*
+ * @param scratch the matrix [rows, 2*cols] to hold the temporary information
+ *
* @return newly created matrix representing the matrix inverse
* @throws ArithmeticException if the matrix is not invertible
*/
- public @NonNull Matrix inverse() {
+ public @NonNull Matrix inverse(@NonNull Matrix scratch) {
if (!(mRows == mCols)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "The matrix is not square. this:%s", shortString()));
}
- final Matrix scratch = new Matrix(mRows, 2 * mCols);
+
+ if (scratch.mRows != mRows || scratch.mCols != 2 * mCols) {
+ throw new IllegalArgumentException(
+ String.format(
+ Locale.ROOT,
+ "The scratch matrix size is not correct. this:%s",
+ scratch.shortString()
+ )
+ );
+ }
for (int i = 0; i < mRows; i++) {
for (int j = 0; j < mCols; j++) {
@@ -349,27 +339,6 @@
* Calculates the matrix product with the transpose of a second matrix.
*
* @param that the other matrix
- * @return newly created matrix representing the matrix product of this and that.transpose()
- * @throws IllegalArgumentException if shapes are not conformant
- */
- public @NonNull Matrix dotTranspose(@NonNull Matrix that) {
- try {
- return dotTranspose(that, new Matrix(mRows, that.mRows));
- } catch (IllegalArgumentException e) {
- throw new IllegalArgumentException(
- String.format(
- Locale.ROOT,
- "The matrices dimensions are not conformant for a transpose "
- + "operation. this:%s that:%s",
- shortString(),
- that.shortString()));
- }
- }
-
- /**
- * Calculates the matrix product with the transpose of a second matrix.
- *
- * @param that the other matrix
* @param result space to hold the result
* @return result, filled with the matrix product of this and that.transpose()
* @throws IllegalArgumentException if shapes are not conformant