using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using System; using System.Collections.Generic; using System.Linq; using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; namespace Sandbox.Generator { internal class CodeGen { [Flags] internal enum Flags { WrapPropertyGet = 1, WrapPropertySet = 2, WrapMethod = 4, Static = 8, Instance = 16 } /// /// Find anything marked with [CodeGen] and perform the appropriate code generation. /// internal static void VisitMethod( ref MethodDeclarationSyntax node, IMethodSymbol symbol, Worker master ) { // This will be true for abstract methods... if ( (node.Body == null && node.ExpressionBody == null) || symbol.IsAbstract ) return; bool hasTarget = false; var attributesToWrite = new List(); var attributes = symbol.GetAttributes(); foreach ( var attribute in attributes ) { foreach ( var cg in GetCodeGeneratorAttributes( attribute ) ) { var type = (Flags)int.Parse( cg.GetArgumentValue( 0, "Type", "0" ) ); var callbackName = cg.GetArgumentValue( 1, "CallbackName", string.Empty ); if ( !type.Contains( Flags.WrapMethod ) ) continue; hasTarget = HandleWrapCall( attribute, type, callbackName, ref node, symbol, master ) || hasTarget; } // Include ALL the attributes when writing the static accessor AddAttributeString( attribute, attributesToWrite ); } if ( hasTarget && attributesToWrite.Count > 0 ) { var methodIdentity = MakeMethodIdentitySafe( GetUniqueMethodIdentity( symbol ) ); master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] static readonly global::System.Attribute[] __{methodIdentity}__Attrs = new global::System.Attribute[] {{ {string.Join( ", ", attributesToWrite )} }};\n", false ); } } private struct PropertyWrapperData { public AttributeData Attribute { get; set; } public string CallbackName { get; set; } public int Priority { get; set; } public Flags Type { get; set; } } internal static void VisitProperty( ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master ) { var attributesToWrite = new List(); var attributes = symbol.GetAttributes(); var generatedFields = new HashSet(); var data = new List(); foreach ( var attribute in attributes ) { foreach ( var cg in GetCodeGeneratorAttributes( attribute ) ) { var type = (Flags)int.Parse( cg.GetArgumentValue( 0, "Type", "0" ) ); var callbackName = cg.GetArgumentValue( 1, "CallbackName", string.Empty ); var priority = int.Parse( cg.GetArgumentValue( 2, "Priority", "0" ) ); if ( type.Contains( Flags.WrapPropertySet ) || type.Contains( Flags.WrapPropertyGet ) ) { data.Add( new() { Attribute = attribute, CallbackName = callbackName, Priority = priority, Type = type } ); } AddAttributeString( attribute, attributesToWrite ); } } data.Sort( ( a, b ) => b.Priority.CompareTo( a.Priority ) ); foreach ( var w in data ) { if ( w.Type.Contains( Flags.WrapPropertySet ) ) { HandleWrapSet( w.Attribute, w.Type, w.CallbackName, ref node, symbol, master, generatedFields ); } if ( w.Type.Contains( Flags.WrapPropertyGet ) ) { HandleWrapGet( w.Attribute, w.Type, w.CallbackName, ref node, symbol, master, generatedFields ); } } if ( attributesToWrite.Count > 0 ) { master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] static readonly global::System.Attribute[] __{symbol.Name}__Attrs = new global::System.Attribute[] {{ {string.Join( ", ", attributesToWrite )} }};\n", false ); } } private static void AddAttributeString( AttributeData attribute, List list ) { var sn = attribute.ApplicationSyntaxReference?.GetSyntax() as AttributeSyntax; if ( sn is null ) return; var attributeClassName = attribute.AttributeClass.FullName(); var propertyArguments = new List<(string, string)>(); var regularArguments = new List(); if ( !attributeClassName.EndsWith( "Attribute" ) ) attributeClassName += "Attribute"; var arguments = sn.ArgumentList?.Arguments.ToArray() ?? []; if ( arguments.Length == 0 ) { list.Add( $"new {attributeClassName}()" ); return; } foreach ( var syntax in arguments ) { if ( syntax.NameColon is not null ) propertyArguments.Add( (syntax.NameColon.Name.ToString(), syntax.Expression.ToString()) ); else if ( syntax.NameEquals != null ) propertyArguments.Add( (syntax.NameEquals.Name.ToString(), syntax.Expression.ToString()) ); else regularArguments.Add( syntax.Expression.ToString() ); } var output = $"new {attributeClassName}( {string.Join( ",", regularArguments )} ) {{ "; for ( var i = 0; i < propertyArguments.Count; i++ ) { var (k, v) = propertyArguments[i]; output += $"{k} = {v}"; if ( i < propertyArguments.Count - 1 ) { output += ", "; } } list.Add( $"{output} }}" ); } #region Property Wrapping /// /// Rewrites all occurrences of 'value' identifier to the specified parameter name. /// This is needed because the original setter body uses 'value', but our lambda uses a different parameter. /// private static BlockSyntax RewriteValueToParameter( BlockSyntax body, string parameterName ) { var rewriter = new ValueIdentifierRewriter( parameterName ); return (BlockSyntax)rewriter.Visit( body ); } private class ValueIdentifierRewriter : CSharpSyntaxRewriter { private readonly string _parameterName; public ValueIdentifierRewriter( string parameterName ) { _parameterName = parameterName; } public override SyntaxNode VisitIdentifierName( IdentifierNameSyntax node ) { if ( node.Identifier.Text == "value" ) { return node.WithIdentifier( Identifier( _parameterName ) ); } return base.VisitIdentifierName( node ); } } /// /// Gets the expression or body to use for reading the property value directly, /// bypassing the wrapped getter to prevent infinite recursion. /// private static CSharpSyntaxNode GetDirectGetterBody( AccessorDeclarationSyntax existingGetter ) { if ( existingGetter?.ExpressionBody is not null ) { // Expression body: get => _backingField; return existingGetter.ExpressionBody.Expression; } if ( existingGetter?.Body is not null ) { // Block body: get { return _backingField; } return existingGetter.Body; } // Auto-getter: use field keyword return FieldExpression(); } private static void HandleWrapSet( AttributeData attribute, Flags type, string callbackName, ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master, HashSet generatedFields ) { if ( symbol.IsStatic && !type.Contains( Flags.Static ) ) return; if ( !symbol.IsStatic && !type.Contains( Flags.Instance ) ) return; var typeToInvokeOn = symbol.ContainingType; var methodToInvoke = callbackName; var splitCallbackName = callbackName.Split( '.' ); var isStaticCallback = false; if ( splitCallbackName.Length > 1 ) { isStaticCallback = true; methodToInvoke = splitCallbackName[splitCallbackName.Length - 1]; var typeToLookFor = string.Join( ".", splitCallbackName.Take( splitCallbackName.Length - 1 ) ); typeToInvokeOn = master.GetOrCreateTypeByMetadataName( typeToLookFor ); if ( typeToInvokeOn is null ) { master.AddError( node.GetLocation(), $"Unable to find {typeToLookFor} required for {attribute.AttributeClass?.Name}. Ensure that a fully qualified callback name is used." ); return; } } if ( typeToInvokeOn is null || !ValidateSetterCallback( symbol.ContainingType, typeToInvokeOn, methodToInvoke, isStaticCallback, symbol.Type ) ) { master.AddError( node.GetLocation(), $"A method {callbackName}( WrappedPropertySet ) is required on {typeToInvokeOn?.Name}." ); return; } var propertyType = symbol.Type.FullName(); var accessors = new List(); var existingGetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.GetAccessorDeclaration ); var existingSetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.SetAccessorDeclaration ); if ( existingSetter is null ) { // There is no setter to wrap. return; } // Generate cached delegate field names (include attribute name for multiple attributes on the same property) var attributeSuffix = attribute.AttributeClass?.Name ?? "Unknown"; var setterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedSetter"; var getterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedSetterGetter"; var staticModifier = symbol.IsStatic ? "static " : ""; if ( generatedFields.Add( setterFieldName ) ) { master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Action<{propertyType}> {setterFieldName};\n", false ); } if ( generatedFields.Add( getterFieldName ) ) { master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Func<{propertyType}> {getterFieldName};\n", false ); } // GET accessor if ( existingGetter is not null ) { accessors.Add( existingGetter ); } // SET accessor { BlockSyntax setterInnerBody; if ( existingSetter.ExpressionBody is not null ) { var expr = existingSetter.ExpressionBody.Expression; setterInnerBody = Block( ExpressionStatement( expr ) ); } else if ( existingSetter.Body is not null ) { setterInnerBody = existingSetter.Body; } else { // Auto-setter: generate field = value; var assign = ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, FieldExpression(), IdentifierName( "value" ) ) ); setterInnerBody = Block( assign ); } // Rewrite 'value' to 'v' in the setter body for the lambda parameter var rewrittenSetterBody = RewriteValueToParameter( setterInnerBody, "v" ); var setterLambda = ParenthesizedLambdaExpression( ParameterList( SingletonSeparatedList( Parameter( Identifier( "v" ) ) ) ), rewrittenSetterBody ); var memberIdentity = $"{symbol.ContainingType.GetFullMetadataName().Replace( "global::", "" )}.{symbol.Name}"; var memberHash = memberIdentity.FastHash(); var wrappedType = ParseTypeName( $"global::Sandbox.WrappedPropertySet<{propertyType}>" ); // Cached setter: __CachedSetter ??= (v) => { ... } var cachedSetterExpr = AssignmentExpression( SyntaxKind.CoalesceAssignmentExpression, IdentifierName( setterFieldName ), setterLambda ); // Cached getter: __CachedGetter ??= () => PropertyName // Calls the property by name, which goes through all wrapped getters // This avoids inlining wrapped getter code which would cause recursion var getterLambda = ParenthesizedLambdaExpression( IdentifierName( symbol.Name ) ); var cachedGetterExpr = AssignmentExpression( SyntaxKind.CoalesceAssignmentExpression, IdentifierName( getterFieldName ), getterLambda ); var wrappedInitializerExpressions = new List { AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Value" ), IdentifierName( "value" ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Object" ), symbol.IsStatic ? LiteralExpression( SyntaxKind.NullLiteralExpression ) : ThisExpression() ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Setter" ), cachedSetterExpr ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Getter" ), cachedGetterExpr ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "IsStatic" ), LiteralExpression( symbol.IsStatic ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "TypeName" ), ParseExpression( symbol.ContainingType.FullName().Replace( "global::", "" ).QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "PropertyName" ), ParseExpression( symbol.Name.QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "MemberIdent" ), LiteralExpression( SyntaxKind.NumericLiteralExpression, Literal( memberHash ) ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Attributes" ), IdentifierName( $"__{symbol.Name}__Attrs" ) ) }; var parameterStructExpr = ObjectCreationExpression( wrappedType ) .WithInitializer( InitializerExpression( SyntaxKind.ObjectInitializerExpression, SeparatedList( wrappedInitializerExpressions ) ) ); var callbackExpr = ParseExpression( callbackName ); var argList = ArgumentList( SingletonSeparatedList( Argument( parameterStructExpr ) ) ); var invocation = InvocationExpression( callbackExpr, argList ); StatementSyntax[] statements = [ ExpressionStatement( invocation ) ]; var set = AccessorDeclaration( SyntaxKind.SetAccessorDeclaration ) .WithBody( Block( statements ) ) .WithModifiers( existingSetter.Modifiers ); accessors.Add( set ); node = node.WithAccessorList( AccessorList( List( accessors ) ) ) .NormalizeWhitespace(); } } private static void HandleWrapGet( AttributeData attribute, Flags type, string callbackName, ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master, HashSet generatedFields ) { if ( symbol.IsStatic && !type.Contains( Flags.Static ) ) return; if ( !symbol.IsStatic && !type.Contains( Flags.Instance ) ) return; var typeToInvokeOn = symbol.ContainingType; var methodToInvoke = callbackName; var splitCallbackName = callbackName.Split( '.' ); var isStaticCallback = false; if ( splitCallbackName.Length > 1 ) { isStaticCallback = true; methodToInvoke = splitCallbackName[splitCallbackName.Length - 1]; var typeToLookFor = string.Join( ".", splitCallbackName.Take( splitCallbackName.Length - 1 ) ); typeToInvokeOn = master.GetOrCreateTypeByMetadataName( typeToLookFor ); if ( typeToInvokeOn is null ) { master.AddError( node.GetLocation(), $"Unable to find {typeToLookFor} required for {attribute.AttributeClass?.Name}. Ensure that a fully qualified callback name is used." ); return; } } var propertyType = symbol.Type.FullName(); if ( typeToInvokeOn is null || !ValidateGetterCallback( symbol.ContainingType, typeToInvokeOn, methodToInvoke, isStaticCallback, symbol.Type ) ) { master.AddError( node.GetLocation(), $"A method {symbol.Type.Name} {methodToInvoke}( WrappedPropertyGet ) is required on {typeToInvokeOn?.Name}." ); return; } var accessors = new List(); var existingGetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.GetAccessorDeclaration ); var existingSetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.SetAccessorDeclaration ); if ( existingGetter is null ) { // There is no getter to wrap. return; } // Generate cached delegate field name (include attribute name for multiple attributes on same property) var attributeSuffix = attribute.AttributeClass?.Name ?? "Unknown"; var getterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedGetter"; var staticModifier = symbol.IsStatic ? "static " : ""; if ( generatedFields.Add( getterFieldName ) ) { master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Func<{propertyType}> {getterFieldName};\n", false ); } // SET accessor if ( existingSetter is not null ) { accessors.Add( existingSetter ); } // GET accessor { var statements = new List(); // Get the current getter body - this allows get wrappers to chain var directGetterBody = GetDirectGetterBody( existingGetter ); var getterLambda = ParenthesizedLambdaExpression( directGetterBody ); // Cached getter: __CachedGetter ??= () => var cachedGetterExpr = AssignmentExpression( SyntaxKind.CoalesceAssignmentExpression, IdentifierName( getterFieldName ), getterLambda ); // Invoke the cached getter to get the value var defaultValueExpression = InvocationExpression( ParenthesizedExpression( cachedGetterExpr ) ); var memberIdentity = $"{symbol.ContainingType.GetFullMetadataName().Replace( "global::", "" )}.{symbol.Name}"; var memberHash = memberIdentity.FastHash(); var wrappedType = ParseTypeName( $"global::Sandbox.WrappedPropertyGet<{propertyType}>" ); var wrappedInitializerExpressions = new List { AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Value" ), defaultValueExpression ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Object" ), symbol.IsStatic ? LiteralExpression( SyntaxKind.NullLiteralExpression ) : ThisExpression() ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "IsStatic" ), LiteralExpression( symbol.IsStatic ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "TypeName" ), ParseExpression( symbol.ContainingType.FullName().Replace( "global::", "" ).QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "PropertyName" ), ParseExpression( symbol.Name.QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "MemberIdent" ), LiteralExpression( SyntaxKind.NumericLiteralExpression, Literal( memberHash ) ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Attributes" ), IdentifierName( $"__{symbol.Name}__Attrs" ) ) }; var parameterStructExpr = ObjectCreationExpression( wrappedType ) .WithInitializer( InitializerExpression( SyntaxKind.ObjectInitializerExpression, SeparatedList( wrappedInitializerExpressions ) ) ); var callbackExpr = ParseExpression( callbackName ); var argList = ArgumentList( SingletonSeparatedList( Argument( parameterStructExpr ) ) ); var invocation = InvocationExpression( callbackExpr, argList ); var returnTypeSyntax = ParseTypeName( propertyType ); statements.Add( ReturnStatement( CastExpression( returnTypeSyntax, invocation ) ) ); var get = AccessorDeclaration( SyntaxKind.GetAccessorDeclaration ) .WithBody( Block( statements ) ) .WithModifiers( existingGetter.Modifiers ); accessors.Add( get ); node = node.WithAccessorList( AccessorList( List( accessors ) ) ) .NormalizeWhitespace(); } } #endregion #region Method Wrapping private static ExpressionSyntax BuildWrappedMethodExpression( IMethodSymbol symbol, CSharpSyntaxNode resumeBodyNode, int methodIdentity, bool usesObjectFallback = false ) { var hasReturn = !symbol.ReturnsVoid; string parameterStructGenericType; if ( !hasReturn ) { parameterStructGenericType = string.Empty; } else if ( usesObjectFallback ) { // Use object (or Task for async Task) var fullReturnType = symbol.ReturnType.FullName(); parameterStructGenericType = fullReturnType.StartsWith( "global::System.Threading.Tasks.Task<" ) ? ">" : ""; } else { parameterStructGenericType = $"<{symbol.ReturnType.FullName()}>"; } var wrappedTypeName = $"global::Sandbox.WrappedMethod{parameterStructGenericType}"; var wrappedType = ParseTypeName( wrappedTypeName ); var resumeLambda = ParenthesizedLambdaExpression( resumeBodyNode ); if ( symbol.IsAsync ) { resumeLambda = resumeLambda.WithAsyncKeyword( Token( SyntaxKind.AsyncKeyword ) ); } var typeName = symbol.ContainingType.FullName().Replace( "global::", "" ); var attrsFieldName = $"__{MakeMethodIdentitySafe( methodIdentity )}__Attrs"; ExpressionSyntax genericArgsExpression; if ( symbol.IsGenericMethod ) { var typeofExpressions = symbol.TypeArguments .Select( t => TypeOfExpression( ParseTypeName( t.ToDisplayString( SymbolDisplayFormat.FullyQualifiedFormat ) ) ) ) .ToArray(); genericArgsExpression = ImplicitArrayCreationExpression( InitializerExpression( SyntaxKind.ArrayInitializerExpression, SeparatedList( typeofExpressions ) ) ); } else { genericArgsExpression = LiteralExpression( SyntaxKind.NullLiteralExpression ); } var assignments = new List { AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Resume" ), resumeLambda ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Object" ), symbol.IsStatic ? LiteralExpression( SyntaxKind.NullLiteralExpression ) : ThisExpression() ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "MethodIdentity" ), LiteralExpression( SyntaxKind.NumericLiteralExpression, Literal( methodIdentity ) ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "MethodName" ), ParseExpression( symbol.Name.QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "TypeName" ), ParseExpression( typeName.QuoteSafe() ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "IsStatic" ), LiteralExpression( symbol.IsStatic ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "Attributes" ), IdentifierName( attrsFieldName ) ), AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, IdentifierName( "GenericArguments" ), genericArgsExpression ) }; return ObjectCreationExpression( wrappedType ) .WithInitializer( InitializerExpression( SyntaxKind.ObjectInitializerExpression, SeparatedList( assignments ) ) ); } private static ExpressionSyntax BuildCallbackInvocation( string callbackName, ExpressionSyntax wrappedMethodExpr, IEnumerable parameters ) { var callbackExpr = ParseExpression( callbackName ); var args = new List { Argument( wrappedMethodExpr ) }; args.AddRange( parameters.Select( p => Argument( IdentifierName( p.Name ) ) ) ); return InvocationExpression( callbackExpr, ArgumentList( SeparatedList( args ) ) ); } private static bool HandleWrapCall( AttributeData attribute, Flags type, string callbackName, ref MethodDeclarationSyntax node, IMethodSymbol symbol, Worker master ) { if ( node.Body == null && node.ExpressionBody == null ) return false; var parameterCount = symbol.Parameters.Count(); if ( symbol.IsStatic && !type.Contains( Flags.Static ) ) return false; if ( !symbol.IsStatic && !type.Contains( Flags.Instance ) ) return false; var usesObjectFallback = false; var typeToInvokeOn = symbol.ContainingType; var methodToInvoke = callbackName; var splitCallbackName = callbackName.Split( '.' ); var isStaticCallback = false; if ( splitCallbackName.Length > 1 ) { isStaticCallback = true; methodToInvoke = splitCallbackName[splitCallbackName.Length - 1]; var typeToLookFor = string.Join( ".", splitCallbackName.Take( splitCallbackName.Length - 1 ) ); typeToInvokeOn = master.GetOrCreateTypeByMetadataName( typeToLookFor ); if ( typeToInvokeOn is null ) { master.AddError( node.GetLocation(), $"Unable to find {typeToLookFor} required for {attribute.AttributeClass?.Name}. Ensure that a fully qualified callback name is used." ); return false; } } var success = false; if ( typeToInvokeOn is not null ) { success = ValidateMethodCallback( symbol.ContainingType, typeToInvokeOn, methodToInvoke, isStaticCallback, !symbol.ReturnsVoid ? symbol.ReturnType : null, parameterCount, out usesObjectFallback ); } if ( !success ) { var returnType = symbol.ReturnsVoid ? string.Empty : $"{symbol.ReturnType.Name} "; var paramsString = string.Join( ", ", Enumerable.Repeat( "Object", parameterCount ) ); if ( symbol.ReturnsVoid ) { master.AddError( node.GetLocation(), parameterCount > 0 ? $"A method {returnType}{methodToInvoke}( WrappedMethod, {paramsString} ) is required on {typeToInvokeOn?.Name}." : $"A method {returnType}{methodToInvoke}( WrappedMethod ) is required on {typeToInvokeOn?.Name}." ); } else { master.AddError( node.GetLocation(), parameterCount > 0 ? $"A method {returnType}{methodToInvoke}( WrappedMethod<{symbol.ReturnType.Name}>, {paramsString} ) is required on {typeToInvokeOn?.Name}." : $"A method {returnType}{methodToInvoke}( WrappedMethod<{symbol.ReturnType.Name}> ) is required on {typeToInvokeOn?.Name}." ); } return false; } // Capture original body/expression before we replace them var originalBody = node.Body; var originalExpressionBody = node.ExpressionBody; CSharpSyntaxNode resumeBodyNode; if ( originalBody is not null ) { resumeBodyNode = originalBody; } else { resumeBodyNode = originalExpressionBody.Expression; } var methodIdentity = GetUniqueMethodIdentity( symbol ); var wrappedMethodExpr = BuildWrappedMethodExpression( symbol, resumeBodyNode, methodIdentity, usesObjectFallback ); var callbackInvocation = BuildCallbackInvocation( callbackName, wrappedMethodExpr, symbol.Parameters ); var fullReturnType = symbol.ReturnType.FullName(); var isGenericTaskType = fullReturnType.StartsWith( "global::System.Threading.Tasks.Task<" ); var isTaskType = fullReturnType == "global::System.Threading.Tasks.Task"; if ( originalExpressionBody is null ) { List statements; if ( symbol.IsAsync ) { if ( isGenericTaskType ) { if ( usesObjectFallback ) { var innerType = (symbol.ReturnType as INamedTypeSymbol)?.TypeArguments[0]; var innerTypeSyntax = ParseTypeName( innerType.FullName() ); statements = [ LocalDeclarationStatement( VariableDeclaration( IdentifierName( "var" ) ) .WithVariables( SingletonSeparatedList( VariableDeclarator( "__result" ) .WithInitializer( EqualsValueClause( AwaitExpression( callbackInvocation ) ) ) ) ) ), ReturnStatement( CastExpression( innerTypeSyntax, IdentifierName( "__result" ) ) ) ]; } else { // return await Callback(...); statements = [ ReturnStatement( AwaitExpression( callbackInvocation ) ) ]; } } else if ( isTaskType ) { // await Callback(...); return; statements = [ ExpressionStatement( AwaitExpression( callbackInvocation ) ), ReturnStatement() ]; } else if ( symbol.ReturnsVoid ) { // Callback(...); statements = [ExpressionStatement( callbackInvocation )]; } else { // return Callback(...); statements = [ReturnStatement( callbackInvocation )]; } } else { var list = new List(); if ( symbol.ReturnsVoid ) { list.Add( ExpressionStatement( callbackInvocation ) ); } else { list.Add( ReturnStatement( callbackInvocation ) ); } statements = list; } var block = Block( statements ); var newBody = block.WithCloseBraceToken( block.CloseBraceToken.WithTrailingTrivia( SyntaxTriviaList.Empty ) ); node = node .WithBody( newBody ) .WithExpressionBody( null ) .WithSemicolonToken( Token( SyntaxKind.None ) ) .NormalizeWhitespace(); } else { if ( symbol.IsAsync && isTaskType ) { var awaitExpr = AwaitExpression( Token( SyntaxKind.AwaitKeyword ), callbackInvocation ); var statements = new StatementSyntax[] { ExpressionStatement( awaitExpr ), ReturnStatement() }; node = node .WithExpressionBody( null ) .WithSemicolonToken( Token( SyntaxKind.None ) ) .WithBody( Block( statements ) ) .NormalizeWhitespace(); } else { ExpressionSyntax expression = callbackInvocation; if ( symbol.IsAsync && isGenericTaskType ) { expression = AwaitExpression( callbackInvocation ); } node = node.WithExpressionBody( ArrowExpressionClause( expression ) ) .NormalizeWhitespace(); } } return true; } #endregion private static readonly Dictionary TypeAliases = new() { ["object"] = "System.Object", ["string"] = "System.String", ["bool"] = "System.Boolean", ["byte"] = "System.Byte", ["sbyte"] = "System.SByte", ["short"] = "System.Int16", ["ushort"] = "System.UInt16", ["int"] = "System.Int32", ["uint"] = "System.UInt32", ["long"] = "System.Int64", ["ulong"] = "System.UInt64", ["float"] = "System.Single", ["double"] = "System.Double", ["decimal"] = "System.Decimal", ["char"] = "System.Char" }; private static string SanitizeTypeName( ITypeSymbol type, bool fullName = false ) { if ( type is IArrayTypeSymbol a ) return $"{SanitizeTypeName( a.ElementType )}[]"; if ( !fullName ) { return TypeAliases.TryGetValue( type.Name, out var alias ) ? alias : type.Name; } return type.FullName() .Replace( "global::", "" ) .Split( '<' ) .FirstOrDefault(); } private static string GetUniqueMethodIdentityString( IMethodSymbol method ) { // Needs to keep in sync with Sandbox.MethodDescription.GetIdentityHashString() var returnTypeName = method.ReturnsVoid ? "Void" : SanitizeTypeName( method.ReturnType ); return $"{returnTypeName}.{SanitizeTypeName( method.ContainingType, true )}.{method.Name}.{string.Join( ",", method.Parameters.Select( p => SanitizeTypeName( p.Type ) ) )}"; } private static int GetUniqueMethodIdentity( IMethodSymbol method ) { return GetUniqueMethodIdentityString( method ).FastHash(); } private static string MakeMethodIdentitySafe( int identity ) { return identity.ToString().Replace( "-", "m_" ); } private static IEnumerable FetchValidMethods( INamedTypeSymbol parent, string methodName, bool isStatic = false, bool isRootType = false ) { var validMethods = parent.GetMembers().OfType() .Where( s => (!isStatic || s.IsStatic) && s.Name == methodName ) .Where( s => s.DeclaredAccessibility != Accessibility.Private || isRootType ); foreach ( var symbol in validMethods ) { yield return symbol; } // If our target method is static we shouldn't look at base types. if ( isStatic ) yield break; if ( parent.BaseType is null ) yield break; foreach ( var symbol in FetchValidMethods( parent.BaseType, methodName ) ) { yield return symbol; } } private static bool ValidateMethodCallback( INamedTypeSymbol containingType, INamedTypeSymbol parent, string methodName, bool isStatic, ITypeSymbol returnType, int argCount, out bool usesObjectFallback ) { usesObjectFallback = false; var validMethods = FetchValidMethods( parent, methodName, isStatic, SymbolEqualityComparer.Default.Equals( containingType, parent ) ); foreach ( var method in validMethods ) { if ( IsValidMethodCallback( method, returnType, argCount, requireExactType: true ) ) return true; } // Second pass: look for object fallback (only if we have a return type) if ( returnType is null ) return false; foreach ( var method in validMethods ) { if ( !IsValidMethodCallback( method, returnType, argCount, requireExactType: false ) ) continue; usesObjectFallback = true; return true; } return false; } private static bool IsValidMethodCallback( IMethodSymbol method, ITypeSymbol returnType, int argCount, bool requireExactType ) { var hasObjectParams = method.Parameters.Length > 1 && method.Parameters[1].IsParams && method.Parameters[1].Type.FullName() == "object[]"; if ( !hasObjectParams && method.Parameters.Length != argCount + 1 ) return false; var firstParameterType = method.Parameters[0].Type; var firstParameterName = firstParameterType.FullName(); if ( returnType is null ) { return firstParameterName == "global::Sandbox.WrappedMethod"; } if ( !firstParameterName.StartsWith( "global::Sandbox.WrappedMethod<" ) ) return false; var namedParam = firstParameterType as INamedTypeSymbol; var wrappedArg = namedParam?.TypeArguments[0]; if ( wrappedArg is null ) return false; if ( requireExactType ) { // Exact match or compatible generic if ( !SymbolEqualityComparer.Default.Equals( wrappedArg, returnType ) && !IsTypeCompatible( wrappedArg, returnType ) ) { return false; } var cbReturn = method.ReturnType; if ( !SymbolEqualityComparer.Default.Equals( cbReturn, returnType ) && !IsTypeCompatible( cbReturn, returnType ) && cbReturn is not ITypeParameterSymbol ) { return false; } } else { if ( !IsObjectOrTaskOfObject( wrappedArg, returnType ) ) return false; var cbReturn = method.ReturnType; if ( !IsObjectOrTaskOfObject( cbReturn, returnType ) && cbReturn is not ITypeParameterSymbol ) return false; } return true; } private static bool IsObjectOrTaskOfObject( ITypeSymbol candidate, ITypeSymbol targetForShape ) { if ( candidate.SpecialType == SpecialType.System_Object ) return true; // Check for Task when the target is Task if ( candidate is not INamedTypeSymbol namedCandidate || targetForShape is not INamedTypeSymbol namedTarget ) return false; // Both must be generic Task var isCandidateGenericTask = namedCandidate.Name == "Task" && namedCandidate.TypeArguments.Length == 1 && namedCandidate.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks"; var isTargetGenericTask = namedTarget.Name == "Task" && namedTarget.TypeArguments.Length == 1 && namedTarget.ContainingNamespace?.ToDisplayString() == "System.Threading.Tasks"; if ( !isCandidateGenericTask || !isTargetGenericTask ) return false; // Candidate should be Task return namedCandidate.TypeArguments[0].SpecialType == SpecialType.System_Object; } private static bool IsTypeCompatible( ITypeSymbol candidate, ITypeSymbol target ) { if ( candidate is ITypeParameterSymbol ) return true; if ( candidate is not INamedTypeSymbol namedCandidate || target is not INamedTypeSymbol namedTarget ) return false; if ( !SymbolEqualityComparer.Default.Equals( namedCandidate.OriginalDefinition, namedTarget.OriginalDefinition ) ) return false; var candidateArgs = namedCandidate.TypeArguments; var targetArgs = namedTarget.TypeArguments; if ( candidateArgs.Length != targetArgs.Length ) return false; for ( var i = 0; i < candidateArgs.Length; i++ ) { var candidateArg = candidateArgs[i]; var targetArg = targetArgs[i]; if ( SymbolEqualityComparer.Default.Equals( candidateArg, targetArg ) ) continue; if ( candidateArg is ITypeParameterSymbol ) continue; if ( candidateArg is not INamedTypeSymbol candidateNamedArg || targetArg is not INamedTypeSymbol targetNamedArg ) return false; if ( !IsTypeCompatible( candidateNamedArg, targetNamedArg ) ) return false; } return true; } private static bool ValidateSetterCallback( INamedTypeSymbol containingType, INamedTypeSymbol parent, string methodName, bool isStatic, ITypeSymbol propertyType ) { var validMethods = FetchValidMethods( parent, methodName, isStatic, SymbolEqualityComparer.Default.Equals( containingType, parent ) ); foreach ( var method in validMethods ) { if ( method.Parameters.Count() != 1 ) continue; if ( !method.Parameters[0].Type.FullName().StartsWith( "global::Sandbox.WrappedPropertySet<" ) ) continue; var namedParameterType = method.Parameters[0].Type as INamedTypeSymbol; if ( !SymbolEqualityComparer.Default.Equals( namedParameterType?.TypeArguments[0], propertyType ) && namedParameterType?.TypeArguments[0] is not ITypeParameterSymbol ) continue; return true; } return false; } private static bool ValidateGetterCallback( INamedTypeSymbol containingType, INamedTypeSymbol parent, string methodName, bool isStatic, ITypeSymbol propertyType ) { var validMethods = FetchValidMethods( parent, methodName, isStatic, SymbolEqualityComparer.Default.Equals( containingType, parent ) ); foreach ( var method in validMethods ) { if ( method.Parameters.Count() != 1 ) continue; if ( !method.Parameters[0].Type.FullName().StartsWith( "global::Sandbox.WrappedPropertyGet<" ) ) continue; var namedParameterType = method.Parameters[0].Type as INamedTypeSymbol; if ( !SymbolEqualityComparer.Default.Equals( namedParameterType?.TypeArguments[0], propertyType ) && namedParameterType?.TypeArguments[0] is not ITypeParameterSymbol ) continue; return true; } return false; } private static bool IsCodeGeneratorAttribute( AttributeData attribute ) { return attribute.AttributeClass.FullName() == "global::Sandbox.CodeGeneratorAttribute"; } private static IEnumerable GetCodeGeneratorAttributes( AttributeData parent ) { return parent.AttributeClass?.GetAttributes().Where( IsCodeGeneratorAttribute ); } } }