I'm playing with the idea of tampering with the state of a tasks internal state machine, but i'm having trouble finding a way to actually access the state machine reference within my task method.
class Test
{
async Task IndexAsync()
{
var nottheactualtype = GetType(); //This references the "Test" class, but this operation is actually located in the nested state machine class named "IndexAsync", in the method "MoveNext()".
var actualcalledmethod = new StackTrace().GetFrame(0).GetMethod(); //This shows the actual method currently being run: IndexAsync.MoveNext().
//But how do I get the reference to my current IndexAsync class?
}
}
How can I get access to the reference of the generated state machine currently being run?
It's nasty, and it's not guaranteed to work (it depends on implementation details) - but this works for me... it basically provokes the state machine to pass a continuation to an awaiter. We can then get the state machine out of the continuation delegate's target.
Ugly, ugly, ugly code... but it does work for me :)
using System;
using System.Reflection;
using System.Threading.Tasks;
using System.Runtime.CompilerServices;
using static System.Reflection.BindingFlags;
public class StateMachineProvider
{
private static readonly StateMachineProvider instance = new StateMachineProvider();
public static StateMachineProvider GetStateMachine() => instance;
public StateMachineAwaiter GetAwaiter() => new StateMachineAwaiter();
public class StateMachineAwaiter : INotifyCompletion
{
private Action continuation;
public bool IsCompleted => continuation != null;
public void OnCompleted(Action continuation)
{
this.continuation = continuation;
// Fire the continuation in a separate task.
// (We shouldn't just call it synchronously.)
Task.Run(continuation);
}
public IAsyncStateMachine GetResult()
{
var target = continuation.Target;
var field = target.GetType()
.GetField("m_stateMachine", NonPublic | Instance);
return (IAsyncStateMachine) field.GetValue(target);
}
}
}
class Test
{
static void Main()
{
AsyncMethod().Wait();
}
static async Task AsyncMethod()
{
int x = 10;
IAsyncStateMachine machine = await StateMachineProvider.GetStateMachine();
Console.WriteLine($"x={x}"); // Force the use across an await boundary
Console.WriteLine($"State machine type: {machine.GetType()})");
Console.WriteLine("Fields:");
var fields = machine.GetType().GetFields(Public | NonPublic | Instance);
foreach (var field in fields)
{
Console.WriteLine($"{field.Name}: {field.GetValue(machine)}");
}
}
}
Output:
x=10
State machine type: Test+<AsyncMethod>d__1)
Fields:
<>1__state: -1
<>t__builder: System.Runtime.CompilerServices.AsyncTaskMethodBuilder
<x>5__1: 10
<machine>5__2: Test+<AsyncMethod>d__1
<>s__3:
<>s__4: System.Reflection.FieldInfo[]
<>s__5: 6
<field>5__6: System.Reflection.FieldInfo <field>5__6
<>u__1:
See more on this question at Stackoverflow