Skip to content

Commit

Permalink
override default retry partition for AWS SDK clients to include region
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd committed Jan 9, 2025
1 parent 9f5b1c8 commit 0f15144
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeConfig
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsCustomization
import software.amazon.smithy.rust.codegen.core.smithy.generators.LibRsSection
Expand Down Expand Up @@ -58,6 +59,7 @@ class AwsFluentClientDecorator : ClientCodegenDecorator {
listOf(
AwsPresignedFluentBuilderMethod(codegenContext),
AwsFluentClientDocs(codegenContext),
AwsFluentClientRetryPartition(codegenContext),
).letIf(codegenContext.serviceShape.id == ShapeId.from("com.amazonaws.s3#AmazonS3")) {
it + S3ExpressFluentClientCustomization(codegenContext)
},
Expand Down Expand Up @@ -166,3 +168,27 @@ private class AwsFluentClientDocs(private val codegenContext: ClientCodegenConte
}
}
}

/**
* Replaces the default retry partition for all operations to include the AWS region if set
*/
private class AwsFluentClientRetryPartition(private val codegenContext: ClientCodegenContext) : FluentClientCustomization() {
override fun section(section: FluentClientSection): Writable {
return when (section) {
is FluentClientSection.BeforeBaseClientPluginSetup ->
writable {
rustTemplate(
"""
let default_retry_partition = match config.region() {
Some(region) => #{Cow}::from(format!("{default_retry_partition}-{}", region)),
None => #{Cow}::from(default_retry_partition),
};
""",
*preludeScope,
"Cow" to RuntimeType.Cow,
)
}
else -> emptySection
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.core.rustlang.CargoDependency
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest

class RetryPartitionTest {
@Test
fun `default retry partition`() {
awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { ctx, rustCrate ->
val rc = ctx.runtimeConfig
val codegenScope =
arrayOf(
*RuntimeType.preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
"capture_test_logs" to
CargoDependency.smithyRuntimeTestUtil(rc).toType()
.resolve("test_util::capture_test_logs::capture_test_logs"),
"Credentials" to
AwsRuntimeType.awsCredentialTypesTestUtil(rc)
.resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
)

rustCrate.integrationTest("default_retry_partition") {
tokioTest("default_retry_partition_includes_region") {
val moduleName = ctx.moduleUseName()
rustTemplate(
"""
let (_logs, logs_rx) = #{capture_test_logs}();
let (http_client, _rx) = #{capture_request}(#{None});
let client_config = $moduleName::Config::builder()
.http_client(http_client)
.region(#{Region}::new("us-west-2"))
.credentials_provider(#{Credentials}::for_tests())
.build();
let client = $moduleName::Client::from_conf(client_config);
let _ = client
.some_operation()
.send()
.await
.expect("success");
let log_contents = logs_rx.contents();
assert!(log_contents.contains("token bucket for RetryPartition { name: \"dontcare-us-west-2\" } added to config bag"));
""",
*codegenScope,
)
}

tokioTest("user_config_retry_partition") {
val moduleName = ctx.moduleUseName()
rustTemplate(
"""
let (_logs, logs_rx) = #{capture_test_logs}();
let (http_client, _rx) = #{capture_request}(#{None});
let client_config = $moduleName::Config::builder()
.http_client(http_client)
.region(#{Region}::new("us-west-2"))
.credentials_provider(#{Credentials}::for_tests())
.retry_partition(#{RetryPartition}::new("user-partition"))
.build();
let client = $moduleName::Client::from_conf(client_config);
let _ = client
.some_operation()
.send()
.await
.expect("success");
let log_contents = logs_rx.contents();
assert!(log_contents.contains("token bucket for RetryPartition { name: \"user-partition\" } added to config bag"));
""",
*codegenScope,
"RetryPartition" to RuntimeType.smithyRuntime(ctx.runtimeConfig).resolve("client::retries::RetryPartition"),
)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ sealed class FluentClientSection(name: String) : Section(name) {
/** Write custom code for adding additional client plugins to base_client_runtime_plugins */
data class AdditionalBaseClientPlugins(val plugins: String, val config: String) :
FluentClientSection("AdditionalBaseClientPlugins")

/** Write additional code before plugins are configured */
data class BeforeBaseClientPluginSetup(val config: String) :
FluentClientSection("BeforeBaseClientPluginSetup")
}

abstract class FluentClientCustomization : NamedCustomization<FluentClientSection>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,14 @@ private fun baseClientRuntimePluginsFn(
::std::mem::swap(&mut config.runtime_plugins, &mut configured_plugins);
#{update_bmv}
let default_retry_partition = ${codegenContext.serviceShape.sdkId().dq()};
#{before_plugin_setup}
let mut plugins = #{RuntimePlugins}::new()
// defaults
.with_client_plugins(#{default_plugins}(
#{DefaultPluginParams}::new()
.with_retry_partition_name(${codegenContext.serviceShape.sdkId().dq()})
.with_retry_partition_name(default_retry_partition)
.with_behavior_version(config.behavior_version.expect(${behaviorVersionError.dq()}))
))
// user config
Expand Down Expand Up @@ -299,6 +302,13 @@ private fun baseClientRuntimePluginsFn(
FluentClientSection.AdditionalBaseClientPlugins("plugins", "config"),
)
},
"before_plugin_setup" to
writable {
writeCustomizations(
customizations,
FluentClientSection.BeforeBaseClientPluginSetup("config"),
)
},
"DefaultPluginParams" to rt.resolve("client::defaults::DefaultPluginParams"),
"default_plugins" to rt.resolve("client::defaults::default_plugins"),
"NoAuthRuntimePlugin" to rt.resolve("client::auth::no_auth::NoAuthRuntimePlugin"),
Expand Down

0 comments on commit 0f15144

Please sign in to comment.