Skip to content

Commit 7cecce1

Browse files
committed
feat(ai): introduce Schema.fromEnum<E>()
1 parent f2257e4 commit 7cecce1

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

firebase-ai/api.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,9 @@ package com.google.firebase.ai.type {
774774
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values);
775775
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null);
776776
method public static com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false);
777+
method public static <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass);
778+
method public static <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass, String? description = null);
779+
method public static <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass, String? description = null, boolean nullable = false);
777780
method public String? getDescription();
778781
method public java.util.List<java.lang.String>? getEnum();
779782
method public String? getFormat();
@@ -823,6 +826,10 @@ package com.google.firebase.ai.type {
823826
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values);
824827
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null);
825828
method public com.google.firebase.ai.type.Schema enumeration(java.util.List<java.lang.String> values, String? description = null, boolean nullable = false);
829+
method public <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass);
830+
method public <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass, String? description = null);
831+
method public <E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(Class<E> enumClass, String? description = null, boolean nullable = false);
832+
method public inline <reified E extends java.lang.Enum<E>> com.google.firebase.ai.type.Schema fromEnum(String? description = null, boolean nullable = false);
826833
method public com.google.firebase.ai.type.Schema numDouble();
827834
method public com.google.firebase.ai.type.Schema numDouble(String? description = null);
828835
method public com.google.firebase.ai.type.Schema numDouble(String? description = null, boolean nullable = false);

firebase-ai/src/main/kotlin/com/google/firebase/ai/type/Schema.kt

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.firebase.ai.type
1818

19+
import java.util.EnumSet
1920
import kotlinx.serialization.Serializable
2021

2122
public abstract class StringFormat private constructor(internal val value: String) {
@@ -179,7 +180,11 @@ internal constructor(
179180
): Schema {
180181
if (!properties.keys.containsAll(optionalProperties)) {
181182
throw IllegalArgumentException(
182-
"All optional properties must be present in properties. Missing: ${optionalProperties.minus(properties.keys)}"
183+
"All optional properties must be present in properties. Missing: ${
184+
optionalProperties.minus(
185+
properties.keys
186+
)
187+
}"
183188
)
184189
}
185190
return Schema(
@@ -239,6 +244,49 @@ internal constructor(
239244
enum = values,
240245
type = "STRING",
241246
)
247+
248+
/**
249+
* Returns a [Schema] for the given Kotlin Enum.
250+
*
251+
* For example, the cardinal directions can be represented as:
252+
*
253+
* ```
254+
* enum class CardinalDirection { NORTH, EAST, SOUTH, WEST }
255+
*
256+
* Schema.fromEnum<CardinalDirection>()
257+
* ```
258+
*
259+
* @param description The description of what the parameter should contain or represent
260+
* @param nullable Indicates whether the value can be `null`. Defaults to `false`.
261+
*/
262+
@JvmOverloads
263+
public inline fun <reified E : Enum<E>> fromEnum(
264+
description: String? = null,
265+
nullable: Boolean = false
266+
): Schema = enumeration(enumValues<E>().map { it.toString() }, description, nullable)
267+
268+
/**
269+
* Returns a [Schema] for the given Java Enum.
270+
*
271+
* For example, the cardinal directions can be represented as:
272+
*
273+
* ```
274+
* enum CardinalDirection { NORTH, EAST, SOUTH, WEST }
275+
*
276+
* Schema.fromEnum(CardinalDirection.class);
277+
* ```
278+
*
279+
* @param enumClass The Enum's Java class.
280+
* @param description The description of what the parameter should contain or represent
281+
* @param nullable Indicates whether the value can be `null`. Defaults to `false`.
282+
*/
283+
@JvmStatic
284+
@JvmOverloads
285+
public fun <E : Enum<E>> fromEnum(
286+
enumClass: Class<E>,
287+
description: String? = null,
288+
nullable: Boolean = false
289+
): Schema = enumeration(EnumSet.allOf(enumClass).map { it.name }, description, nullable)
242290
}
243291

244292
internal fun toInternal(): Internal =
@@ -252,6 +300,7 @@ internal constructor(
252300
required,
253301
items?.toInternal(),
254302
)
303+
255304
@Serializable
256305
internal data class Internal(
257306
val type: String,

firebase-ai/src/test/java/com/google/firebase/ai/SchemaTests.kt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,42 @@ internal class SchemaTests {
218218

219219
Json.encodeToString(schemaDeclaration.toInternal()).shouldEqualJson(expectedJson)
220220
}
221+
222+
enum class TestEnum {
223+
BASIC,
224+
INTERMEDIATE,
225+
ADVANCED
226+
}
227+
228+
@Test
229+
fun `basic Kotlin enum class`() {
230+
val schema = Schema.fromEnum<TestEnum>()
231+
val expectedJson =
232+
"""
233+
{
234+
"type": "STRING",
235+
"format": "enum",
236+
"enum": ["BASIC", "INTERMEDIATE", "ADVANCED"]
237+
}
238+
"""
239+
.trimIndent()
240+
241+
Json.encodeToString(schema.toInternal()).shouldEqualJson(expectedJson)
242+
}
243+
244+
@Test
245+
fun `basic Java enum`() {
246+
val schema = Schema.fromEnum(TestEnum::class.java)
247+
val expectedJson =
248+
"""
249+
{
250+
"type": "STRING",
251+
"format": "enum",
252+
"enum": ["BASIC", "INTERMEDIATE", "ADVANCED"]
253+
}
254+
"""
255+
.trimIndent()
256+
257+
Json.encodeToString(schema.toInternal()).shouldEqualJson(expectedJson)
258+
}
221259
}

0 commit comments

Comments
 (0)