diff --git a/inc/os_win/threads_types.h b/inc/os_win/threads_types.h index 1173355..1da440f 100644 --- a/inc/os_win/threads_types.h +++ b/inc/os_win/threads_types.h @@ -4,6 +4,8 @@ #pragma once +#define ONCE_FLAG_INIT ((once_flag){0}) + typedef struct thrd_t { void *handle; } thrd_t; @@ -12,6 +14,13 @@ typedef struct tss_t { unsigned tls_index; } tss_t; +// We pretend that once_flag defined the same way as INIT_ONCE +// from winapi headers (aka _RTL_RUN_ONCE), which itself is defined +// as a union +typedef union once_flag { + void *ptr; +} once_flag; + typedef struct cnd_t { int idk_yet; } cnd_t; diff --git a/inc/threads.h b/inc/threads.h index 243c25e..c9b594d 100644 --- a/inc/threads.h +++ b/inc/threads.h @@ -1,8 +1,16 @@ #pragma once +// Note(bumbread): some of the macros and types are platform-dependent +// and can be found in os_win/ subfolder. + +#if defined(_WIN32) + #include "os_win/threads_types.h" +#else + #error "Not implemented" +#endif + // 7.28.1 p.3: Macros -#define ONCE_FLAG_INIT 1 #define TSS_DTOR_ITERATIONS 32 // TODO(bumbread): check the spec for whether thread_local needs to be declared @@ -20,19 +28,9 @@ }; #endif -#if defined(_WIN32) - #include "os_win/threads_types.h" -#else - #error "Not implemented" -#endif - typedef void(*tss_dtor_t) (void*); typedef int (*thrd_start_t)(void*); -// TODO(bumbread): this probably should be a mutex or a semaphore -// also probably can be implemented with interlocked increment -typedef int once_flag; - // 7.28.1 p.5: Enumeration constants enum { diff --git a/src/os_win/thread.c b/src/os_win/thread.c index 810254c..8fe6c73 100644 --- a/src/os_win/thread.c +++ b/src/os_win/thread.c @@ -130,16 +130,21 @@ _Noreturn void thrd_exit(int res) { #define TSS_KEYS_MAX 1088 static tss_dtor_t _tss_dtors[TSS_KEYS_MAX]; +static bool _tss_init[TSS_KEYS_MAX]; static void _thread_cleanup() { for(int i = 0; i != TSS_DTOR_ITERATIONS; ++i) { - for(unsigned k = 0; k != TSS_KEYS_MAX; ++k) { + for(unsigned k = 1; k != TSS_KEYS_MAX; ++k) { + if(!_tss_init[k]) { + continue; + } void *data = TlsGetValue(k); - if(data != NULL) { - TlsSetValue(k, NULL); - if(_tss_dtors[k]) { - _tss_dtors[k](data); - } + if(data == NULL) { + continue; + } + TlsSetValue(k, NULL); + if(_tss_dtors[k]) { + _tss_dtors[k](data); } } } @@ -156,11 +161,13 @@ int tss_create(tss_t *key, tss_dtor_t dtor) { TlsFree(tls_index); return thrd_error; } + _tss_init[tls_index] = true; _tss_dtors[tls_index] = dtor; return thrd_success; } void tss_delete(tss_t key) { + _tss_init[key.tls_index] = false; _tss_dtors[key.tls_index] = NULL; TlsFree(key.tls_index); } @@ -180,6 +187,18 @@ int tss_set(tss_t key, void *val) { return thrd_success; } +// Call once + +static BOOL _call_once_trampoline(PINIT_ONCE init_once, PVOID param, PVOID *ctx) { + void (*user_func)(void) = *ctx; + user_func(); + return TRUE; +} + +void call_once(once_flag *flag, void (*func)(void)) { + InitOnceExecuteOnce((void *)flag, _call_once_trampoline, NULL, (void **)&func); +} + // Mutex functions void mtx_destroy(mtx_t *mtx) { diff --git a/test/test_threads.c b/test/test_threads.c index 965e6a2..9d11c8b 100644 --- a/test/test_threads.c +++ b/test/test_threads.c @@ -2,33 +2,36 @@ #include #include +#define N_THREADS 1 + _Thread_local int counter; -tss_t key; +once_flag flag = ONCE_FLAG_INIT; + +void init() { + puts("Hey I got a call"); +} int f(void* thr_data) { - tss_set(key, "Thread 2 finished"); + call_once(&flag, init); for(int n = 0; n < 5; ++n) counter++; - puts(tss_get(key)); + puts("Finished"); return 0; } int main(void) { - tss_create(&key, NULL); - thrd_t thread; - int status = thrd_create(&thread, f, NULL); - if(status == thrd_error) { - puts("Failed creating threads"); + thrd_t thread[N_THREADS]; + for(int i = 0; i != N_THREADS; ++i) { + int status = thrd_create(&thread[i], f, NULL); + if(status == thrd_error) { + puts("Failed creating threads"); + } } - for(int n = 0; n < 10; ++n) { - counter++; + for(int i = 0; i != N_THREADS; ++i) { + int res; + if(thrd_join(thread[i], &res) == thrd_error) { + puts("Failed waiting on thread"); + } } - tss_set(key, "Thread 1 finished"); - int res; - if(thrd_join(thread, &res) == thrd_error) { - puts("Failed waiting on thread"); - } - puts(tss_get(key)); - tss_delete(key); }