using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System;
using System.Collections.Generic;
using System.Linq;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
namespace Sandbox.Generator
{
internal class CodeGen
{
[Flags]
internal enum Flags
{
WrapPropertyGet = 1,
WrapPropertySet = 2,
WrapMethod = 4,
Static = 8,
Instance = 16
}
///
/// Find anything marked with [CodeGen] and perform the appropriate code generation.
///
internal static void VisitMethod( ref MethodDeclarationSyntax node, IMethodSymbol symbol, Worker master )
{
// This will be true for abstract methods...
if ( (node.Body == null && node.ExpressionBody == null) || symbol.IsAbstract ) return;
bool hasTarget = false;
var attributesToWrite = new List();
var attributes = symbol.GetAttributes();
foreach ( var attribute in attributes )
{
foreach ( var cg in GetCodeGeneratorAttributes( attribute ) )
{
var type = (Flags)int.Parse( cg.GetArgumentValue( 0, "Type", "0" ) );
var callbackName = cg.GetArgumentValue( 1, "CallbackName", string.Empty );
if ( !type.Contains( Flags.WrapMethod ) ) continue;
hasTarget = HandleWrapCall( attribute, type, callbackName, ref node, symbol, master ) || hasTarget;
}
// Include ALL the attributes when writing the static accessor
AddAttributeString( attribute, attributesToWrite );
}
if ( hasTarget && attributesToWrite.Count > 0 )
{
var methodIdentity = MakeMethodIdentitySafe( GetUniqueMethodIdentity( symbol ) );
master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] static readonly global::System.Attribute[] __{methodIdentity}__Attrs = new global::System.Attribute[] {{ {string.Join( ", ", attributesToWrite )} }};\n", false );
}
}
private struct PropertyWrapperData
{
public AttributeData Attribute { get; set; }
public string CallbackName { get; set; }
public int Priority { get; set; }
public Flags Type { get; set; }
}
internal static void VisitProperty( ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master )
{
var attributesToWrite = new List();
var attributes = symbol.GetAttributes();
var generatedFields = new HashSet();
var data = new List();
foreach ( var attribute in attributes )
{
foreach ( var cg in GetCodeGeneratorAttributes( attribute ) )
{
var type = (Flags)int.Parse( cg.GetArgumentValue( 0, "Type", "0" ) );
var callbackName = cg.GetArgumentValue( 1, "CallbackName", string.Empty );
var priority = int.Parse( cg.GetArgumentValue( 2, "Priority", "0" ) );
if ( type.Contains( Flags.WrapPropertySet ) || type.Contains( Flags.WrapPropertyGet ) )
{
data.Add( new()
{
Attribute = attribute,
CallbackName = callbackName,
Priority = priority,
Type = type
} );
}
AddAttributeString( attribute, attributesToWrite );
}
}
data.Sort( ( a, b ) => b.Priority.CompareTo( a.Priority ) );
foreach ( var w in data )
{
if ( w.Type.Contains( Flags.WrapPropertySet ) )
{
HandleWrapSet( w.Attribute, w.Type, w.CallbackName, ref node, symbol, master, generatedFields );
}
if ( w.Type.Contains( Flags.WrapPropertyGet ) )
{
HandleWrapGet( w.Attribute, w.Type, w.CallbackName, ref node, symbol, master, generatedFields );
}
}
if ( attributesToWrite.Count > 0 )
{
master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] static readonly global::System.Attribute[] __{symbol.Name}__Attrs = new global::System.Attribute[] {{ {string.Join( ", ", attributesToWrite )} }};\n", false );
}
}
private static void AddAttributeString( AttributeData attribute, List list )
{
var sn = attribute.ApplicationSyntaxReference?.GetSyntax() as AttributeSyntax;
if ( sn is null ) return;
var attributeClassName = attribute.AttributeClass.FullName();
var propertyArguments = new List<(string, string)>();
var regularArguments = new List();
if ( !attributeClassName.EndsWith( "Attribute" ) )
attributeClassName += "Attribute";
var arguments = sn.ArgumentList?.Arguments.ToArray() ?? [];
if ( arguments.Length == 0 )
{
list.Add( $"new {attributeClassName}()" );
return;
}
foreach ( var syntax in arguments )
{
if ( syntax.NameColon is not null )
propertyArguments.Add( (syntax.NameColon.Name.ToString(), syntax.Expression.ToString()) );
else if ( syntax.NameEquals != null )
propertyArguments.Add( (syntax.NameEquals.Name.ToString(), syntax.Expression.ToString()) );
else
regularArguments.Add( syntax.Expression.ToString() );
}
var output = $"new {attributeClassName}( {string.Join( ",", regularArguments )} ) {{ ";
for ( var i = 0; i < propertyArguments.Count; i++ )
{
var (k, v) = propertyArguments[i];
output += $"{k} = {v}";
if ( i < propertyArguments.Count - 1 )
{
output += ", ";
}
}
list.Add( $"{output} }}" );
}
#region Property Wrapping
///
/// Rewrites all occurrences of 'value' identifier to the specified parameter name.
/// This is needed because the original setter body uses 'value', but our lambda uses a different parameter.
///
private static BlockSyntax RewriteValueToParameter( BlockSyntax body, string parameterName )
{
var rewriter = new ValueIdentifierRewriter( parameterName );
return (BlockSyntax)rewriter.Visit( body );
}
private class ValueIdentifierRewriter : CSharpSyntaxRewriter
{
private readonly string _parameterName;
public ValueIdentifierRewriter( string parameterName )
{
_parameterName = parameterName;
}
public override SyntaxNode VisitIdentifierName( IdentifierNameSyntax node )
{
if ( node.Identifier.Text == "value" )
{
return node.WithIdentifier( Identifier( _parameterName ) );
}
return base.VisitIdentifierName( node );
}
}
///
/// Gets the expression or body to use for reading the property value directly,
/// bypassing the wrapped getter to prevent infinite recursion.
///
private static CSharpSyntaxNode GetDirectGetterBody( AccessorDeclarationSyntax existingGetter )
{
if ( existingGetter?.ExpressionBody is not null )
{
// Expression body: get => _backingField;
return existingGetter.ExpressionBody.Expression;
}
if ( existingGetter?.Body is not null )
{
// Block body: get { return _backingField; }
return existingGetter.Body;
}
// Auto-getter: use field keyword
return FieldExpression();
}
private static void HandleWrapSet( AttributeData attribute, Flags type, string callbackName, ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master, HashSet generatedFields )
{
if ( symbol.IsStatic && !type.Contains( Flags.Static ) )
return;
if ( !symbol.IsStatic && !type.Contains( Flags.Instance ) )
return;
var typeToInvokeOn = symbol.ContainingType;
var methodToInvoke = callbackName;
var splitCallbackName = callbackName.Split( '.' );
var isStaticCallback = false;
if ( splitCallbackName.Length > 1 )
{
isStaticCallback = true;
methodToInvoke = splitCallbackName[splitCallbackName.Length - 1];
var typeToLookFor = string.Join( ".", splitCallbackName.Take( splitCallbackName.Length - 1 ) );
typeToInvokeOn = master.GetOrCreateTypeByMetadataName( typeToLookFor );
if ( typeToInvokeOn is null )
{
master.AddError( node.GetLocation(),
$"Unable to find {typeToLookFor} required for {attribute.AttributeClass?.Name}. Ensure that a fully qualified callback name is used." );
return;
}
}
if ( typeToInvokeOn is null || !ValidateSetterCallback( symbol.ContainingType, typeToInvokeOn, methodToInvoke, isStaticCallback, symbol.Type ) )
{
master.AddError( node.GetLocation(),
$"A method {callbackName}( WrappedPropertySet ) is required on {typeToInvokeOn?.Name}." );
return;
}
var propertyType = symbol.Type.FullName();
var accessors = new List();
var existingGetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.GetAccessorDeclaration );
var existingSetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.SetAccessorDeclaration );
if ( existingSetter is null )
{
// There is no setter to wrap.
return;
}
// Generate cached delegate field names (include attribute name for multiple attributes on the same property)
var attributeSuffix = attribute.AttributeClass?.Name ?? "Unknown";
var setterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedSetter";
var getterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedSetterGetter";
var staticModifier = symbol.IsStatic ? "static " : "";
if ( generatedFields.Add( setterFieldName ) )
{
master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Action<{propertyType}> {setterFieldName};\n", false );
}
if ( generatedFields.Add( getterFieldName ) )
{
master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Func<{propertyType}> {getterFieldName};\n", false );
}
// GET accessor
if ( existingGetter is not null )
{
accessors.Add( existingGetter );
}
// SET accessor
{
BlockSyntax setterInnerBody;
if ( existingSetter.ExpressionBody is not null )
{
var expr = existingSetter.ExpressionBody.Expression;
setterInnerBody = Block( ExpressionStatement( expr ) );
}
else if ( existingSetter.Body is not null )
{
setterInnerBody = existingSetter.Body;
}
else
{
// Auto-setter: generate field = value;
var assign = ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
FieldExpression(),
IdentifierName( "value" ) ) );
setterInnerBody = Block( assign );
}
// Rewrite 'value' to 'v' in the setter body for the lambda parameter
var rewrittenSetterBody = RewriteValueToParameter( setterInnerBody, "v" );
var setterLambda = ParenthesizedLambdaExpression(
ParameterList(
SingletonSeparatedList(
Parameter( Identifier( "v" ) ) ) ),
rewrittenSetterBody );
var memberIdentity = $"{symbol.ContainingType.GetFullMetadataName().Replace( "global::", "" )}.{symbol.Name}";
var memberHash = memberIdentity.FastHash();
var wrappedType = ParseTypeName( $"global::Sandbox.WrappedPropertySet<{propertyType}>" );
// Cached setter: __CachedSetter ??= (v) => { ... }
var cachedSetterExpr = AssignmentExpression(
SyntaxKind.CoalesceAssignmentExpression,
IdentifierName( setterFieldName ),
setterLambda );
// Cached getter: __CachedGetter ??= () => PropertyName
// Calls the property by name, which goes through all wrapped getters
// This avoids inlining wrapped getter code which would cause recursion
var getterLambda = ParenthesizedLambdaExpression( IdentifierName( symbol.Name ) );
var cachedGetterExpr = AssignmentExpression(
SyntaxKind.CoalesceAssignmentExpression,
IdentifierName( getterFieldName ),
getterLambda );
var wrappedInitializerExpressions = new List
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Value" ),
IdentifierName( "value" ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Object" ),
symbol.IsStatic
? LiteralExpression( SyntaxKind.NullLiteralExpression )
: ThisExpression() ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Setter" ),
cachedSetterExpr ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Getter" ),
cachedGetterExpr ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "IsStatic" ),
LiteralExpression( symbol.IsStatic ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "TypeName" ),
ParseExpression( symbol.ContainingType.FullName().Replace( "global::", "" ).QuoteSafe() ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "PropertyName" ),
ParseExpression( symbol.Name.QuoteSafe() ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "MemberIdent" ),
LiteralExpression( SyntaxKind.NumericLiteralExpression, Literal( memberHash ) ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Attributes" ),
IdentifierName( $"__{symbol.Name}__Attrs" ) )
};
var parameterStructExpr =
ObjectCreationExpression( wrappedType )
.WithInitializer(
InitializerExpression(
SyntaxKind.ObjectInitializerExpression,
SeparatedList( wrappedInitializerExpressions ) ) );
var callbackExpr = ParseExpression( callbackName );
var argList = ArgumentList(
SingletonSeparatedList(
Argument( parameterStructExpr ) ) );
var invocation = InvocationExpression( callbackExpr, argList );
StatementSyntax[] statements =
[
ExpressionStatement( invocation )
];
var set = AccessorDeclaration( SyntaxKind.SetAccessorDeclaration )
.WithBody( Block( statements ) )
.WithModifiers( existingSetter.Modifiers );
accessors.Add( set );
node = node.WithAccessorList( AccessorList( List( accessors ) ) )
.NormalizeWhitespace();
}
}
private static void HandleWrapGet( AttributeData attribute, Flags type, string callbackName, ref PropertyDeclarationSyntax node, IPropertySymbol symbol, Worker master, HashSet generatedFields )
{
if ( symbol.IsStatic && !type.Contains( Flags.Static ) )
return;
if ( !symbol.IsStatic && !type.Contains( Flags.Instance ) )
return;
var typeToInvokeOn = symbol.ContainingType;
var methodToInvoke = callbackName;
var splitCallbackName = callbackName.Split( '.' );
var isStaticCallback = false;
if ( splitCallbackName.Length > 1 )
{
isStaticCallback = true;
methodToInvoke = splitCallbackName[splitCallbackName.Length - 1];
var typeToLookFor = string.Join( ".", splitCallbackName.Take( splitCallbackName.Length - 1 ) );
typeToInvokeOn = master.GetOrCreateTypeByMetadataName( typeToLookFor );
if ( typeToInvokeOn is null )
{
master.AddError( node.GetLocation(),
$"Unable to find {typeToLookFor} required for {attribute.AttributeClass?.Name}. Ensure that a fully qualified callback name is used." );
return;
}
}
var propertyType = symbol.Type.FullName();
if ( typeToInvokeOn is null || !ValidateGetterCallback( symbol.ContainingType, typeToInvokeOn, methodToInvoke, isStaticCallback, symbol.Type ) )
{
master.AddError( node.GetLocation(),
$"A method {symbol.Type.Name} {methodToInvoke}( WrappedPropertyGet ) is required on {typeToInvokeOn?.Name}." );
return;
}
var accessors = new List();
var existingGetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.GetAccessorDeclaration );
var existingSetter = node.AccessorList?.Accessors.FirstOrDefault( a => a.Kind() == SyntaxKind.SetAccessorDeclaration );
if ( existingGetter is null )
{
// There is no getter to wrap.
return;
}
// Generate cached delegate field name (include attribute name for multiple attributes on same property)
var attributeSuffix = attribute.AttributeClass?.Name ?? "Unknown";
var getterFieldName = $"__{symbol.Name}_{attributeSuffix}__CachedGetter";
var staticModifier = symbol.IsStatic ? "static " : "";
if ( generatedFields.Add( getterFieldName ) )
{
master.AddToCurrentClass( $"[global::Sandbox.SkipHotload] private {staticModifier}global::System.Func<{propertyType}> {getterFieldName};\n", false );
}
// SET accessor
if ( existingSetter is not null )
{
accessors.Add( existingSetter );
}
// GET accessor
{
var statements = new List();
// Get the current getter body - this allows get wrappers to chain
var directGetterBody = GetDirectGetterBody( existingGetter );
var getterLambda = ParenthesizedLambdaExpression( directGetterBody );
// Cached getter: __CachedGetter ??= () =>
var cachedGetterExpr = AssignmentExpression(
SyntaxKind.CoalesceAssignmentExpression,
IdentifierName( getterFieldName ),
getterLambda );
// Invoke the cached getter to get the value
var defaultValueExpression = InvocationExpression(
ParenthesizedExpression( cachedGetterExpr ) );
var memberIdentity = $"{symbol.ContainingType.GetFullMetadataName().Replace( "global::", "" )}.{symbol.Name}";
var memberHash = memberIdentity.FastHash();
var wrappedType = ParseTypeName( $"global::Sandbox.WrappedPropertyGet<{propertyType}>" );
var wrappedInitializerExpressions = new List
{
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Value" ),
defaultValueExpression ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Object" ),
symbol.IsStatic
? LiteralExpression( SyntaxKind.NullLiteralExpression )
: ThisExpression() ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "IsStatic" ),
LiteralExpression( symbol.IsStatic ? SyntaxKind.TrueLiteralExpression : SyntaxKind.FalseLiteralExpression ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "TypeName" ),
ParseExpression( symbol.ContainingType.FullName().Replace( "global::", "" ).QuoteSafe() ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "PropertyName" ),
ParseExpression( symbol.Name.QuoteSafe() ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "MemberIdent" ),
LiteralExpression( SyntaxKind.NumericLiteralExpression, Literal( memberHash ) ) ),
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
IdentifierName( "Attributes" ),
IdentifierName( $"__{symbol.Name}__Attrs" ) )
};
var parameterStructExpr =
ObjectCreationExpression( wrappedType )
.WithInitializer(
InitializerExpression(
SyntaxKind.ObjectInitializerExpression,
SeparatedList( wrappedInitializerExpressions ) ) );
var callbackExpr = ParseExpression( callbackName );
var argList = ArgumentList(
SingletonSeparatedList(
Argument( parameterStructExpr ) ) );
var invocation = InvocationExpression( callbackExpr, argList );
var returnTypeSyntax = ParseTypeName( propertyType );
statements.Add(
ReturnStatement(
CastExpression(
returnTypeSyntax,
invocation ) ) );
var get = AccessorDeclaration( SyntaxKind.GetAccessorDeclaration )
.WithBody( Block( statements ) )
.WithModifiers( existingGetter.Modifiers );
accessors.Add( get );
node = node.WithAccessorList( AccessorList( List( accessors ) ) )
.NormalizeWhitespace();
}
}
#endregion
#region Method Wrapping
private static ExpressionSyntax BuildWrappedMethodExpression( IMethodSymbol symbol, CSharpSyntaxNode resumeBodyNode, int methodIdentity, bool usesObjectFallback = false )
{
var hasReturn = !symbol.ReturnsVoid;
string parameterStructGenericType;
if ( !hasReturn )
{
parameterStructGenericType = string.Empty;
}
else if ( usesObjectFallback )
{
// Use object (or Task