Merge "Correctly record capture scope of local functions" into androidx-main
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
index 2918915..e140173 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/ComposerParamSignatureTests.kt
@@ -168,6 +168,25 @@
)
@Test
+ fun testCaptureIssue23(): Unit = codegen(
+ """
+ import androidx.compose.animation.AnimatedContent
+ import androidx.compose.animation.ExperimentalAnimationApi
+ import androidx.compose.runtime.Composable
+
+ @OptIn(ExperimentalAnimationApi::class)
+ @Composable
+ fun SimpleAnimatedContentSample() {
+ @Composable fun Foo() {}
+
+ AnimatedContent(1f) {
+ Foo()
+ }
+ }
+ """
+ )
+
+ @Test
fun test32Params(): Unit = codegen(
"""
@Composable
diff --git a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
index 732e8bf..d37af66 100644
--- a/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
+++ b/compose/compiler/compiler-hosted/integration-tests/src/test/java/androidx/compose/compiler/plugins/kotlin/LambdaMemoizationTransformTests.kt
@@ -109,6 +109,171 @@
"""
)
+ // Fixes b/201252574
+ @Test
+ fun testLocalFunCaptures(): Unit = verifyComposeIrTransform(
+ """
+ import androidx.compose.runtime.NonRestartableComposable
+ import androidx.compose.runtime.Composable
+
+ @NonRestartableComposable
+ @Composable
+ fun Err() {
+ // `x` is not a capture of handler, but is treated as such.
+ fun handler() {
+ { x: Int -> x }
+ }
+ // Lambda calling handler. To find captures, we need captures of `handler`.
+ {
+ handler()
+ }
+ }
+ """,
+ """
+ @NonRestartableComposable
+ @Composable
+ fun Err(%composer: Composer?, %changed: Int) {
+ %composer.startReplaceableGroup(<>)
+ sourceInformation(%composer, "C(Err):Test.kt")
+ fun handler() {
+ { x: Int ->
+ x
+ }
+ }
+ {
+ handler()
+ }
+ %composer.endReplaceableGroup()
+ }
+ """,
+ """
+ """
+ )
+
+ @Test
+ fun testLocalClassCaptures1(): Unit = verifyComposeIrTransform(
+ """
+ import androidx.compose.runtime.NonRestartableComposable
+ import androidx.compose.runtime.Composable
+
+ @NonRestartableComposable
+ @Composable
+ fun Err(y: Int, z: Int) {
+ class Local {
+ val w = z
+ fun something(x: Int): Int { return x + y + w }
+ }
+ {
+ Local().something(2)
+ }
+ }
+ """,
+ """
+ @NonRestartableComposable
+ @Composable
+ fun Err(y: Int, z: Int, %composer: Composer?, %changed: Int) {
+ %composer.startReplaceableGroup(<>)
+ sourceInformation(%composer, "C(Err)<{>:Test.kt")
+ class Local {
+ val w: Int = z
+ fun something(x: Int): Int {
+ return x + y + w
+ }
+ }
+ remember(y, z, {
+ {
+ Local().something(2)
+ }
+ }, %composer, 0)
+ %composer.endReplaceableGroup()
+ }
+ """,
+ """
+ """
+ )
+
+ @Test
+ fun testLocalClassCaptures2(): Unit = verifyComposeIrTransform(
+ """
+ import androidx.compose.runtime.Composable
+ import androidx.compose.runtime.NonRestartableComposable
+
+ @NonRestartableComposable
+ @Composable
+ fun Example(z: Int) {
+ class Foo(val x: Int) { val y = z }
+ val lambda: () -> Any = {
+ Foo(1)
+ }
+ }
+ """,
+ """
+ @NonRestartableComposable
+ @Composable
+ fun Example(z: Int, %composer: Composer?, %changed: Int) {
+ %composer.startReplaceableGroup(<>)
+ sourceInformation(%composer, "C(Example)<{>:Test.kt")
+ class Foo(val x: Int) {
+ val y: Int = z
+ }
+ val lambda = remember(z, {
+ {
+ Foo(1)
+ }
+ }, %composer, 0)
+ %composer.endReplaceableGroup()
+ }
+ """,
+ """
+ """
+ )
+
+ @Test
+ fun testLocalFunCaptures3(): Unit = verifyComposeIrTransform(
+ """
+ import androidx.compose.animation.AnimatedContent
+ import androidx.compose.animation.ExperimentalAnimationApi
+ import androidx.compose.runtime.Composable
+
+ @OptIn(ExperimentalAnimationApi::class)
+ @Composable
+ fun SimpleAnimatedContentSample() {
+ @Composable fun Foo() {}
+
+ AnimatedContent(1f) {
+ Foo()
+ }
+ }
+ """,
+ """
+ @OptIn(markerClass = ExperimentalAnimationApi::class)
+ @Composable
+ fun SimpleAnimatedContentSample(%composer: Composer?, %changed: Int) {
+ %composer = %composer.startRestartGroup(<>)
+ sourceInformation(%composer, "C(SimpleAnimatedContentSample)<Animat...>:Test.kt")
+ if (%changed !== 0 || !%composer.skipping) {
+ @Composable
+ fun Foo(%composer: Composer?, %changed: Int) {
+ %composer.startReplaceableGroup(<>)
+ sourceInformation(%composer, "C(Foo):Test.kt")
+ %composer.endReplaceableGroup()
+ }
+ AnimatedContent(1.0f, null, null, null, composableLambda(%composer, <>, false) { it: Float, %composer: Composer?, %changed: Int ->
+ sourceInformation(%composer, "C<Foo()>:Test.kt")
+ Foo(%composer, 0)
+ }, %composer, 0b0110000000000110, 0b1110)
+ } else {
+ %composer.skipToGroupEnd()
+ }
+ %composer.endRestartGroup()?.updateScope { %composer: Composer?, %force: Int ->
+ SimpleAnimatedContentSample(%composer, %changed or 0b0001)
+ }
+ }
+ """,
+ """
+ """
+ )
+
@Test
fun testStateDelegateCapture(): Unit = verifyComposeIrTransform(
"""
diff --git a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
index 04d8902..bb88a82 100644
--- a/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
+++ b/compose/compiler/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/lower/ComposerLambdaMemoization.kt
@@ -53,6 +53,7 @@
import org.jetbrains.kotlin.ir.builders.irTemporary
import org.jetbrains.kotlin.ir.declarations.IrAttributeContainer
import org.jetbrains.kotlin.ir.declarations.IrClass
+import org.jetbrains.kotlin.ir.declarations.IrConstructor
import org.jetbrains.kotlin.ir.declarations.IrDeclarationBase
import org.jetbrains.kotlin.ir.declarations.IrDeclarationOrigin
import org.jetbrains.kotlin.ir.declarations.IrFile
@@ -63,6 +64,7 @@
import org.jetbrains.kotlin.ir.declarations.IrVariable
import org.jetbrains.kotlin.ir.declarations.copyAttributes
import org.jetbrains.kotlin.ir.expressions.IrCall
+import org.jetbrains.kotlin.ir.expressions.IrConstructorCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression
import org.jetbrains.kotlin.ir.expressions.IrFunctionExpression
@@ -97,53 +99,85 @@
private class CaptureCollector {
val captures = mutableSetOf<IrValueDeclaration>()
- val capturedFunctions = mutableSetOf<IrFunction>()
- val hasCaptures: Boolean get() = captures.isNotEmpty() || capturedFunctions.isNotEmpty()
+ val capturedDeclarations = mutableSetOf<IrSymbolOwner>()
+ val hasCaptures: Boolean get() = captures.isNotEmpty() || capturedDeclarations.isNotEmpty()
fun recordCapture(local: IrValueDeclaration) {
captures.add(local)
}
- fun recordCapture(local: IrFunction) {
- capturedFunctions.add(local)
+ fun recordCapture(local: IrSymbolOwner) {
+ capturedDeclarations.add(local)
}
}
private abstract class DeclarationContext {
+ val localDeclarationCaptures = mutableMapOf<IrSymbolOwner, Set<IrValueDeclaration>>()
+ fun recordLocalDeclaration(local: DeclarationContext) {
+ localDeclarationCaptures[local.declaration] = local.captures
+ }
abstract val composable: Boolean
abstract val symbol: IrSymbol
+ abstract val declaration: IrSymbolOwner
+ abstract val captures: Set<IrValueDeclaration>
abstract val functionContext: FunctionContext?
abstract fun declareLocal(local: IrValueDeclaration?)
- abstract fun recordLocalFunction(local: FunctionContext)
- abstract fun recordCapture(local: IrValueDeclaration?)
- abstract fun recordCapture(local: IrFunction?)
+ abstract fun recordCapture(local: IrValueDeclaration?): Boolean
+ abstract fun recordCapture(local: IrSymbolOwner?)
abstract fun pushCollector(collector: CaptureCollector)
abstract fun popCollector(collector: CaptureCollector)
}
-private class SymbolOwnerContext(val declaration: IrSymbolOwner) : DeclarationContext() {
+private fun List<DeclarationContext>.recordCapture(value: IrValueDeclaration) {
+ for (dec in reversed()) {
+ val shouldBreak = dec.recordCapture(value)
+ if (shouldBreak) break
+ }
+}
+
+private fun List<DeclarationContext>.recordLocalDeclaration(local: DeclarationContext) {
+ for (dec in reversed()) {
+ dec.recordLocalDeclaration(local)
+ }
+}
+
+private fun List<DeclarationContext>.recordLocalCapture(local: IrSymbolOwner) {
+ val capturesForLocal = reversed().firstNotNullOfOrNull { it.localDeclarationCaptures[local] }
+ if (capturesForLocal != null) {
+ capturesForLocal.forEach { recordCapture(it) }
+ for (dec in reversed()) {
+ dec.recordCapture(local)
+ if (dec.localDeclarationCaptures.containsKey(local)) {
+ // this is the scope that the class was defined in, so above this we don't need
+ // to do anything
+ break
+ }
+ }
+ }
+}
+
+private class SymbolOwnerContext(override val declaration: IrSymbolOwner) : DeclarationContext() {
override val composable get() = false
override val functionContext: FunctionContext? get() = null
override val symbol get() = declaration.symbol
+ override val captures: Set<IrValueDeclaration> get() = emptySet()
override fun declareLocal(local: IrValueDeclaration?) { }
- override fun recordLocalFunction(local: FunctionContext) { }
- override fun recordCapture(local: IrValueDeclaration?) { }
- override fun recordCapture(local: IrFunction?) { }
+ override fun recordCapture(local: IrValueDeclaration?): Boolean { return false }
+ override fun recordCapture(local: IrSymbolOwner?) { }
override fun pushCollector(collector: CaptureCollector) { }
override fun popCollector(collector: CaptureCollector) { }
}
private class FunctionLocalSymbol(
- val declaration: IrSymbolOwner,
+ override val declaration: IrSymbolOwner,
override val functionContext: FunctionContext
) : DeclarationContext() {
override val composable: Boolean get() = functionContext.composable
override val symbol: IrSymbol get() = declaration.symbol
+ override val captures: Set<IrValueDeclaration> get() = functionContext.captures
override fun declareLocal(local: IrValueDeclaration?) = functionContext.declareLocal(local)
- override fun recordLocalFunction(local: FunctionContext) =
- functionContext.recordLocalFunction(local)
override fun recordCapture(local: IrValueDeclaration?) = functionContext.recordCapture(local)
- override fun recordCapture(local: IrFunction?) = functionContext.recordCapture(local)
+ override fun recordCapture(local: IrSymbolOwner?) = functionContext.recordCapture(local)
override fun pushCollector(collector: CaptureCollector) =
functionContext.pushCollector(collector)
override fun popCollector(collector: CaptureCollector) =
@@ -151,16 +185,15 @@
}
private class FunctionContext(
- val declaration: IrFunction,
+ override val declaration: IrFunction,
override val composable: Boolean,
val canRemember: Boolean
) : DeclarationContext() {
override val symbol get() = declaration.symbol
override val functionContext: FunctionContext? get() = this
val locals = mutableSetOf<IrValueDeclaration>()
- val captures = mutableSetOf<IrValueDeclaration>()
+ override val captures: MutableSet<IrValueDeclaration> = mutableSetOf()
var collectors = mutableListOf<CaptureCollector>()
- val localFunctionCaptures = mutableMapOf<IrFunction, Set<IrValueDeclaration>>()
init {
declaration.valueParameters.forEach {
@@ -176,26 +209,22 @@
}
}
- override fun recordLocalFunction(local: FunctionContext) {
- if (local.captures.isNotEmpty() && local.declaration.isLocal) {
- localFunctionCaptures[local.declaration] = local.captures
- }
- }
-
- override fun recordCapture(local: IrValueDeclaration?) {
- if (local != null && collectors.isNotEmpty() && locals.contains(local)) {
+ override fun recordCapture(local: IrValueDeclaration?): Boolean {
+ val containsLocal = locals.contains(local)
+ if (local != null && collectors.isNotEmpty() && containsLocal) {
for (collector in collectors) {
collector.recordCapture(local)
}
}
- if (local != null && declaration.isLocal && !locals.contains(local)) {
+ if (local != null && declaration.isLocal && !containsLocal) {
captures.add(local)
}
+ return containsLocal
}
- override fun recordCapture(local: IrFunction?) {
+ override fun recordCapture(local: IrSymbolOwner?) {
if (local != null) {
- val captures = localFunctionCaptures[local]
+ val captures = localDeclarationCaptures[local]
for (collector in collectors) {
collector.recordCapture(local)
if (captures != null) {
@@ -217,22 +246,28 @@
}
}
-private class ClassContext(val declaration: IrClass) : DeclarationContext() {
+private class ClassContext(override val declaration: IrClass) : DeclarationContext() {
override val composable: Boolean = false
override val symbol get() = declaration.symbol
override val functionContext: FunctionContext? = null
+ override val captures: MutableSet<IrValueDeclaration> = mutableSetOf()
val thisParam: IrValueDeclaration? = declaration.thisReceiver!!
var collectors = mutableListOf<CaptureCollector>()
override fun declareLocal(local: IrValueDeclaration?) { }
- override fun recordLocalFunction(local: FunctionContext) { }
- override fun recordCapture(local: IrValueDeclaration?) {
- if (local != null && collectors.isNotEmpty() && local == thisParam) {
+ override fun recordCapture(local: IrValueDeclaration?): Boolean {
+ val isThis = local == thisParam
+ val isCtorParam = (local?.parent as? IrConstructor)?.parent === declaration
+ if (local != null && collectors.isNotEmpty() && isThis) {
for (collector in collectors) {
collector.recordCapture(local)
}
}
+ if (local != null && declaration.isLocal && !isThis && !isCtorParam) {
+ captures.add(local)
+ }
+ return isThis || isCtorParam
}
- override fun recordCapture(local: IrFunction?) { }
+ override fun recordCapture(local: IrSymbolOwner?) { }
override fun pushCollector(collector: CaptureCollector) {
collectors.add(collector)
}
@@ -383,7 +418,7 @@
val result = super.visitFunction(declaration)
declarationContextStack.pop()
if (declaration.isLocal) {
- declarationContextStack.peek()?.recordLocalFunction(context)
+ declarationContextStack.recordLocalDeclaration(context)
}
return result
}
@@ -393,6 +428,9 @@
declarationContextStack.push(context)
val result = super.visitClass(declaration)
declarationContextStack.pop()
+ if (declaration.isLocal) {
+ declarationContextStack.recordLocalDeclaration(context)
+ }
return result
}
@@ -402,9 +440,7 @@
}
override fun visitValueAccess(expression: IrValueAccessExpression): IrExpression {
- declarationContextStack.forEach {
- it.recordCapture(expression.symbol.owner)
- }
+ declarationContextStack.recordCapture(expression.symbol.owner)
return super.visitValueAccess(expression)
}
@@ -514,13 +550,20 @@
override fun visitCall(expression: IrCall): IrExpression {
val fn = expression.symbol.owner
if (fn.isLocal) {
- declarationContextStack.forEach {
- it.recordCapture(fn)
- }
+ declarationContextStack.recordLocalCapture(fn)
}
return super.visitCall(expression)
}
+ override fun visitConstructorCall(expression: IrConstructorCall): IrExpression {
+ val fn = expression.symbol.owner
+ val cls = fn.parent as? IrClass
+ if (cls != null && fn.isLocal) {
+ declarationContextStack.recordLocalCapture(cls)
+ }
+ return super.visitConstructorCall(expression)
+ }
+
@ObsoleteDescriptorBasedAPI
private fun visitComposableFunctionExpression(
expression: IrFunctionExpression,