Skip to content

chore(client-s3): bucket contextParam customization for schema-serde mode #7250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen/sdk-codegen/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ tasks.register("generate-smithy-build") {
// e.g. "S3" - use this as exclusion list if needed.
)
val useSchemaSerde = setOf<String>(
// "CloudWatch Logs"
// "S3"
)
val projectionContents = Node.objectNodeBuilder()
.withMember("imports", Node.fromStrings("${models.getAbsolutePath()}${File.separator}${file.name}"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public AddProtocolConfig() {
SchemaGenerationAllowlist.allow("com.amazonaws.dynamodb#DynamoDB_20120810");
SchemaGenerationAllowlist.allow("com.amazonaws.lambda#AWSGirApiService");
SchemaGenerationAllowlist.allow("com.amazonaws.cloudwatchlogs#Logs_20140328");
SchemaGenerationAllowlist.allow("com.amazonaws.sts#AWSSecurityTokenServiceV20110615");

// protocol tests
SchemaGenerationAllowlist.allow("aws.protocoltests.json10#JsonRpc10");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import software.amazon.smithy.model.knowledge.OperationIndex;
import software.amazon.smithy.model.knowledge.TopDownIndex;
import software.amazon.smithy.model.pattern.SmithyPattern;
import software.amazon.smithy.model.pattern.UriPattern;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
Expand All @@ -47,9 +48,11 @@
import software.amazon.smithy.model.traits.EndpointTrait;
import software.amazon.smithy.model.traits.HttpHeaderTrait;
import software.amazon.smithy.model.traits.HttpPayloadTrait;
import software.amazon.smithy.model.traits.HttpTrait;
import software.amazon.smithy.model.traits.StreamingTrait;
import software.amazon.smithy.model.traits.Trait;
import software.amazon.smithy.model.transform.ModelTransformer;
import software.amazon.smithy.rulesengine.traits.ContextParamTrait;
import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait;
import software.amazon.smithy.typescript.codegen.LanguageTarget;
import software.amazon.smithy.typescript.codegen.TypeScriptDependency;
Expand All @@ -58,6 +61,7 @@
import software.amazon.smithy.typescript.codegen.auth.http.integration.AddHttpSigningPlugin;
import software.amazon.smithy.typescript.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.typescript.codegen.integration.TypeScriptIntegration;
import software.amazon.smithy.typescript.codegen.schema.SchemaGenerationAllowlist;
import software.amazon.smithy.utils.ListUtils;
import software.amazon.smithy.utils.MapUtils;
import software.amazon.smithy.utils.SetUtils;
Expand Down Expand Up @@ -113,6 +117,53 @@ public static Shape removeHostPrefixTrait(Shape shape) {
.orElse(shape);
}

/**
* Remove `/{Bucket}` from the operation endpoint URI IFF
* - it is in a prefix position.
* - input has a member called "Bucket".
* - "Bucket" input member is a contextParam.
*/
public static Shape removeUriBucketPrefix(Shape shape, Model model) {
return shape.asOperationShape()
.map(OperationShape::shapeToBuilder)
.map((Object object) -> {
OperationShape.Builder builder = (OperationShape.Builder) object;
Trait trait = builder.getAllTraits().get(HttpTrait.ID);
if (trait instanceof HttpTrait httpTrait) {
String uri = httpTrait.getUri().toString();

StructureShape input = model.expectShape(
shape.asOperationShape().get().getInputShape()
).asStructureShape().orElseThrow(
() -> new RuntimeException("operation must have input structure")
);

boolean hasBucketPrefix = uri.startsWith("/{Bucket}");
Optional<MemberShape> bucket = input.getMember("Bucket");
boolean inputHasBucketMember = bucket.isPresent();
boolean bucketIsContextParam = bucket
.map(ms -> ms.getTrait(ContextParamTrait.class))
.isPresent();

if (hasBucketPrefix && inputHasBucketMember && bucketIsContextParam) {
String replaced = uri
.replace("/{Bucket}/", "/")
.replace("/{Bucket}", "/");
builder.addTrait(
httpTrait
.toBuilder()
.uri(UriPattern.parse(replaced))
.build()
);
}
}
return builder;
})
.map(OperationShape.Builder::build)
.map(s -> (Shape) s)
.orElse(shape);
}

@Override
public List<String> runAfter() {
return List.of(
Expand Down Expand Up @@ -243,7 +294,12 @@ public Model preprocessModel(Model model, TypeScriptSettings settings) {
Model builtModel = modelBuilder.addShapes(inputShapes).build();
if (hasRuleset) {
return ModelTransformer.create().mapShapes(
builtModel, AddS3Config::removeHostPrefixTrait
builtModel, (shape) -> {
if (SchemaGenerationAllowlist.allows(serviceShape.getId(), settings)) {
return removeUriBucketPrefix(shape, model);
}
return shape;
}
);
}
return builtModel;
Expand Down
31 changes: 20 additions & 11 deletions packages/core/src/submodules/protocols/json/AwsJsonRpcProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,26 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol {
[namespace, errorName] = errorIdentifier.split("#");
}

const errorMetadata = {
$metadata: metadata,
$response: response,
$fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const),
};

const registry = TypeRegistry.for(namespace);
let errorSchema: ErrorSchema;
try {
errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema;
} catch (e) {
if (dataObject.Message) {
dataObject.message = dataObject.Message;
}
const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException();
if (baseExceptionSchema) {
const ErrorCtor = baseExceptionSchema.ctor;
throw Object.assign(new ErrorCtor(errorName), dataObject);
throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject);
}
throw new Error(errorName);
throw Object.assign(new Error(errorName), errorMetadata, dataObject);
}

const ns = NormalizedSchema.of(errorSchema);
Expand All @@ -109,14 +118,14 @@ export abstract class AwsJsonRpcProtocol extends RpcProtocol {
output[name] = this.codec.createDeserializer().readObject(member, dataObject[target]);
}

Object.assign(exception, {
$metadata: metadata,
$response: response,
$fault: ns.getMergedTraits().error,
message,
...output,
});

throw exception;
throw Object.assign(
exception,
errorMetadata,
{
$fault: ns.getMergedTraits().error,
message,
},
output
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,26 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol {
[namespace, errorName] = errorIdentifier.split("#");
}

const errorMetadata = {
$metadata: metadata,
$response: response,
$fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const),
};

const registry = TypeRegistry.for(namespace);
let errorSchema: ErrorSchema;
try {
errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema;
} catch (e) {
if (dataObject.Message) {
dataObject.message = dataObject.Message;
}
const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException();
if (baseExceptionSchema) {
const ErrorCtor = baseExceptionSchema.ctor;
throw Object.assign(new ErrorCtor(errorName), dataObject);
throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject);
}
throw new Error(errorName);
throw Object.assign(new Error(errorName), errorMetadata, dataObject);
}

const ns = NormalizedSchema.of(errorSchema);
Expand All @@ -147,15 +156,15 @@ export class AwsRestJsonProtocol extends HttpBindingProtocol {
output[name] = this.codec.createDeserializer().readObject(member, dataObject[target]);
}

Object.assign(exception, {
$metadata: metadata,
$response: response,
$fault: ns.getMergedTraits().error,
message,
...output,
});

throw exception;
throw Object.assign(
exception,
errorMetadata,
{
$fault: ns.getMergedTraits().error,
message,
},
output
);
}

/**
Expand Down
34 changes: 21 additions & 13 deletions packages/core/src/submodules/protocols/query/AwsQueryProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ export class AwsQueryProtocol extends RpcProtocol {
[namespace, errorName] = errorIdentifier.split("#");
}

const errorDataSource = this.loadQueryError(dataObject);
const errorData = this.loadQueryError(dataObject);
const errorMetadata = {
$metadata: metadata,
$response: response,
$fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const),
};

const registry = TypeRegistry.for(namespace);
let errorSchema: ErrorSchema;
Expand All @@ -159,12 +164,15 @@ export class AwsQueryProtocol extends RpcProtocol {
errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema;
}
} catch (e) {
if (errorData.Message) {
errorData.message = errorData.Message;
}
const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException();
if (baseExceptionSchema) {
const ErrorCtor = baseExceptionSchema.ctor;
throw Object.assign(new ErrorCtor(errorName), errorDataSource);
throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject);
}
throw new Error(errorName);
throw Object.assign(new Error(errorName), errorMetadata, errorData);
}

const ns = NormalizedSchema.of(errorSchema);
Expand All @@ -175,19 +183,19 @@ export class AwsQueryProtocol extends RpcProtocol {

for (const [name, member] of ns.structIterator()) {
const target = member.getMergedTraits().xmlName ?? name;
const value = errorDataSource[target] ?? dataObject[target];
const value = errorData[target] ?? dataObject[target];
output[name] = this.deserializer.readSchema(member, value);
}

Object.assign(exception, {
$metadata: metadata,
$response: response,
$fault: ns.getMergedTraits().error,
message,
...output,
});

throw exception;
throw Object.assign(
exception,
errorMetadata,
{
$fault: ns.getMergedTraits().error,
message,
},
output
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ describe(AwsRestXmlProtocol.name, () => {
},
expected: {
request: {
path: "/",
// S3 customization not active here since this is a mock.
// customization does model preprocessing to remove /{Bucket} prefix
// when it is a contextParam.
path: "/{Bucket}",
method: "POST",
headers: {
"content-type": "application/xml",
Expand Down
42 changes: 21 additions & 21 deletions packages/core/src/submodules/protocols/xml/AwsRestXmlProtocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,6 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol {
const ns = NormalizedSchema.of(operationSchema.input);
const members = ns.getMemberSchemas();

request.path =
String(request.path)
.split("/")
.filter((segment) => {
// for legacy reasons,
// Bucket is in the http trait but is handled by endpoints ruleset.
return segment !== "{Bucket}";
})
.join("/") || "/";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this hack, replaced by model preprocessing transform


if (!request.headers["content-type"]) {
const httpPayloadMember = Object.values(members).find((m) => {
return !!m.getMergedTraits().httpPayload;
Expand Down Expand Up @@ -136,17 +126,27 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol {
[namespace, errorName] = errorIdentifier.split("#");
}

const errorMetadata = {
$metadata: metadata,
$response: response,
$fault: response.statusCode <= 500 ? ("client" as const) : ("server" as const),
};

const registry = TypeRegistry.for(namespace);

let errorSchema: ErrorSchema;
try {
errorSchema = registry.getSchema(errorIdentifier) as ErrorSchema;
} catch (e) {
if (dataObject.Message) {
dataObject.message = dataObject.Message;
}
const baseExceptionSchema = TypeRegistry.for("smithy.ts.sdk.synthetic." + namespace).getBaseException();
if (baseExceptionSchema) {
const ErrorCtor = baseExceptionSchema.ctor;
throw Object.assign(new ErrorCtor(errorName), dataObject);
throw Object.assign(new ErrorCtor({ name: errorName }), errorMetadata, dataObject);
}
throw new Error(errorName);
throw Object.assign(new Error(errorName), errorMetadata, dataObject);
}

const ns = NormalizedSchema.of(errorSchema);
Expand All @@ -162,15 +162,15 @@ export class AwsRestXmlProtocol extends HttpBindingProtocol {
output[name] = this.codec.createDeserializer().readSchema(member, value);
}

Object.assign(exception, {
$metadata: metadata,
$response: response,
$fault: ns.getMergedTraits().error,
message,
...output,
});

throw exception;
throw Object.assign(
exception,
errorMetadata,
{
$fault: ns.getMergedTraits().error,
message,
},
output
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ export class XmlShapeSerializer extends SerdeContextConfig implements ShapeSeria

const [xmlnsAttr, xmlns] = this.getXmlnsAttribute(ns, parentXmlns);

if (xmlns) {
structXmlNode.addAttribute(xmlnsAttr as string, xmlns);
}

for (const [memberName, memberSchema] of ns.structIterator()) {
const val = (value as any)[memberName];

Expand All @@ -108,6 +104,10 @@ export class XmlShapeSerializer extends SerdeContextConfig implements ShapeSeria
}
}

if (xmlns) {
structXmlNode.addAttribute(xmlnsAttr as string, xmlns);
}

return structXmlNode;
}

Expand Down
2 changes: 1 addition & 1 deletion packages/middleware-logger/src/loggerMiddleware.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import {
import type {
AbsoluteLocation,
HandlerExecutionContext,
InitializeHandler,
Expand Down
Loading
Loading