diff --git a/inc/os_win/threads_types.h b/inc/os_win/threads_types.h index 8ecc018..1173355 100644 --- a/inc/os_win/threads_types.h +++ b/inc/os_win/threads_types.h @@ -4,18 +4,18 @@ #pragma once -typedef struct cnd_t { - int idk_yet; -} cnd_t; - typedef struct thrd_t { void *handle; } thrd_t; typedef struct tss_t { - int idk_yet; + unsigned tls_index; } tss_t; +typedef struct cnd_t { + int idk_yet; +} cnd_t; + typedef struct mtx_t { int type; // Done to handle recursive mutexes diff --git a/src/os_win/thread.c b/src/os_win/thread.c index 19021c9..810254c 100644 --- a/src/os_win/thread.c +++ b/src/os_win/thread.c @@ -1,4 +1,50 @@ +// Note(bumbread): +// https://gist.github.com/wbenny/6d7fc92e9b5c3194ce56bf8c60d6191d + +#pragma comment(linker, "/merge:.CRT=.rdata") + +#pragma section(".CRT$XLA", read) +__declspec(allocate(".CRT$XLA")) const PIMAGE_TLS_CALLBACK __xl_a = NULL; + +#pragma section(".CRT$XLZ", read) +__declspec(allocate(".CRT$XLZ")) const PIMAGE_TLS_CALLBACK __xl_z = NULL; + +#pragma section(".CRT$XLM", read) +__declspec(allocate(".CRT$XLM")) extern const PIMAGE_TLS_CALLBACK TlsCallbackArray; + +char _tls_start = 0; +char _tls_end = 0; +unsigned int _tls_index = 0; + +const IMAGE_TLS_DIRECTORY _tls_used = { + (ULONG_PTR)&_tls_start, + (ULONG_PTR)&_tls_end, + (ULONG_PTR)&_tls_index, + (ULONG_PTR)(&__xl_a + 1), +}; + +static void _thread_cleanup(); + +VOID NTAPI _tls_callback( + PVOID DllHandle, + DWORD Reason, + PVOID Reserved + ) +{ + switch(Reason) { + case DLL_THREAD_ATTACH: break; + case DLL_THREAD_DETACH: { + _thread_cleanup(); + } break; + case DLL_PROCESS_ATTACH: break; + case DLL_PROCESS_DETACH: break; + } + // __debugbreak(); +} + +const PIMAGE_TLS_CALLBACK TlsCallbackArray = { &_tls_callback }; + // NOTE: debug mutexes will follow the recursive logic but error if they // actually recurse, this is slower than doing plain logic but it helps // debug weird mutex errors. @@ -7,8 +53,6 @@ // https://preshing.com/20120305/implementing-a-recursive-mutex/ // https://preshing.com/20120226/roll-your-own-lightweight-mutex/ -DWORD _tls_index = 0; - typedef struct UserClosure { thrd_start_t func; void* arg; @@ -17,8 +61,6 @@ typedef struct UserClosure { static DWORD _thread_call_user(void* arg) { UserClosure info = *((UserClosure*) arg); int result = info.func(info.arg); - - // TODO(NeGate): setup TSS dtors here return (DWORD) result; } @@ -78,11 +120,68 @@ void thrd_yield(void) { } _Noreturn void thrd_exit(int res) { - // TODO(NeGate): setup TSS dtors here + _thread_cleanup(); ExitThread((DWORD)res); __builtin_unreachable(); } +// TSS functions + +#define TSS_KEYS_MAX 1088 + +static tss_dtor_t _tss_dtors[TSS_KEYS_MAX]; + +static void _thread_cleanup() { + for(int i = 0; i != TSS_DTOR_ITERATIONS; ++i) { + for(unsigned k = 0; k != TSS_KEYS_MAX; ++k) { + void *data = TlsGetValue(k); + if(data != NULL) { + TlsSetValue(k, NULL); + if(_tss_dtors[k]) { + _tss_dtors[k](data); + } + } + } + } +} + +int tss_create(tss_t *key, tss_dtor_t dtor) { + DWORD tls_index = TlsAlloc(); + if(tls_index == TLS_OUT_OF_INDEXES) { + return thrd_error; + } + key->tls_index = tls_index; + if(tls_index >= TSS_KEYS_MAX) { + __debugbreak(); + TlsFree(tls_index); + return thrd_error; + } + _tss_dtors[tls_index] = dtor; + return thrd_success; +} + +void tss_delete(tss_t key) { + _tss_dtors[key.tls_index] = NULL; + TlsFree(key.tls_index); +} + +void *tss_get(tss_t key) { + void *data = TlsGetValue(key.tls_index); + if(data == NULL && GetLastError() != ERROR_SUCCESS) { + return NULL; + } + return data; +} + +int tss_set(tss_t key, void *val) { + if(!TlsSetValue(key.tls_index, val)) { + return thrd_error; + } + return thrd_success; +} + +// Mutex functions + void mtx_destroy(mtx_t *mtx) { CloseHandle(mtx->semaphore); } diff --git a/test.bat b/test.bat new file mode 100644 index 0000000..807e319 --- /dev/null +++ b/test.bat @@ -0,0 +1,2 @@ + +clang -g %1 -I inc utf8.obj -nostdlib -mfma -lciabatta.lib diff --git a/test/test_threads.c b/test/test_threads.c index 41df77d..ebbb723 100644 --- a/test/test_threads.c +++ b/test/test_threads.c @@ -2,29 +2,33 @@ #include #include -atomic_int acnt; -int cnt; +_Thread_local int counter; +tss_t key; -int f(void* thr_data) -{ +int f(void* thr_data) { + tss_set(key, "Thread 2 finished"); for(int n = 0; n < 5; ++n) - puts("b"); + counter++; + puts(tss_get(key)); return 0; } int main(void) { - thrd_t thread; - int status = thrd_create(&thread, f, NULL); - if(status == thrd_error) { - puts("Failed creating threads"); - } - for(int n = 0; n < 5; ++n) { - puts("a"); - } - int res; - if(thrd_join(thread, &res) == thrd_error) { - puts("Failed waiting on thread"); - } - puts("Finished"); + tss_create(&key, NULL); + // thrd_t thread; + // int status = thrd_create(&thread, f, NULL); + // if(status == thrd_error) { + // puts("Failed creating threads"); + // } + // for(int n = 0; n < 10; ++n) { + // counter++; + // } + tss_set(key, "Thread 1 finished"); + // int res; + // if(thrd_join(thread, &res) == thrd_error) { + // puts("Failed waiting on thread"); + // } + puts(tss_get(key)); + tss_delete(key); }