This commit is contained in:
bumbread 2023-02-19 10:27:59 +11:00
parent c02862b550
commit fc793c4869
4 changed files with 63 additions and 34 deletions

View File

@ -4,6 +4,8 @@
#pragma once #pragma once
#define ONCE_FLAG_INIT ((once_flag){0})
typedef struct thrd_t { typedef struct thrd_t {
void *handle; void *handle;
} thrd_t; } thrd_t;
@ -12,6 +14,13 @@ typedef struct tss_t {
unsigned tls_index; unsigned tls_index;
} tss_t; } tss_t;
// We pretend that once_flag defined the same way as INIT_ONCE
// from winapi headers (aka _RTL_RUN_ONCE), which itself is defined
// as a union
typedef union once_flag {
void *ptr;
} once_flag;
typedef struct cnd_t { typedef struct cnd_t {
int idk_yet; int idk_yet;
} cnd_t; } cnd_t;

View File

@ -1,8 +1,16 @@
#pragma once #pragma once
// Note(bumbread): some of the macros and types are platform-dependent
// and can be found in os_win/ subfolder.
#if defined(_WIN32)
#include "os_win/threads_types.h"
#else
#error "Not implemented"
#endif
// 7.28.1 p.3: Macros // 7.28.1 p.3: Macros
#define ONCE_FLAG_INIT 1
#define TSS_DTOR_ITERATIONS 32 #define TSS_DTOR_ITERATIONS 32
// TODO(bumbread): check the spec for whether thread_local needs to be declared // TODO(bumbread): check the spec for whether thread_local needs to be declared
@ -20,19 +28,9 @@
}; };
#endif #endif
#if defined(_WIN32)
#include "os_win/threads_types.h"
#else
#error "Not implemented"
#endif
typedef void(*tss_dtor_t) (void*); typedef void(*tss_dtor_t) (void*);
typedef int (*thrd_start_t)(void*); typedef int (*thrd_start_t)(void*);
// TODO(bumbread): this probably should be a mutex or a semaphore
// also probably can be implemented with interlocked increment
typedef int once_flag;
// 7.28.1 p.5: Enumeration constants // 7.28.1 p.5: Enumeration constants
enum { enum {

View File

@ -130,16 +130,21 @@ _Noreturn void thrd_exit(int res) {
#define TSS_KEYS_MAX 1088 #define TSS_KEYS_MAX 1088
static tss_dtor_t _tss_dtors[TSS_KEYS_MAX]; static tss_dtor_t _tss_dtors[TSS_KEYS_MAX];
static bool _tss_init[TSS_KEYS_MAX];
static void _thread_cleanup() { static void _thread_cleanup() {
for(int i = 0; i != TSS_DTOR_ITERATIONS; ++i) { for(int i = 0; i != TSS_DTOR_ITERATIONS; ++i) {
for(unsigned k = 0; k != TSS_KEYS_MAX; ++k) { for(unsigned k = 1; k != TSS_KEYS_MAX; ++k) {
if(!_tss_init[k]) {
continue;
}
void *data = TlsGetValue(k); void *data = TlsGetValue(k);
if(data != NULL) { if(data == NULL) {
TlsSetValue(k, NULL); continue;
if(_tss_dtors[k]) { }
_tss_dtors[k](data); TlsSetValue(k, NULL);
} if(_tss_dtors[k]) {
_tss_dtors[k](data);
} }
} }
} }
@ -156,11 +161,13 @@ int tss_create(tss_t *key, tss_dtor_t dtor) {
TlsFree(tls_index); TlsFree(tls_index);
return thrd_error; return thrd_error;
} }
_tss_init[tls_index] = true;
_tss_dtors[tls_index] = dtor; _tss_dtors[tls_index] = dtor;
return thrd_success; return thrd_success;
} }
void tss_delete(tss_t key) { void tss_delete(tss_t key) {
_tss_init[key.tls_index] = false;
_tss_dtors[key.tls_index] = NULL; _tss_dtors[key.tls_index] = NULL;
TlsFree(key.tls_index); TlsFree(key.tls_index);
} }
@ -180,6 +187,18 @@ int tss_set(tss_t key, void *val) {
return thrd_success; return thrd_success;
} }
// Call once
static BOOL _call_once_trampoline(PINIT_ONCE init_once, PVOID param, PVOID *ctx) {
void (*user_func)(void) = *ctx;
user_func();
return TRUE;
}
void call_once(once_flag *flag, void (*func)(void)) {
InitOnceExecuteOnce((void *)flag, _call_once_trampoline, NULL, (void **)&func);
}
// Mutex functions // Mutex functions
void mtx_destroy(mtx_t *mtx) { void mtx_destroy(mtx_t *mtx) {

View File

@ -2,33 +2,36 @@
#include <threads.h> #include <threads.h>
#include <stdatomic.h> #include <stdatomic.h>
#define N_THREADS 1
_Thread_local int counter; _Thread_local int counter;
tss_t key; once_flag flag = ONCE_FLAG_INIT;
void init() {
puts("Hey I got a call");
}
int f(void* thr_data) { int f(void* thr_data) {
tss_set(key, "Thread 2 finished"); call_once(&flag, init);
for(int n = 0; n < 5; ++n) for(int n = 0; n < 5; ++n)
counter++; counter++;
puts(tss_get(key)); puts("Finished");
return 0; return 0;
} }
int main(void) int main(void)
{ {
tss_create(&key, NULL); thrd_t thread[N_THREADS];
thrd_t thread; for(int i = 0; i != N_THREADS; ++i) {
int status = thrd_create(&thread, f, NULL); int status = thrd_create(&thread[i], f, NULL);
if(status == thrd_error) { if(status == thrd_error) {
puts("Failed creating threads"); puts("Failed creating threads");
}
} }
for(int n = 0; n < 10; ++n) { for(int i = 0; i != N_THREADS; ++i) {
counter++; int res;
if(thrd_join(thread[i], &res) == thrd_error) {
puts("Failed waiting on thread");
}
} }
tss_set(key, "Thread 1 finished");
int res;
if(thrd_join(thread, &res) == thrd_error) {
puts("Failed waiting on thread");
}
puts(tss_get(key));
tss_delete(key);
} }