From 0c8c0b2ccbe113fcbdb0d96af8e4fafc6618c186 Mon Sep 17 00:00:00 2001 From: Sebastien Ros Date: Fri, 1 May 2026 17:57:40 -0700 Subject: [PATCH 1/3] Fix polyglot AppHost callback and builder generation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../AtsGoCodeGenerator.cs | 30 ++++- .../Resources/base.go | 61 +++++++++ .../AtsJavaCodeGenerator.cs | 2 +- .../Resources/Transport.java | 69 +++++++++- .../PythonModuleBuilder.cs | 9 ++ .../Ats/AtsCallbackProxyFactory.cs | 7 +- .../Ats/AtsMarshaller.cs | 70 +++++++++++ .../AtsGoCodeGeneratorTests.cs | 50 ++++++++ .../Snapshots/AtsGeneratedAspire.verified.go | 8 +- ...TwoPassScanningGeneratedAspire.verified.go | 85 +++++++++---- .../AtsJavaCodeGeneratorTests.cs | 27 ++++ .../AtsGeneratedAspire.verified.java | 91 ++++++++++++-- ...oPassScanningGeneratedAspire.verified.java | 119 ++++++++++++++---- .../AtsPythonCodeGeneratorTests.cs | 13 ++ .../Snapshots/AtsGeneratedAspire.verified.py | 11 +- ...TwoPassScanningGeneratedAspire.verified.py | 9 ++ .../CallbackProxyTests.cs | 39 +++++- 17 files changed, 634 insertions(+), 66 deletions(-) diff --git a/src/Aspire.Hosting.CodeGeneration.Go/AtsGoCodeGenerator.cs b/src/Aspire.Hosting.CodeGeneration.Go/AtsGoCodeGenerator.cs index b76b79b8e58..63ff8a4b3b8 100644 --- a/src/Aspire.Hosting.CodeGeneration.Go/AtsGoCodeGenerator.cs +++ b/src/Aspire.Hosting.CodeGeneration.Go/AtsGoCodeGenerator.cs @@ -1401,6 +1401,28 @@ private void EmitCallbackRegistration(string indent, AtsParameterInfo p, string // Legacy untyped callback returning any — preserve return value. WriteLine($"{indent}\t\treturn {callExpr}"); } + else if (p.CallbackParameters is { Count: > 0 } callbackParameters && callbackParameters.Any(cp => cp.Type.Category == AtsTypeCategory.Dto)) + { + var argNames = new List(callbackParameters.Count); + for (var i = 0; i < callbackParameters.Count; i++) + { + var argName = $"arg{i}"; + argNames.Add(argName); + var goType = MapTypeRefToGo(callbackParameters[i].Type, false); + WriteLine($"{indent}\t\t{argName} := callbackArg[{goType}](args, {i})"); + } + + WriteLine($"{indent}\t\tcb({string.Join(", ", argNames)})"); + WriteLine($"{indent}\t\treturn map[string]any{{"); + for (var i = 0; i < callbackParameters.Count; i++) + { + if (callbackParameters[i].Type.Category == AtsTypeCategory.Dto) + { + WriteLine($"{indent}\t\t\t\"p{i}\": serializeValue({argNames[i]}),"); + } + } + WriteLine($"{indent}\t\t}}"); + } else { WriteLine($"{indent}\t\t{callExpr}"); @@ -1840,9 +1862,15 @@ private void GenerateCreateBuilder(AtsContext context) WriteLine("\t}"); } WriteLine("\tif _, ok := resolved[\"Args\"]; !ok { resolved[\"Args\"] = os.Args[1:] }"); - WriteLine("\tif _, ok := resolved[\"ProjectDirectory\"]; !ok {"); + WriteLine("\tif projectDirectory, ok := resolved[\"ProjectDirectory\"].(string); !ok || projectDirectory == \"\" {"); WriteLine("\t\tif pwd, err := os.Getwd(); err == nil { resolved[\"ProjectDirectory\"] = pwd }"); WriteLine("\t}"); + WriteLine("\tif appHostFilePath, ok := resolved[\"AppHostFilePath\"].(string); !ok || appHostFilePath == \"\" {"); + WriteLine("\t\tif appHostFilePath := os.Getenv(\"ASPIRE_APPHOST_FILEPATH\"); appHostFilePath != \"\" { resolved[\"AppHostFilePath\"] = appHostFilePath }"); + WriteLine("\t}"); + WriteLine("\tif dashboardApplicationName, ok := resolved[\"DashboardApplicationName\"].(string); ok && dashboardApplicationName == \"\" {"); + WriteLine("\t\tdelete(resolved, \"DashboardApplicationName\")"); + WriteLine("\t}"); WriteLine(); WriteLine($"\tresult, err := c.invokeCapability(context.Background(), \"{AtsConstants.CreateBuilderCapability}\", map[string]any{{\"argsOrOptions\": resolved}})"); WriteLine("\tif err != nil { return nil, err }"); diff --git a/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go b/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go index e68a285bf35..6d3a59dc493 100644 --- a/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go +++ b/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "reflect" + "strings" "sync" ) @@ -712,11 +713,71 @@ func decodeAs[T any](raw any) (T, error) { } var out T if err := json.Unmarshal(bytes, &out); err != nil { + if decoded, ok := decodeStructFields[T](raw); ok { + return decoded, nil + } return zero, err } return out, nil } +func decodeStructFields[T any](raw any) (T, bool) { + var zero T + rawMap, ok := raw.(map[string]any) + if !ok { + return zero, false + } + + targetType := reflect.TypeOf((*T)(nil)).Elem() + isPointer := targetType.Kind() == reflect.Pointer + if isPointer { + targetType = targetType.Elem() + } + if targetType.Kind() != reflect.Struct { + return zero, false + } + + targetValue := reflect.New(targetType) + structValue := targetValue.Elem() + for i := 0; i < targetType.NumField(); i++ { + fieldInfo := targetType.Field(i) + fieldValue := structValue.Field(i) + if !fieldValue.CanSet() { + continue + } + + fieldName := fieldInfo.Name + if tag := fieldInfo.Tag.Get("json"); tag != "" { + name, _, _ := strings.Cut(tag, ",") + if name == "-" { + continue + } + if name != "" { + fieldName = name + } + } + + rawFieldValue, ok := rawMap[fieldName] + if !ok { + continue + } + + bytes, err := json.Marshal(rawFieldValue) + if err != nil { + continue + } + if err := json.Unmarshal(bytes, fieldValue.Addr().Interface()); err != nil { + continue + } + } + + if isPointer { + return targetValue.Interface().(T), true + } + + return structValue.Interface().(T), true +} + // ── deepUpdate ─────────────────────────────────────────────────────────────── // // Used by generated code to merge variadic Options structs: diff --git a/src/Aspire.Hosting.CodeGeneration.Java/AtsJavaCodeGenerator.cs b/src/Aspire.Hosting.CodeGeneration.Java/AtsJavaCodeGenerator.cs index 16def94d8d4..17ad83febbe 100644 --- a/src/Aspire.Hosting.CodeGeneration.Java/AtsJavaCodeGenerator.cs +++ b/src/Aspire.Hosting.CodeGeneration.Java/AtsJavaCodeGenerator.cs @@ -527,7 +527,7 @@ private void GenerateDtoTypes(IReadOnlyList dtoTypes) var dtoName = _dtoNames[dto.TypeId]; WriteLine($"/** {dto.Name} DTO. */"); - WriteLine($"class {dtoName} {{"); + WriteLine($"class {dtoName} implements JsonSerializable {{"); // Fields foreach (var property in dto.Properties) diff --git a/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java b/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java index 4f19e533736..506be4e1adc 100644 --- a/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java +++ b/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java @@ -80,6 +80,10 @@ void onCancel(Runnable listener) { } } +interface JsonSerializable { + Map toMap(); +} + /** * AspireClient handles JSON-RPC communication with the AppHost server. */ @@ -314,7 +318,7 @@ private String readLine() throws IOException { private void handleServerRequest(Map request) throws IOException { String method = (String) request.get("method"); Object idObj = request.get("id"); - Map params = (Map) request.get("params"); + Object params = request.get("params"); debug("Received server request: " + method); @@ -323,8 +327,8 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { - String callbackId = (String) params.get("callbackId"); - List args = (List) params.get("args"); + String callbackId = getCallbackId(params); + List args = getCallbackArgs(params); Function callback = callbacks.get(callbackId); if (callback != null) { @@ -336,7 +340,7 @@ private void handleServerRequest(Map request) throws IOException error = createError(-32601, "Callback not found: " + callbackId); } } else if ("cancel".equals(method)) { - String cancellationId = (String) params.get("cancellationId"); + String cancellationId = getCancellationId(params); Consumer handler = cancellations.get(cancellationId); if (handler != null) { handler.accept(null); @@ -362,6 +366,60 @@ private void handleServerRequest(Map request) throws IOException sendMessage(response); } + @SuppressWarnings("unchecked") + private String getCallbackId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("callbackId"); + } + + return null; + } + + @SuppressWarnings("unchecked") + private List getCallbackArgs(Object params) { + Object args = null; + if (params instanceof List list && list.size() > 1) { + args = list.get(1); + } else if (params instanceof Map map) { + args = map.get("args"); + } + + if (args instanceof Map map) { + List positionalArgs = new ArrayList<>(); + for (var i = 0; ; i++) { + var key = "p" + i; + if (map.containsKey(key)) { + positionalArgs.add(map.get(key)); + } else { + break; + } + } + return positionalArgs; + } + + if (args instanceof List list) { + return (List) list; + } + + return args == null ? List.of() : List.of(args); + } + + private String getCancellationId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("cancellationId"); + } + + return null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); @@ -462,6 +520,9 @@ public static Object serializeValue(Object value) { if (value instanceof AspireUnion union) { return serializeValue(union.getValue()); } + if (value instanceof JsonSerializable jsonSerializable) { + return jsonSerializable.toMap(); + } if (value instanceof Map) { @SuppressWarnings("unchecked") Map map = (Map) value; diff --git a/src/Aspire.Hosting.CodeGeneration.Python/PythonModuleBuilder.cs b/src/Aspire.Hosting.CodeGeneration.Python/PythonModuleBuilder.cs index b0e525a9b85..6bfae92c5b9 100644 --- a/src/Aspire.Hosting.CodeGeneration.Python/PythonModuleBuilder.cs +++ b/src/Aspire.Hosting.CodeGeneration.Python/PythonModuleBuilder.cs @@ -1795,6 +1795,7 @@ def create_builder( *, args: typing.Iterable[str] | None = None, project_directory: str | None = None, + app_host_file_path: str | None = None, container_registry_override: str | None = None, disable_dashboard: bool | None = None, dashboard_application_name: str | None = None, @@ -1813,6 +1814,8 @@ def create_builder( passed to the Aspire command line (arguments specified after '--'). Specifying them here will override that default. project_directory (str): The directory containing the AppHost project file. By default, this will use the ASPIRE_PROJECT_DIRECTORY environment variable if set, otherwise it will use the current working directory. + app_host_file_path (str): The path to the AppHost source file. By default, this will use the ASPIRE_APPHOST_FILEPATH + environment variable if set. container_registry_override (str): When containers are used, use this value to override the container registry. disable_dashboard (bool): Determines whether the dashboard is disabled. dashboard_application_name (str): The application name to display in the dashboard. @@ -1842,6 +1845,12 @@ elif not effective_options.get('Args'): effective_options['ProjectDirectory'] = project_directory elif not effective_options.get('ProjectDirectory'): effective_options['ProjectDirectory'] = os.environ.get('ASPIRE_PROJECT_DIRECTORY', os.getcwd()) + if app_host_file_path is not None: + effective_options['AppHostFilePath'] = app_host_file_path + elif not effective_options.get('AppHostFilePath'): + app_host_file_path = os.environ.get('ASPIRE_APPHOST_FILEPATH') + if app_host_file_path: + effective_options['AppHostFilePath'] = app_host_file_path if container_registry_override is not None: effective_options['ContainerRegistryOverride'] = container_registry_override if disable_dashboard is not None: diff --git a/src/Aspire.Hosting.RemoteHost/Ats/AtsCallbackProxyFactory.cs b/src/Aspire.Hosting.RemoteHost/Ats/AtsCallbackProxyFactory.cs index f99d702630d..ffada37c832 100644 --- a/src/Aspire.Hosting.RemoteHost/Ats/AtsCallbackProxyFactory.cs +++ b/src/Aspire.Hosting.RemoteHost/Ats/AtsCallbackProxyFactory.cs @@ -157,7 +157,8 @@ private Expression BuildMarshalArgs(ParameterExpression[] paramExprs, ParameterI var marshalCall = Expression.Call( Expression.Constant(this), marshalMethod, - Expression.Convert(paramExpr, typeof(object))); + Expression.Convert(paramExpr, typeof(object)), + Expression.Constant(param.ParameterType, typeof(Type))); // Use positional key (p0, p1, p2, ...) instead of param.Name var addCall = Expression.Call(jsonObjVar, addMethod!, Expression.Constant($"p{paramIndex}"), marshalCall); @@ -169,9 +170,9 @@ private Expression BuildMarshalArgs(ParameterExpression[] paramExprs, ParameterI return Expression.Block(new[] { jsonObjVar }, expressions); } - private JsonNode? MarshalArg(object? value) + private JsonNode? MarshalArg(object? value, Type declaredType) { - return _marshaller.MarshalToJson(value); + return _marshaller.MarshalToJson(value, declaredType); } private Expression BuildSyncVoidCall(string callbackId, Expression? argsExpr, Expression? ctExpr, int ctParamIndex) diff --git a/src/Aspire.Hosting.RemoteHost/Ats/AtsMarshaller.cs b/src/Aspire.Hosting.RemoteHost/Ats/AtsMarshaller.cs index e77fa0a4f47..006352e3056 100644 --- a/src/Aspire.Hosting.RemoteHost/Ats/AtsMarshaller.cs +++ b/src/Aspire.Hosting.RemoteHost/Ats/AtsMarshaller.cs @@ -159,6 +159,45 @@ public static bool IsSimpleType(Type type) }; } + /// + /// Marshals a .NET object to JSON for sending to the guest using a declared CLR type. + /// + /// The value to marshal. + /// The declared type that should be exposed to the guest. + /// The JSON representation, or null if the value is null. + public JsonNode? MarshalToJson(object? value, Type declaredType) + { + if (value == null) + { + return null; + } + + if (declaredType == typeof(object)) + { + return MarshalToJson(value); + } + + if (declaredType == typeof(CancellationToken)) + { + return SerializeCancellationToken((CancellationToken)value); + } + + var typeId = AtsTypeMapping.DeriveTypeId(declaredType); + var category = _context.GetCategory(declaredType); + + return category switch + { + AtsTypeCategory.Primitive => SerializePrimitive(value), + AtsTypeCategory.Enum => JsonValue.Create(value.ToString()), + AtsTypeCategory.Dto => SerializeDto(value), + AtsTypeCategory.Array => SerializeArray(value, CreateElementTypeRef(declaredType)), + AtsTypeCategory.List => _handles.Marshal(value, typeId), + AtsTypeCategory.Dict => _handles.Marshal(value, typeId), + AtsTypeCategory.Handle => _handles.Marshal(value, typeId), + _ => _handles.Marshal(value, typeId) + }; + } + private static JsonNode? SerializePrimitive(object value) { var type = value.GetType(); @@ -207,6 +246,37 @@ public static bool IsSimpleType(Type type) return jsonArray; } + private AtsTypeRef? CreateElementTypeRef(Type declaredType) + { + if (declaredType.IsArray) + { + return CreateTypeRef(declaredType.GetElementType()!); + } + + if (declaredType.IsGenericType) + { + var genericTypeDefinition = declaredType.GetGenericTypeDefinition(); + if (genericTypeDefinition == typeof(IReadOnlyList<>) + || genericTypeDefinition == typeof(IReadOnlyCollection<>) + || genericTypeDefinition == typeof(IEnumerable<>)) + { + return CreateTypeRef(declaredType.GetGenericArguments()[0]); + } + } + + return null; + } + + private AtsTypeRef CreateTypeRef(Type type) + { + return new AtsTypeRef + { + TypeId = AtsTypeMapping.DeriveTypeId(type), + ClrType = type, + Category = _context.GetCategory(type) + }; + } + /// /// Marshals a .NET object to JSON for sending to the guest. /// Uses runtime type inspection based on scanned AtsContext. diff --git a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/AtsGoCodeGeneratorTests.cs b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/AtsGoCodeGeneratorTests.cs index 520335695e5..6b065d4cad1 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/AtsGoCodeGeneratorTests.cs +++ b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/AtsGoCodeGeneratorTests.cs @@ -266,6 +266,56 @@ public void GeneratedCode_HasCreateBuilderFunction() Assert.Contains("func CreateBuilder", aspireGo); } + [Fact] + public void GeneratedCode_CreateBuilderDefaultsAppHostFilePathFromEnvironment() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspireGo = files["aspire.go"]; + + Assert.Contains("if appHostFilePath, ok := resolved[\"AppHostFilePath\"].(string); !ok || appHostFilePath == \"\"", aspireGo); + Assert.Contains("os.Getenv(\"ASPIRE_APPHOST_FILEPATH\")", aspireGo); + Assert.Contains("resolved[\"AppHostFilePath\"] = appHostFilePath", aspireGo); + } + + [Fact] + public void GeneratedCode_CreateBuilderOmitsEmptyDashboardApplicationName() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspireGo = files["aspire.go"]; + + Assert.Contains("if dashboardApplicationName, ok := resolved[\"DashboardApplicationName\"].(string); ok && dashboardApplicationName == \"\"", aspireGo); + Assert.Contains("delete(resolved, \"DashboardApplicationName\")", aspireGo); + } + + [Fact] + public void GeneratedCode_DtoCallbacksReturnMutatedArguments() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspireGo = files["aspire.go"]; + + Assert.Contains("arg0 := callbackArg[*ResourceUrlAnnotation](args, 0)", aspireGo); + Assert.Contains("cb(arg0)", aspireGo); + Assert.Contains("\"p0\": serializeValue(arg0)", aspireGo); + } + + [Fact] + public void GeneratedCode_CallbackArgsSkipUndecodableStructFields() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var baseGo = files["base.go"]; + + Assert.Contains("func decodeStructFields[T any](raw any) (T, bool)", baseGo); + Assert.Contains("fieldInfo.Tag.Get(\"json\")", baseGo); + } + [Fact] public void GeneratedCode_HasGoModFile() { diff --git a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/AtsGeneratedAspire.verified.go b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/AtsGeneratedAspire.verified.go index 1485d2d9c03..c0aa60cb43a 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/AtsGeneratedAspire.verified.go +++ b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/AtsGeneratedAspire.verified.go @@ -2337,9 +2337,15 @@ func CreateBuilder() (IDistributedApplicationBuilder, error) { resolved := map[string]any{} if _, ok := resolved["Args"]; !ok { resolved["Args"] = os.Args[1:] } - if _, ok := resolved["ProjectDirectory"]; !ok { + if projectDirectory, ok := resolved["ProjectDirectory"].(string); !ok || projectDirectory == "" { if pwd, err := os.Getwd(); err == nil { resolved["ProjectDirectory"] = pwd } } + if appHostFilePath, ok := resolved["AppHostFilePath"].(string); !ok || appHostFilePath == "" { + if appHostFilePath := os.Getenv("ASPIRE_APPHOST_FILEPATH"); appHostFilePath != "" { resolved["AppHostFilePath"] = appHostFilePath } + } + if dashboardApplicationName, ok := resolved["DashboardApplicationName"].(string); ok && dashboardApplicationName == "" { + delete(resolved, "DashboardApplicationName") + } result, err := c.invokeCapability(context.Background(), "Aspire.Hosting/createBuilder", map[string]any{"argsOrOptions": resolved}) if err != nil { return nil, err } diff --git a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.go b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.go index db00728646b..0dea5fee448 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.go +++ b/tests/Aspire.Hosting.CodeGeneration.Go.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.go @@ -2429,8 +2429,11 @@ func (s *aspire_Hosting_CodeGeneration_Go_TestsTestVaultResource) WithUrlForEndp if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -4047,8 +4050,11 @@ func (s *cSharpAppResource) WithUrlForEndpoint(endpointName string, callback fun if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -5384,8 +5390,11 @@ func (s *containerRegistryResource) WithUrlForEndpoint(endpointName string, call if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -7082,8 +7091,11 @@ func (s *containerResource) WithUrlForEndpoint(endpointName string, callback fun if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -10060,8 +10072,11 @@ func (s *dotnetToolResource) WithUrlForEndpoint(endpointName string, callback fu if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -12488,8 +12503,11 @@ func (s *executableResource) WithUrlForEndpoint(endpointName string, callback fu if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -13640,8 +13658,11 @@ func (s *externalServiceResource) WithUrlForEndpoint(endpointName string, callba if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -14869,8 +14890,11 @@ func (s *parameterResource) WithUrlForEndpoint(endpointName string, callback fun if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -17027,8 +17051,11 @@ func (s *projectResource) WithUrlForEndpoint(endpointName string, callback func( if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -20108,8 +20135,11 @@ func (s *testDatabaseResource) WithUrlForEndpoint(endpointName string, callback if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -22282,8 +22312,11 @@ func (s *testRedisResource) WithUrlForEndpoint(endpointName string, callback fun if callback != nil { cb := callback shim := func(args ...any) any { - cb(callbackArg[*ResourceUrlAnnotation](args, 0)) - return nil + arg0 := callbackArg[*ResourceUrlAnnotation](args, 0) + cb(arg0) + return map[string]any{ + "p0": serializeValue(arg0), + } } reqArgs["callback"] = s.client.registerCallback(shim) } @@ -23926,9 +23959,15 @@ func CreateBuilder(options ...*CreateBuilderOptions) (DistributedApplicationBuil for k, v := range merged.ToMap() { resolved[k] = v } } if _, ok := resolved["Args"]; !ok { resolved["Args"] = os.Args[1:] } - if _, ok := resolved["ProjectDirectory"]; !ok { + if projectDirectory, ok := resolved["ProjectDirectory"].(string); !ok || projectDirectory == "" { if pwd, err := os.Getwd(); err == nil { resolved["ProjectDirectory"] = pwd } } + if appHostFilePath, ok := resolved["AppHostFilePath"].(string); !ok || appHostFilePath == "" { + if appHostFilePath := os.Getenv("ASPIRE_APPHOST_FILEPATH"); appHostFilePath != "" { resolved["AppHostFilePath"] = appHostFilePath } + } + if dashboardApplicationName, ok := resolved["DashboardApplicationName"].(string); ok && dashboardApplicationName == "" { + delete(resolved, "DashboardApplicationName") + } result, err := c.invokeCapability(context.Background(), "Aspire.Hosting/createBuilder", map[string]any{"argsOrOptions": resolved}) if err != nil { return nil, err } diff --git a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/AtsJavaCodeGeneratorTests.cs b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/AtsJavaCodeGeneratorTests.cs index 283d1375fdc..52fdec0f088 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/AtsJavaCodeGeneratorTests.cs +++ b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/AtsJavaCodeGeneratorTests.cs @@ -314,6 +314,33 @@ public void GeneratedCode_HasPublicAspireClass() Assert.Contains("public class Aspire", aspireJava); } + [Fact] + public void GeneratedTransport_HandlesJsonRpcArrayCallbackParameters() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspireClientJava = files["AspireClient.java"]; + + Assert.Contains("private String getCallbackId(Object params)", aspireClientJava); + Assert.Contains("if (params instanceof List list && !list.isEmpty())", aspireClientJava); + Assert.Contains("var key = \"p\" + i;", aspireClientJava); + } + + [Fact] + public void GeneratedDtoValues_AreSerializedAsMaps() + { + var atsContext = CreateContextFromTestAssembly(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspireClientJava = files["AspireClient.java"]; + var testConfigDtoJava = files["TestConfigDto.java"]; + + Assert.Contains("interface JsonSerializable", files["JsonSerializable.java"]); + Assert.Contains("if (value instanceof JsonSerializable jsonSerializable)", aspireClientJava); + Assert.Contains("public class TestConfigDto implements JsonSerializable", testConfigDtoJava); + } + private static string JoinGeneratedFiles(Dictionary files) { return string.Join( diff --git a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java index eb1e128fa18..575522420de 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java +++ b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java @@ -1,4 +1,4 @@ -// ===== Aspire.java ===== +// ===== Aspire.java ===== // Aspire.java - GENERATED CODE - DO NOT EDIT package aspire; @@ -360,7 +360,7 @@ private String readLine() throws IOException { private void handleServerRequest(Map request) throws IOException { String method = (String) request.get("method"); Object idObj = request.get("id"); - Map params = (Map) request.get("params"); + Object params = request.get("params"); debug("Received server request: " + method); @@ -369,8 +369,8 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { - String callbackId = (String) params.get("callbackId"); - List args = (List) params.get("args"); + String callbackId = getCallbackId(params); + List args = getCallbackArgs(params); Function callback = callbacks.get(callbackId); if (callback != null) { @@ -382,7 +382,7 @@ private void handleServerRequest(Map request) throws IOException error = createError(-32601, "Callback not found: " + callbackId); } } else if ("cancel".equals(method)) { - String cancellationId = (String) params.get("cancellationId"); + String cancellationId = getCancellationId(params); Consumer handler = cancellations.get(cancellationId); if (handler != null) { handler.accept(null); @@ -408,6 +408,60 @@ private void handleServerRequest(Map request) throws IOException sendMessage(response); } + @SuppressWarnings("unchecked") + private String getCallbackId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("callbackId"); + } + + return null; + } + + @SuppressWarnings("unchecked") + private List getCallbackArgs(Object params) { + Object args = null; + if (params instanceof List list && list.size() > 1) { + args = list.get(1); + } else if (params instanceof Map map) { + args = map.get("args"); + } + + if (args instanceof Map map) { + List positionalArgs = new ArrayList<>(); + for (var i = 0; ; i++) { + var key = "p" + i; + if (map.containsKey(key)) { + positionalArgs.add(map.get(key)); + } else { + break; + } + } + return positionalArgs; + } + + if (args instanceof List list) { + return (List) list; + } + + return args == null ? List.of() : List.of(args); + } + + private String getCancellationId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("cancellationId"); + } + + return null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); @@ -508,6 +562,9 @@ public static Object serializeValue(Object value) { if (value instanceof AspireUnion union) { return serializeValue(union.getValue()); } + if (value instanceof JsonSerializable jsonSerializable) { + return jsonSerializable.toMap(); + } if (value instanceof Map) { @SuppressWarnings("unchecked") Map map = (Map) value; @@ -1324,6 +1381,23 @@ public class ITestVaultResource extends ResourceBuilderBase { } +// ===== JsonSerializable.java ===== +// JsonSerializable.java - GENERATED CODE - DO NOT EDIT + +package aspire; + +import java.io.*; +import java.net.*; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.*; +import java.util.function.*; + +public interface JsonSerializable { + Map toMap(); +} + // ===== ReferenceExpression.java ===== // ReferenceExpression.java - GENERATED CODE - DO NOT EDIT @@ -1574,7 +1648,7 @@ public AspireDict metadata() { import java.util.function.*; /** TestConfigDto DTO. */ -public class TestConfigDto { +public class TestConfigDto implements JsonSerializable { private String name; private double port; private boolean enabled; @@ -2000,7 +2074,7 @@ public TestDatabaseResource withMergeRouteMiddleware(String path, String method, import java.util.function.*; /** TestDeeplyNestedDto DTO. */ -public class TestDeeplyNestedDto { +public class TestDeeplyNestedDto implements JsonSerializable { private AspireDict> nestedData; private AspireDict[] metadataArray; @@ -2137,7 +2211,7 @@ public TestMutableCollectionContext setCounts(AspireDict value) import java.util.function.*; /** TestNestedDto DTO. */ -public class TestNestedDto { +public class TestNestedDto implements JsonSerializable { private String id; private TestConfigDto config; private AspireList tags; @@ -3292,6 +3366,7 @@ public WithOptionalStringOptions enabled(Boolean value) { .modules/IResourceWithConnectionString.java .modules/IResourceWithEnvironment.java .modules/ITestVaultResource.java +.modules/JsonSerializable.java .modules/ReferenceExpression.java .modules/ResourceBuilderBase.java .modules/TestCallbackContext.java diff --git a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java index 7abe7db08b1..eed62a5987a 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java +++ b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java @@ -7,7 +7,7 @@ import java.util.function.*; /** AddContainerOptions DTO. */ -public class AddContainerOptions { +public class AddContainerOptions implements JsonSerializable { private String image; private String tag; @@ -531,7 +531,7 @@ private String readLine() throws IOException { private void handleServerRequest(Map request) throws IOException { String method = (String) request.get("method"); Object idObj = request.get("id"); - Map params = (Map) request.get("params"); + Object params = request.get("params"); debug("Received server request: " + method); @@ -540,8 +540,8 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { - String callbackId = (String) params.get("callbackId"); - List args = (List) params.get("args"); + String callbackId = getCallbackId(params); + List args = getCallbackArgs(params); Function callback = callbacks.get(callbackId); if (callback != null) { @@ -553,7 +553,7 @@ private void handleServerRequest(Map request) throws IOException error = createError(-32601, "Callback not found: " + callbackId); } } else if ("cancel".equals(method)) { - String cancellationId = (String) params.get("cancellationId"); + String cancellationId = getCancellationId(params); Consumer handler = cancellations.get(cancellationId); if (handler != null) { handler.accept(null); @@ -579,6 +579,60 @@ private void handleServerRequest(Map request) throws IOException sendMessage(response); } + @SuppressWarnings("unchecked") + private String getCallbackId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("callbackId"); + } + + return null; + } + + @SuppressWarnings("unchecked") + private List getCallbackArgs(Object params) { + Object args = null; + if (params instanceof List list && list.size() > 1) { + args = list.get(1); + } else if (params instanceof Map map) { + args = map.get("args"); + } + + if (args instanceof Map map) { + List positionalArgs = new ArrayList<>(); + for (var i = 0; ; i++) { + var key = "p" + i; + if (map.containsKey(key)) { + positionalArgs.add(map.get(key)); + } else { + break; + } + } + return positionalArgs; + } + + if (args instanceof List list) { + return (List) list; + } + + return args == null ? List.of() : List.of(args); + } + + private String getCancellationId(Object params) { + if (params instanceof List list && !list.isEmpty()) { + return (String) list.get(0); + } + + if (params instanceof Map map) { + return (String) map.get("cancellationId"); + } + + return null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); @@ -679,6 +733,9 @@ public static Object serializeValue(Object value) { if (value instanceof AspireUnion union) { return serializeValue(union.getValue()); } + if (value instanceof JsonSerializable jsonSerializable) { + return jsonSerializable.toMap(); + } if (value instanceof Map) { @SuppressWarnings("unchecked") Map map = (Map) value; @@ -2937,7 +2994,7 @@ public class CapabilityError extends RuntimeException { import java.util.function.*; /** CertificateTrustExecutionConfigurationContext DTO. */ -public class CertificateTrustExecutionConfigurationContext { +public class CertificateTrustExecutionConfigurationContext implements JsonSerializable { private ReferenceExpression certificateBundlePath; private ReferenceExpression certificateDirectoriesPath; private String rootCertificatesPath; @@ -2971,7 +3028,7 @@ public Map toMap() { import java.util.function.*; /** CertificateTrustExecutionConfigurationExportData DTO. */ -public class CertificateTrustExecutionConfigurationExportData { +public class CertificateTrustExecutionConfigurationExportData implements JsonSerializable { private CertificateTrustScope scope; private String[] certificateSubjects; private String[] customBundlePaths; @@ -3132,7 +3189,7 @@ public void add(AspireUnion value) { import java.util.function.*; /** CommandOptions DTO. */ -public class CommandOptions { +public class CommandOptions implements JsonSerializable { private String description; private Object parameter; private String confirmationMessage; @@ -3178,7 +3235,7 @@ public Map toMap() { import java.util.function.*; /** CommandResultData DTO. */ -public class CommandResultData { +public class CommandResultData implements JsonSerializable { private String value; private CommandResultFormat format; private boolean displayImmediately; @@ -5912,7 +5969,7 @@ public ContainerResource withMergeRouteMiddleware(String path, String method, St import java.util.function.*; /** CreateBuilderOptions DTO. */ -public class CreateBuilderOptions { +public class CreateBuilderOptions implements JsonSerializable { private String[] args; private String projectDirectory; private String appHostFilePath; @@ -10021,7 +10078,7 @@ public ExecuteCommandContext setLogger(HandleWrapperBase value) { import java.util.function.*; /** ExecuteCommandResult DTO. */ -public class ExecuteCommandResult { +public class ExecuteCommandResult implements JsonSerializable { private boolean success; private boolean canceled; private String errorMessage; @@ -10768,7 +10825,7 @@ public ExternalServiceResource withMergeRouteMiddleware(String path, String meth import java.util.function.*; /** GenerateParameterDefault DTO. */ -public class GenerateParameterDefault { +public class GenerateParameterDefault implements JsonSerializable { private double minLength; private boolean lower; private boolean upper; @@ -10891,7 +10948,7 @@ AspireClient getClient() { import java.util.function.*; /** HttpCommandExportOptions DTO. */ -public class HttpCommandExportOptions { +public class HttpCommandExportOptions implements JsonSerializable { private String description; private String confirmationMessage; private String iconName; @@ -10976,7 +11033,7 @@ public static HttpCommandResultMode fromValue(String value) { import java.util.function.*; /** HttpsCertificateExecutionConfigurationContext DTO. */ -public class HttpsCertificateExecutionConfigurationContext { +public class HttpsCertificateExecutionConfigurationContext implements JsonSerializable { private ReferenceExpression certificatePath; private ReferenceExpression keyPath; private ReferenceExpression pfxPath; @@ -11006,7 +11063,7 @@ public Map toMap() { import java.util.function.*; /** HttpsCertificateExecutionConfigurationExportData DTO. */ -public class HttpsCertificateExecutionConfigurationExportData { +public class HttpsCertificateExecutionConfigurationExportData implements JsonSerializable { private String subject; private String thumbprint; private String keyPathExpression; @@ -11052,7 +11109,7 @@ public Map toMap() { import java.util.function.*; /** HttpsCertificateInfo DTO. */ -public class HttpsCertificateInfo { +public class HttpsCertificateInfo implements JsonSerializable { private String subject; private String issuer; private String thumbprint; @@ -12634,6 +12691,23 @@ public IServiceProvider services() { } +// ===== JsonSerializable.java ===== +// JsonSerializable.java - GENERATED CODE - DO NOT EDIT + +package aspire; + +import java.io.*; +import java.net.*; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.concurrent.*; +import java.util.concurrent.atomic.*; +import java.util.function.*; + +public interface JsonSerializable { + Map toMap(); +} + // ===== LogFacade.java ===== // LogFacade.java - GENERATED CODE - DO NOT EDIT @@ -15392,7 +15466,7 @@ public PublishResourceUpdateOptions stateStyle(String value) { import java.util.function.*; /** ReferenceEnvironmentInjectionOptions DTO. */ -public class ReferenceEnvironmentInjectionOptions { +public class ReferenceEnvironmentInjectionOptions implements JsonSerializable { private boolean connectionString; private boolean connectionProperties; private boolean serviceDiscovery; @@ -15668,7 +15742,7 @@ public IServiceProvider services() { import java.util.function.*; /** ResourceEventDto DTO. */ -public class ResourceEventDto { +public class ResourceEventDto implements JsonSerializable { private String resourceName; private String resourceId; private String state; @@ -15907,7 +15981,7 @@ public IServiceProvider services() { import java.util.function.*; /** ResourceUrlAnnotation DTO. */ -public class ResourceUrlAnnotation { +public class ResourceUrlAnnotation implements JsonSerializable { private String url; private String displayText; private EndpointReference endpoint; @@ -16160,7 +16234,7 @@ public AspireDict metadata() { import java.util.function.*; /** TestConfigDto DTO. */ -public class TestConfigDto { +public class TestConfigDto implements JsonSerializable { private String name; private double port; private boolean enabled; @@ -17879,7 +17953,7 @@ public TestDatabaseResource withMergeRouteMiddleware(String path, String method, import java.util.function.*; /** TestDeeplyNestedDto DTO. */ -public class TestDeeplyNestedDto { +public class TestDeeplyNestedDto implements JsonSerializable { private AspireDict> nestedData; private AspireDict[] metadataArray; @@ -18016,7 +18090,7 @@ public TestMutableCollectionContext setCounts(AspireDict value) import java.util.function.*; /** TestNestedDto DTO. */ -public class TestNestedDto { +public class TestNestedDto implements JsonSerializable { private String id; private TestConfigDto config; private AspireList tags; @@ -22630,6 +22704,7 @@ public WithVolumeOptions isReadOnly(Boolean value) { .modules/IconVariant.java .modules/ImagePullPolicy.java .modules/InitializeResourceEvent.java +.modules/JsonSerializable.java .modules/LogFacade.java .modules/OtlpProtocol.java .modules/ParameterResource.java diff --git a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs index 796b04c921f..ac4df6ea292 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs +++ b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs @@ -266,6 +266,19 @@ public void GeneratedCode_HasCreateBuilderFunction() Assert.Contains("def create_builder", aspirePy); } + [Fact] + public void GeneratedCode_CreateBuilderDefaultsAppHostFilePathFromEnvironment() + { + var atsContext = CreateContextFromBothAssemblies(); + + var files = _generator.GenerateDistributedApplication(atsContext); + var aspirePy = files["aspire_app.py"]; + + Assert.Contains("app_host_file_path: str | None = None", aspirePy); + Assert.Contains("effective_options['AppHostFilePath'] = app_host_file_path", aspirePy); + Assert.Contains("app_host_file_path = os.environ.get('ASPIRE_APPHOST_FILEPATH')", aspirePy); + } + [Fact] public void GeneratedCode_UsesTypeHints() { diff --git a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/AtsGeneratedAspire.verified.py b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/AtsGeneratedAspire.verified.py index a1e1a57d93f..525ec6d6bd8 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/AtsGeneratedAspire.verified.py +++ b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/AtsGeneratedAspire.verified.py @@ -1,4 +1,4 @@ -# ------------------------------------------------------------- +# ------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See LICENSE in project root for information. # @@ -2838,6 +2838,7 @@ def create_builder( *, args: typing.Iterable[str] | None = None, project_directory: str | None = None, + app_host_file_path: str | None = None, container_registry_override: str | None = None, disable_dashboard: bool | None = None, dashboard_application_name: str | None = None, @@ -2856,6 +2857,8 @@ def create_builder( passed to the Aspire command line (arguments specified after '--'). Specifying them here will override that default. project_directory (str): The directory containing the AppHost project file. By default, this will use the ASPIRE_PROJECT_DIRECTORY environment variable if set, otherwise it will use the current working directory. + app_host_file_path (str): The path to the AppHost source file. By default, this will use the ASPIRE_APPHOST_FILEPATH + environment variable if set. container_registry_override (str): When containers are used, use this value to override the container registry. disable_dashboard (bool): Determines whether the dashboard is disabled. dashboard_application_name (str): The application name to display in the dashboard. @@ -2885,6 +2888,12 @@ def create_builder( effective_options['ProjectDirectory'] = project_directory elif not effective_options.get('ProjectDirectory'): effective_options['ProjectDirectory'] = os.environ.get('ASPIRE_PROJECT_DIRECTORY', os.getcwd()) + if app_host_file_path is not None: + effective_options['AppHostFilePath'] = app_host_file_path + elif not effective_options.get('AppHostFilePath'): + app_host_file_path = os.environ.get('ASPIRE_APPHOST_FILEPATH') + if app_host_file_path: + effective_options['AppHostFilePath'] = app_host_file_path if container_registry_override is not None: effective_options['ContainerRegistryOverride'] = container_registry_override if disable_dashboard is not None: diff --git a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py index 94ade19a23c..3427927302b 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py +++ b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py @@ -10471,6 +10471,7 @@ def create_builder( *, args: typing.Iterable[str] | None = None, project_directory: str | None = None, + app_host_file_path: str | None = None, container_registry_override: str | None = None, disable_dashboard: bool | None = None, dashboard_application_name: str | None = None, @@ -10489,6 +10490,8 @@ def create_builder( passed to the Aspire command line (arguments specified after '--'). Specifying them here will override that default. project_directory (str): The directory containing the AppHost project file. By default, this will use the ASPIRE_PROJECT_DIRECTORY environment variable if set, otherwise it will use the current working directory. + app_host_file_path (str): The path to the AppHost source file. By default, this will use the ASPIRE_APPHOST_FILEPATH + environment variable if set. container_registry_override (str): When containers are used, use this value to override the container registry. disable_dashboard (bool): Determines whether the dashboard is disabled. dashboard_application_name (str): The application name to display in the dashboard. @@ -10518,6 +10521,12 @@ def create_builder( effective_options['ProjectDirectory'] = project_directory elif not effective_options.get('ProjectDirectory'): effective_options['ProjectDirectory'] = os.environ.get('ASPIRE_PROJECT_DIRECTORY', os.getcwd()) + if app_host_file_path is not None: + effective_options['AppHostFilePath'] = app_host_file_path + elif not effective_options.get('AppHostFilePath'): + app_host_file_path = os.environ.get('ASPIRE_APPHOST_FILEPATH') + if app_host_file_path: + effective_options['AppHostFilePath'] = app_host_file_path if container_registry_override is not None: effective_options['ContainerRegistryOverride'] = container_registry_override if disable_dashboard is not None: diff --git a/tests/Aspire.Hosting.RemoteHost.Tests/CallbackProxyTests.cs b/tests/Aspire.Hosting.RemoteHost.Tests/CallbackProxyTests.cs index 4a97a898e1e..6e8be54ea5b 100644 --- a/tests/Aspire.Hosting.RemoteHost.Tests/CallbackProxyTests.cs +++ b/tests/Aspire.Hosting.RemoteHost.Tests/CallbackProxyTests.cs @@ -146,6 +146,31 @@ public async Task InvokedProxy_PassesMultipleArgumentsAsJson() Assert.Equal(42, args["p1"]?.GetValue()); } + [Fact] + public async Task InvokedProxy_MarshalsHandleArgumentWithDeclaredType() + { + var invoker = new TestCallbackInvoker(); + using var factory = CreateFactory(invoker, handleTypes: + [ + new AtsTypeInfo + { + AtsTypeId = AtsTypeMapping.DeriveTypeId(typeof(ITestCallbackHandle)), + ClrType = typeof(ITestCallbackHandle), + IsInterface = true + } + ]); + + var proxy = (TestCallbackWithHandle)factory.CreateProxy("test-callback", typeof(TestCallbackWithHandle))!; + + await proxy(new TestCallbackHandle()); + + Assert.Single(invoker.Invocations); + var args = Assert.IsAssignableFrom(invoker.Invocations[0].Args); + var handle = Assert.IsAssignableFrom(args["p0"]); + Assert.NotNull(handle["$handle"]?.GetValue()); + Assert.Equal(AtsTypeMapping.DeriveTypeId(typeof(ITestCallbackHandle)), handle["$type"]?.GetValue()); + } + [Fact] public async Task InvokedProxy_WithResultReturnsCorrectValue() { @@ -359,7 +384,7 @@ public void InvokedSyncVoidProxy_AppliesWritebackToMultipleDtos() Assert.Equal(20, dto2.Count); } - private static AtsCallbackProxyFactory CreateFactory(ICallbackInvoker? invoker = null, bool registerDtoTypes = false) + private static AtsCallbackProxyFactory CreateFactory(ICallbackInvoker? invoker = null, bool registerDtoTypes = false, IReadOnlyList? handleTypes = null) { var handles = new HandleRegistry(); var ctRegistry = new CancellationTokenRegistry(); @@ -369,7 +394,7 @@ private static AtsCallbackProxyFactory CreateFactory(ICallbackInvoker? invoker = new() { TypeId = "test/TestCallbackDto", Name = "TestCallbackDto", ClrType = typeof(TestCallbackDto), Properties = [] } } : []; - var context = new AtsContext { Capabilities = [], HandleTypes = [], DtoTypes = dtoTypes, EnumTypes = [] }; + var context = new AtsContext { Capabilities = [], HandleTypes = handleTypes ?? [], DtoTypes = dtoTypes, EnumTypes = [] }; var marshaller = new AtsMarshaller(handles, context, ctRegistry, new Lazy(() => throw new NotImplementedException())); return new AtsCallbackProxyFactory(invoker ?? new TestCallbackInvoker(), handles, ctRegistry, marshaller); } @@ -385,6 +410,8 @@ private static AtsCallbackProxyFactory CreateFactory(ICallbackInvoker? invoker = public delegate Task TestCallbackWithMultipleParams(string name, int count); + public delegate Task TestCallbackWithHandle(ITestCallbackHandle handle); + public delegate Task TestCallbackWithStringResult(string input); public delegate Task TestCallbackWithCancellation(string value, CancellationToken cancellationToken); @@ -405,6 +432,14 @@ public sealed class TestCallbackDto public string? Name { get; set; } public int Count { get; set; } } + + public interface ITestCallbackHandle + { + } + + private sealed class TestCallbackHandle : ITestCallbackHandle + { + } } internal sealed class TestCallbackInvoker : ICallbackInvoker From 54cc0b2503548156fb8ef28b4e81f8ff742dce2a Mon Sep 17 00:00:00 2001 From: Sebastien Ros Date: Mon, 4 May 2026 07:36:44 -0700 Subject: [PATCH 2/3] Address polyglot review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../Resources/base.go | 2 +- .../Resources/Transport.java | 46 ++++++++++++------- .../AtsGeneratedAspire.verified.java | 46 ++++++++++++------- ...oPassScanningGeneratedAspire.verified.java | 46 ++++++++++++------- 4 files changed, 88 insertions(+), 52 deletions(-) diff --git a/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go b/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go index 6d3a59dc493..216ae335f01 100644 --- a/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go +++ b/src/Aspire.Hosting.CodeGeneration.Go/Resources/base.go @@ -729,7 +729,7 @@ func decodeStructFields[T any](raw any) (T, bool) { } targetType := reflect.TypeOf((*T)(nil)).Elem() - isPointer := targetType.Kind() == reflect.Pointer + isPointer := targetType.Kind() == reflect.Ptr if isPointer { targetType = targetType.Elem() } diff --git a/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java b/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java index 506be4e1adc..9aff38c00ff 100644 --- a/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java +++ b/src/Aspire.Hosting.CodeGeneration.Java/Resources/Transport.java @@ -328,24 +328,32 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { String callbackId = getCallbackId(params); - List args = getCallbackArgs(params); - - Function callback = callbacks.get(callbackId); - if (callback != null) { - Object[] unwrappedArgs = args.stream() - .map(this::unwrapResult) - .toArray(); - result = awaitValue(callback.apply(unwrappedArgs)); + if (callbackId == null) { + error = createError(-32602, "Invalid params: callbackId is required."); } else { - error = createError(-32601, "Callback not found: " + callbackId); + List args = getCallbackArgs(params); + + Function callback = callbacks.get(callbackId); + if (callback != null) { + Object[] unwrappedArgs = args.stream() + .map(this::unwrapResult) + .toArray(); + result = awaitValue(callback.apply(unwrappedArgs)); + } else { + error = createError(-32601, "Callback not found: " + callbackId); + } } } else if ("cancel".equals(method)) { String cancellationId = getCancellationId(params); - Consumer handler = cancellations.get(cancellationId); - if (handler != null) { - handler.accept(null); + if (cancellationId == null) { + error = createError(-32602, "Invalid params: cancellationId is required."); + } else { + Consumer handler = cancellations.get(cancellationId); + if (handler != null) { + handler.accept(null); + } + result = true; } - result = true; } else { error = createError(-32601, "Unknown method: " + method); } @@ -369,11 +377,11 @@ private void handleServerRequest(Map request) throws IOException @SuppressWarnings("unchecked") private String getCallbackId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("callbackId"); + return asString(map.get("callbackId")); } return null; @@ -410,16 +418,20 @@ private List getCallbackArgs(Object params) { private String getCancellationId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("cancellationId"); + return asString(map.get("cancellationId")); } return null; } + private String asString(Object value) { + return value instanceof String string ? string : null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); diff --git a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java index 575522420de..848ca1971ab 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java +++ b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/AtsGeneratedAspire.verified.java @@ -370,24 +370,32 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { String callbackId = getCallbackId(params); - List args = getCallbackArgs(params); - - Function callback = callbacks.get(callbackId); - if (callback != null) { - Object[] unwrappedArgs = args.stream() - .map(this::unwrapResult) - .toArray(); - result = awaitValue(callback.apply(unwrappedArgs)); + if (callbackId == null) { + error = createError(-32602, "Invalid params: callbackId is required."); } else { - error = createError(-32601, "Callback not found: " + callbackId); + List args = getCallbackArgs(params); + + Function callback = callbacks.get(callbackId); + if (callback != null) { + Object[] unwrappedArgs = args.stream() + .map(this::unwrapResult) + .toArray(); + result = awaitValue(callback.apply(unwrappedArgs)); + } else { + error = createError(-32601, "Callback not found: " + callbackId); + } } } else if ("cancel".equals(method)) { String cancellationId = getCancellationId(params); - Consumer handler = cancellations.get(cancellationId); - if (handler != null) { - handler.accept(null); + if (cancellationId == null) { + error = createError(-32602, "Invalid params: cancellationId is required."); + } else { + Consumer handler = cancellations.get(cancellationId); + if (handler != null) { + handler.accept(null); + } + result = true; } - result = true; } else { error = createError(-32601, "Unknown method: " + method); } @@ -411,11 +419,11 @@ private void handleServerRequest(Map request) throws IOException @SuppressWarnings("unchecked") private String getCallbackId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("callbackId"); + return asString(map.get("callbackId")); } return null; @@ -452,16 +460,20 @@ private List getCallbackArgs(Object params) { private String getCancellationId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("cancellationId"); + return asString(map.get("cancellationId")); } return null; } + private String asString(Object value) { + return value instanceof String string ? string : null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); diff --git a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java index eed62a5987a..ced54dcb9c5 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java +++ b/tests/Aspire.Hosting.CodeGeneration.Java.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.java @@ -541,24 +541,32 @@ private void handleServerRequest(Map request) throws IOException try { if ("invokeCallback".equals(method)) { String callbackId = getCallbackId(params); - List args = getCallbackArgs(params); - - Function callback = callbacks.get(callbackId); - if (callback != null) { - Object[] unwrappedArgs = args.stream() - .map(this::unwrapResult) - .toArray(); - result = awaitValue(callback.apply(unwrappedArgs)); + if (callbackId == null) { + error = createError(-32602, "Invalid params: callbackId is required."); } else { - error = createError(-32601, "Callback not found: " + callbackId); + List args = getCallbackArgs(params); + + Function callback = callbacks.get(callbackId); + if (callback != null) { + Object[] unwrappedArgs = args.stream() + .map(this::unwrapResult) + .toArray(); + result = awaitValue(callback.apply(unwrappedArgs)); + } else { + error = createError(-32601, "Callback not found: " + callbackId); + } } } else if ("cancel".equals(method)) { String cancellationId = getCancellationId(params); - Consumer handler = cancellations.get(cancellationId); - if (handler != null) { - handler.accept(null); + if (cancellationId == null) { + error = createError(-32602, "Invalid params: cancellationId is required."); + } else { + Consumer handler = cancellations.get(cancellationId); + if (handler != null) { + handler.accept(null); + } + result = true; } - result = true; } else { error = createError(-32601, "Unknown method: " + method); } @@ -582,11 +590,11 @@ private void handleServerRequest(Map request) throws IOException @SuppressWarnings("unchecked") private String getCallbackId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("callbackId"); + return asString(map.get("callbackId")); } return null; @@ -623,16 +631,20 @@ private List getCallbackArgs(Object params) { private String getCancellationId(Object params) { if (params instanceof List list && !list.isEmpty()) { - return (String) list.get(0); + return asString(list.get(0)); } if (params instanceof Map map) { - return (String) map.get("cancellationId"); + return asString(map.get("cancellationId")); } return null; } + private String asString(Object value) { + return value instanceof String string ? string : null; + } + private Map createError(int code, String message) { Map error = new HashMap<>(); error.put("code", code); From 5d8e7fd2ef1dcc3da3418ff8099db7871792c7f2 Mon Sep 17 00:00:00 2001 From: Sebastien Ros Date: Mon, 4 May 2026 13:07:54 -0700 Subject: [PATCH 3/3] Use JSON snake case naming for Python codegen Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../AtsPythonCodeGenerator.cs | 19 +---- .../AtsPythonCodeGeneratorTests.cs | 69 +++++++++++++++++++ ...TwoPassScanningGeneratedAspire.verified.py | 12 ++-- 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/src/Aspire.Hosting.CodeGeneration.Python/AtsPythonCodeGenerator.cs b/src/Aspire.Hosting.CodeGeneration.Python/AtsPythonCodeGenerator.cs index a1c8041ed5a..83b1797e5aa 100644 --- a/src/Aspire.Hosting.CodeGeneration.Python/AtsPythonCodeGenerator.cs +++ b/src/Aspire.Hosting.CodeGeneration.Python/AtsPythonCodeGenerator.cs @@ -4,6 +4,7 @@ using System.Globalization; using System.Reflection; using System.Text; +using System.Text.Json; using System.Text.Json.Nodes; using Aspire.TypeSystem; @@ -933,23 +934,7 @@ private static string ToSnakeCase(string name) return name; } - var result = new System.Text.StringBuilder(); - result.Append(char.ToLowerInvariant(name[0])); - - for (int i = 1; i < name.Length; i++) - { - var c = name[i]; - if (char.IsUpper(c)) - { - result.Append('_'); - result.Append(char.ToLowerInvariant(c)); - } - else - { - result.Append(c); - } - } - var resultStr = result.ToString(); + var resultStr = JsonNamingPolicy.SnakeCaseLower.ConvertName(name); resultStr = resultStr.Replace("environment", "env"); resultStr = resultStr.Replace("configuration", "config"); resultStr = resultStr.Replace("application", "app"); diff --git a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs index ac4df6ea292..d61e18adb35 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs +++ b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/AtsPythonCodeGeneratorTests.cs @@ -305,6 +305,17 @@ public void GeneratedCode_SanitizesPythonKeywordIdentifiers() Assert.DoesNotContain("\n from: str", aspirePy); } + [Fact] + public void GeneratedCode_PreservesAcronymsInSnakeCaseIdentifiers() + { + var files = _generator.GenerateDistributedApplication(CreateContextWithAcronymIdentifiers()); + var aspirePy = files["aspire_app.py"]; + + Assert.Contains("def with_something_ai(self, something_ai: str)", aspirePy); + Assert.DoesNotContain("with_something_a_i", aspirePy); + Assert.DoesNotContain("something_a_i", aspirePy); + } + [Fact] public void GeneratedCode_SanitizesClrGenericNamesInInheritance() { @@ -428,6 +439,62 @@ private static AtsContext CreateContextWithKeywordParameter() }; } + private static AtsContext CreateContextWithAcronymIdentifiers() + { + var resourceType = new AtsTypeRef + { + TypeId = "Tests/AcronymResource", + ClrType = typeof(AcronymResource), + Category = AtsTypeCategory.Handle + }; + + return new AtsContext + { + Capabilities = + [ + new AtsCapabilityInfo + { + CapabilityId = "Tests/withSomethingAI", + MethodName = "withSomethingAI", + Parameters = + [ + new AtsParameterInfo + { + Name = "builder", + Type = resourceType + }, + new AtsParameterInfo + { + Name = "somethingAI", + Type = new AtsTypeRef + { + TypeId = AtsConstants.String, + Category = AtsTypeCategory.Primitive + } + } + ], + ReturnType = resourceType, + TargetTypeId = resourceType.TypeId, + TargetType = resourceType, + TargetParameterName = "builder", + ExpandedTargetTypes = [resourceType], + ReturnsBuilder = true, + CapabilityKind = AtsCapabilityKind.Method + } + ], + HandleTypes = + [ + new AtsTypeInfo + { + AtsTypeId = resourceType.TypeId, + ClrType = typeof(AcronymResource) + } + ], + DtoTypes = [], + EnumTypes = [] + }; + } + private static AtsContext CreateContextWithGenericInheritance() { var genericBaseType = typeof(GenericBaseResource>); @@ -513,6 +580,8 @@ private static AtsContext CreateContextWithGenericInheritance() private sealed class KeywordResource; + private sealed class AcronymResource; + private interface IGenericResource; private abstract class GenericBaseResource; diff --git a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py index 3427927302b..e41eb6d78cb 100644 --- a/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py +++ b/tests/Aspire.Hosting.CodeGeneration.Python.Tests/Snapshots/TwoPassScanningGeneratedAspire.verified.py @@ -6930,7 +6930,7 @@ class ContainerResourceKwargs(_BaseResourceKwargs, total=False): image_tag: str image_registry: str image: str | tuple[str, str] - image_s_h_a256: str + image_sha256: str container_runtime_args: typing.Iterable[str] lifetime: ContainerLifetime image_pull_policy: ImagePullPolicy @@ -7044,7 +7044,7 @@ def with_image(self, image: str, *, tag: str | None = None) -> typing.Self: self._handle = self._wrap_builder(result) return self - def with_image_s_h_a256(self, sha256: str) -> typing.Self: + def with_image_sha256(self, sha256: str) -> typing.Self: """Sets the image SHA256 digest""" rpc_args: dict[str, typing.Any] = {'builder': self._handle} rpc_args['sha256'] = sha256 @@ -7716,13 +7716,13 @@ def __init__(self, handle: Handle, client: AspireClient, **kwargs: typing.Unpack handle = self._wrap_builder(client.invoke_capability('Aspire.Hosting/withImage', rpc_args)) else: raise TypeError("Invalid type for option 'image'. Expected: str or (str, str)") - if _image_s_h_a256 := kwargs.pop("image_s_h_a256", None): - if _validate_type(_image_s_h_a256, str): + if _image_sha256 := kwargs.pop("image_sha256", None): + if _validate_type(_image_sha256, str): rpc_args: dict[str, typing.Any] = {"builder": handle} - rpc_args["sha256"] = typing.cast(str, _image_s_h_a256) + rpc_args["sha256"] = typing.cast(str, _image_sha256) handle = self._wrap_builder(client.invoke_capability('Aspire.Hosting/withImageSHA256', rpc_args)) else: - raise TypeError("Invalid type for option 'image_s_h_a256'. Expected: str") + raise TypeError("Invalid type for option 'image_sha256'. Expected: str") if _container_runtime_args := kwargs.pop("container_runtime_args", None): if _validate_type(_container_runtime_args, typing.Iterable[str]): rpc_args: dict[str, typing.Any] = {"builder": handle}