Skip to content
6 changes: 4 additions & 2 deletions hooks/api/hooks.api
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ public abstract class com/intuit/hooks/AsyncParallelHook : com/intuit/hooks/Asyn

public abstract class com/intuit/hooks/AsyncSeriesBailHook : com/intuit/hooks/AsyncBaseHook {
public fun <init> ()V
protected final fun call (Lkotlin/jvm/functions/Function3;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
protected final fun call (Lkotlin/jvm/functions/Function3;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun call$default (Lcom/intuit/hooks/AsyncSeriesBailHook;Lkotlin/jvm/functions/Function3;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
}

public abstract class com/intuit/hooks/AsyncSeriesHook : com/intuit/hooks/AsyncBaseHook {
Expand Down Expand Up @@ -97,7 +98,8 @@ public final class com/intuit/hooks/LoopResult$Companion {

public abstract class com/intuit/hooks/SyncBailHook : com/intuit/hooks/SyncBaseHook {
public fun <init> ()V
protected final fun call (Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
protected final fun call (Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public static synthetic fun call$default (Lcom/intuit/hooks/SyncBailHook;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Ljava/lang/Object;
}

public abstract class com/intuit/hooks/SyncBaseHook : com/intuit/hooks/BaseHook {
Expand Down
5 changes: 3 additions & 2 deletions hooks/src/main/kotlin/com/intuit/hooks/AsyncSeriesBailHook.kt
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package com.intuit.hooks

public abstract class AsyncSeriesBailHook<F : Function<BailResult<R>>, R> : AsyncBaseHook<F>("AsyncSeriesBailHook") {
protected suspend fun call(invokeWithContext: suspend (F, HookContext) -> BailResult<R>): R? {
protected suspend fun call(invokeWithContext: suspend (F, HookContext) -> BailResult<R>, default: (suspend (HookContext) -> R)? = null): R? {
val context = setup(invokeWithContext)

taps.forEach { tapInfo ->
when (val result = invokeWithContext(tapInfo.f, context)) {
is BailResult.Bail<R> -> return@call result.value
is BailResult.Continue -> {}
}
}

return null
return default?.invoke(context)
}
}
5 changes: 3 additions & 2 deletions hooks/src/main/kotlin/com/intuit/hooks/SyncBailHook.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ public sealed class BailResult<T> {
}

public abstract class SyncBailHook<F : Function<BailResult<R>>, R> : SyncBaseHook<F>("SyncBailHook") {
protected fun call(invokeWithContext: (F, HookContext) -> BailResult<R>): R? {
protected fun call(invokeWithContext: (F, HookContext) -> BailResult<R>, default: ((HookContext) -> R)? = null): R? {
val context = setup(invokeWithContext)

taps.forEach { tapInfo ->
when (val result = invokeWithContext(tapInfo.f, context)) {
is BailResult.Bail<R> -> return@call result.value
is BailResult.Continue -> {}
}
}

return null
return default?.invoke(context)
}
}
40 changes: 37 additions & 3 deletions hooks/src/test/kotlin/com/intuit/hooks/AsyncSeriesBailHookTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@ import org.junit.jupiter.api.Test

class AsyncSeriesBailHookTests {
class Hook1<T1, R : Any?> : AsyncSeriesBailHook<suspend (HookContext, T1) -> BailResult<R>, R>() {
suspend fun call(p1: T1): R? = super.call { f, context -> f(context, p1) }
suspend fun call(p1: T1, default: (suspend (HookContext, T1) -> R)? = null): R? = super.call(
{ f, context -> f(context, p1) },
default?.let {
{ context -> default(context, p1) }
}
)

suspend fun call(p1: T1, default: (suspend (T1) -> R)) = call(p1) { _, arg1 ->
default.invoke(arg1)
}
}

@Test
Expand Down Expand Up @@ -41,12 +50,37 @@ class AsyncSeriesBailHookTests {
}

@Test
fun `bail taps can bail without return value`() {
val h = SyncBailHookTests.Hook1<String, Unit>()
fun `bail taps can bail without return value`() = runBlocking {
val h = Hook1<String, Unit>()
h.tap("continue") { _, _ -> BailResult.Continue() }
h.tap("bail") { _, _ -> BailResult.Bail(Unit) }
h.tap("continue again") { _, _ -> Assertions.fail("Should never have gotten here!") }

Assertions.assertEquals(Unit, h.call("David"))
}

@Test
fun `bail call with default handler invokes without taps bailing`() = runBlocking {
val h = Hook1<String, String>()
h.tap("continue") { _, _ -> BailResult.Continue() }
h.tap("continue again") { _, _ -> BailResult.Continue() }

val result = h.call("David") { _, str ->
str
}

Assertions.assertEquals("David", result)
}

@Test
fun `bail call with default handler does not invoke with bail`() = runBlocking {
val h = Hook1<String, String>()
h.tap("continue") { _, _ -> BailResult.Continue() }
h.tap("bail") { _, _ -> BailResult.Bail("bailing") }
h.tap("continue again") { _, _ -> Assertions.fail("Should never have gotten here!") }

val result = h.call("David") { str -> str }

Assertions.assertEquals("bailing", result)
}
}
36 changes: 35 additions & 1 deletion hooks/src/test/kotlin/com/intuit/hooks/SyncBailHookTests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@ import org.junit.jupiter.api.Test

class SyncBailHookTests {
class Hook1<T1, R : Any?> : SyncBailHook<(HookContext, T1) -> BailResult<R>, R>() {
fun call(p1: T1) = super.call { f, context -> f(context, p1) }
fun call(p1: T1, default: ((HookContext, T1) -> R)? = null) = super.call(
{ f, context -> f(context, p1) },
default?.let {
{ context -> default(context, p1) }
}
)

fun call(p1: T1, default: ((T1) -> R)) = call(p1) { _, arg1 ->
default.invoke(arg1)
}
}

@Test
Expand Down Expand Up @@ -51,4 +60,29 @@ class SyncBailHookTests {

Assertions.assertEquals(Unit, h.call("David"))
}

@Test
fun `bail call with default handler invokes without taps bailing`() {
val h = Hook1<String, String>()
h.tap("continue") { _, _ -> BailResult.Continue() }
h.tap("continue again") { _, _ -> BailResult.Continue() }

val result = h.call("David") { _, str ->
str
}

Assertions.assertEquals("David", result)
}

@Test
fun `bail call with default handler does not invoke with bail`() {
val h = Hook1<String, String>()
h.tap("continue") { _, _ -> BailResult.Continue() }
h.tap("bail") { _, _ -> BailResult.Bail("bailing") }
h.tap("continue again") { _, _ -> Assertions.fail("Should never have gotten here!") }

val result = h.call("David") { str -> str }

Assertions.assertEquals("bailing", result)
}
}
59 changes: 40 additions & 19 deletions processor/src/main/kotlin/com/intuit/hooks/plugin/codegen/Poet.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,24 @@ private fun HooksContainer.generateContainerClass(): TypeSpec {
}.build()
}

internal val HookInfo.callBuilder get() = FunSpec.builder("call")
.addParameters(parameterSpecs)
.apply {
if (isAsync)
addModifiers(KModifier.SUSPEND)
}

internal fun HookInfo.generateClass(): TypeSpec {
val callBuilder = FunSpec.builder("call")
.addParameters(parameterSpecs)
.apply {
if ([email protected])
addModifiers(KModifier.SUSPEND)
}

val (superclass, call) = when (hookType) {
val (superclass, calls) = when (hookType) {
HookType.SyncHook, HookType.AsyncSeriesHook, HookType.AsyncParallelHook -> {
val superclass = createSuperClass()

val call = callBuilder
.returns(UNIT)
.addStatement("return super.call { f, context -> f(context, $paramsWithoutTypes) }")

Pair(superclass, call)
Pair(superclass, listOf(call))
}
HookType.SyncLoopHook, HookType.AsyncSeriesLoopHook -> {
val superclass = createSuperClass(interceptParameter)
Expand All @@ -69,7 +70,7 @@ internal fun HookInfo.generateClass(): TypeSpec {
CodeBlock.of("{ f, context -> f(context, $paramsWithoutTypes) }")
)

Pair(superclass, call)
Pair(superclass, listOf(call))
}
HookType.SyncWaterfallHook, HookType.AsyncSeriesWaterfallHook -> {
val superclass = createSuperClass(params.first().type)
Expand All @@ -84,31 +85,51 @@ internal fun HookInfo.generateClass(): TypeSpec {
CodeBlock.of("{ f, context -> f(context, $paramsWithoutTypes) }")
)

Pair(superclass, call)
Pair(superclass, listOf(call))
}
HookType.SyncBailHook, HookType.AsyncSeriesBailHook -> {
requireNotNull(hookSignature.nullableReturnTypeType)
val superclass = createSuperClass(hookSignature.returnTypeType)

val call = callBuilder
.addParameter(
ParameterSpec.builder(
"default",
LambdaTypeName.get(
parameters = parameterSpecs,
returnType = hookSignature.returnTypeType!!
)
).build()
)
.returns(hookSignature.nullableReturnTypeType)
.addStatement("return super.call { f, context -> f(context, $paramsWithoutTypes) }")
.addStatement("return call ($paramsWithoutTypes) { _, arg1 -> default.invoke(arg1) }")

val contextCall = callBuilder
.addParameter(
ParameterSpec.builder(
"default",
createHookContextLambda(hookSignature.returnTypeType).copy(nullable = true)
).defaultValue(CodeBlock.of("null")).build()
)
.returns(hookSignature.nullableReturnTypeType)
.addStatement("return super.call ({ f, context -> f(context, $paramsWithoutTypes) }, default?.let { { context -> default(context, $paramsWithoutTypes) } } )")

Pair(superclass, call)
Pair(superclass, listOf(call, contextCall))
}
// parallel bail requires the concurrency parameter, otherwise it would be just like the other bail hooks
HookType.AsyncParallelBailHook -> {
requireNotNull(hookSignature.nullableReturnTypeType)
val superclass = createSuperClass(hookSignature.returnTypeType)

// force the concurrency parameter to be first
callBuilder.parameters.add(0, ParameterSpec("concurrency", INT))
val call = with(callBuilder) {
// force the concurrency parameter to be first
parameters.add(0, ParameterSpec("concurrency", INT))

val call = callBuilder
.returns(hookSignature.nullableReturnTypeType)
.addStatement("return super.call(concurrency) { f, context -> f(context, $paramsWithoutTypes) }")
returns(hookSignature.nullableReturnTypeType)
.addStatement("return super.call(concurrency) { f, context -> f(context, $paramsWithoutTypes) }")
}

Pair(superclass, call)
Pair(superclass, listOf(call))
}
}

Expand All @@ -117,7 +138,7 @@ internal fun HookInfo.generateClass(): TypeSpec {
addFunctions(tapMethods)
hookType.addedAnnotation?.let(::addAnnotation)
superclass(superclass)
addFunction(call.build())
addFunctions(calls.map { it.build() })
}.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,44 @@ class HooksProcessorTest {
result.runCompiledAssertions()
}

@Test fun `generates bail hook class`() {
val testHooks = SourceFile.kotlin(
"TestBailHooks.kt",
"""
import com.intuit.hooks.BailResult
import com.intuit.hooks.Hook
import com.intuit.hooks.dsl.Hooks

internal abstract class TestBailHooks : Hooks() {
@SyncBail<(String) -> BailResult<String>>
abstract val testSyncBailHook: Hook
}
"""
)

val assertions = SourceFile.kotlin(
"Assertions.kt",
"""
import com.intuit.hooks.BailResult
import org.junit.jupiter.api.Assertions.*

fun testHook() {
val hooks = TestBailHooksImpl()
hooks.testSyncBailHook.tap("test") { _, _ -> BailResult.Continue() }
val result = hooks.testSyncBailHook.call("hello") { ctx, str ->
str + " world"
}
assertEquals("hello world", result)
}
"""
)

val (compilation, result) = compile(testHooks, assertions)
result.assertOk()
compilation.assertKspGeneratedSources("TestBailHooksHooks.kt")
result.runCompiledAssertions()
}

@Test fun `generates nested hook class`() {
val testHooks = SourceFile.kotlin(
"TestHooks.kt",
Expand Down