1 module mecca.reactor.impl.fibril;
2 
3 // Licensed under the Boost license. Full copyright information in the AUTHORS file
4 
5 import mecca.lib.exception;
6 import mecca.log;
7 
8 // Disable tracing instrumentation for the whole file
9 @("notrace") void traceDisableCompileTimeInstrumentation();
10 
11 version (D_InlineAsm_X86_64) version (Posix) {
12     private pure nothrow @trusted @nogc:
13 
14     void* _fibril_init_stack(void[] stackArea, void function(void*) nothrow fn, void* opaque) {
15         // set rsp to top of stack, and make sure it's 16-byte aligned
16         auto rsp = cast(void*)((((cast(size_t)stackArea.ptr) + stackArea.length) >> 4) << 4);
17         auto rbp = rsp;
18 
19         void push(void* v) nothrow pure @nogc {
20             rsp -= v.sizeof;
21             *(cast(void**)rsp) = v;
22         }
23 
24         push(null);                     // Fake RET of entrypoint
25         push(&_fibril_trampoline);      // RIP
26         push(rbp);                      // RBP
27         push(null);                     // RBX
28         push(null);                     // R12
29         push(null);                     // R13
30         push(fn);                       // R14
31         push(opaque);                   // R15
32 
33         return rsp;
34     }
35 
36     extern(C) void _fibril_trampoline() nothrow {
37         pragma(inline, false);
38         asm pure nothrow @nogc {
39             naked;
40             mov RDI, R14;  // fn
41             mov RSI, R15;  // opaque
42 
43             // this has to be a jmp (not a call), otherwise exception-handling will see
44             // this function in the stack and be... unhappy
45             jmp _fibril_wrapper;
46         }
47     }
48 
49     extern(C) void _fibril_switch(void** fromRSP /* RDI */, void* toRSP /* RSI */, void** rspForGc /* RDX */) {
50         pragma(inline, false);
51         asm pure nothrow @nogc {
52             naked;
53 
54             // save current state, then store RSP into `fromRSP`
55             // RET is already pushed at TOS
56             push RBP;
57             push RBX;
58             push R12;
59             push R13;
60             push R14;
61             push R15;
62             mov [RDI], RSP;
63             mov [RDX], RSP;
64 
65             // set RSP to `toRSP` and load state
66             // and return to caller (RET is at TOS)
67             mov RSP, RSI;
68             pop R15;
69             pop R14;
70             pop R13;
71             pop R12;
72             pop RBX;
73             pop RBP;
74             ret;
75         }
76     }
77 }
78 
79 
80 extern(C) private void _fibril_wrapper(void function(void*) fn /* RDI */, void* opaque /* RSI */) {
81     import core.stdc.stdlib: abort;
82     void writeErr(const(char[]) text) {
83         import core.sys.posix.unistd: write;
84         // Write error directly to stderr
85         write(2, text.ptr, text.length);
86     }
87 
88     try {
89         fn(opaque);
90         DIE("Fibril function must never return", __FILE_FULL_PATH__, __LINE__, true);
91     }
92     catch (Throwable ex) {
93         LOG_EXCEPTION(ex);
94         DIE("Fibril function must never throw", __FILE_FULL_PATH__, __LINE__, true);
95     }
96     // we add an extra call to abort here, so the compiler would be forced to emit `call` instead of `jmp`
97     // above, thus leaving this function on the call stack. it produces a more readable backtrace.
98     abort();
99 }
100 
101 
102 struct Fibril {
103     void* rsp;
104 
105     void reset() nothrow @nogc {
106         rsp = null;
107     }
108     void set(void[] stackArea, void function(void*) nothrow fn, void* opaque) nothrow @nogc {
109         assert (rsp is null, "already set");
110         rsp = _fibril_init_stack(stackArea, fn, opaque);
111     }
112     void set(void[] stackArea, void delegate() nothrow dg) nothrow @nogc {
113         set(stackArea, cast(void function(void*) nothrow)dg.funcptr, dg.ptr);
114     }
115     void switchTo(ref Fibril next, void** rspForGc) nothrow @trusted @nogc {
116         pragma(inline, true);
117         _fibril_switch(&this.rsp, next.rsp, rspForGc);
118     }
119 }
120 
121 
122 unittest {
123     import std.stdio;
124     import std.range;
125 
126     ubyte[4096] stack1;
127     ubyte[4096] stack2;
128 
129     Fibril mainFib, fib1, fib2;
130     void*  mainGcRsp, fib1GcRsp, fib2GcRsp;
131     char[] order;
132 
133     void func1() nothrow {
134         while (true) {
135             order ~= '1';
136             //try{writefln("in fib1");} catch(Throwable){}
137             fib1.switchTo(fib2, &fib1GcRsp);
138         }
139     }
140     void func2() nothrow {
141         while (true) {
142             order ~= '2';
143             //try{writefln("in fib2");} catch(Throwable){}
144             fib2.switchTo(mainFib, &fib2GcRsp);
145         }
146     }
147 
148     fib1.set(stack1, &func1);
149     fib2.set(stack2, &func2);
150 
151     enum ITERS = 10;
152     order.reserve(ITERS * 3);
153 
154     foreach(_; 0 .. ITERS) {
155         order ~= 'M';
156         //try{writefln("in main");} catch(Throwable){}
157         mainFib.switchTo(fib1, &mainGcRsp);
158     }
159 
160     assert (order == "M12".repeat(ITERS).join(""), order);
161 }
162 
163 unittest {
164     import std.stdio;
165 
166     ubyte[4096] stack1;
167     ubyte[4096] stack2;
168 
169     Fibril mainFib, fib1, fib2;
170     void*  mainGcRsp, fib1GcRsp, fib2GcRsp;
171     size_t counter;
172 
173     void func1() nothrow {
174         while (true) {
175             counter++;
176             fib1.switchTo(fib2, &fib1GcRsp);
177         }
178     }
179     void func2() nothrow {
180         while (true) {
181             fib2.switchTo(mainFib, &fib2GcRsp);
182         }
183     }
184 
185     fib1.set(stack1, &func1);
186     fib2.set(stack2, &func2);
187 
188     enum ITERS = 10_000_000;
189 
190     import mecca.lib.time: TscTimePoint;
191     auto t0 = TscTimePoint.hardNow;
192     foreach(_; 0 .. ITERS) {
193         mainFib.switchTo(fib1, &mainGcRsp);
194     }
195     auto dt = TscTimePoint.hardNow.diff!"cycles"(t0);
196     assert (counter == ITERS);
197     writefln("total %s cycles, per iteration %s", dt, dt / (ITERS * 3.0));
198 }
199