Added SILC Thread Queue API
[silc.git] / lib / silcutil / win32 / silcwin32thread.c
index f55a55ac8090851af7374ebb1809d16973548bff..fa35683dbe4b19b5ebd96de2d9f8d52c45d1ca2a 100644 (file)
@@ -4,7 +4,7 @@
 
   Author: Pekka Riikonen <priikone@silcnet.org>
 
-  Copyright (C) 2001 - 2006 Pekka Riikonen
+  Copyright (C) 2001 - 2007 Pekka Riikonen
 
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
@@ -32,8 +32,6 @@ typedef struct {
   SilcBool waitable;
 } *SilcWin32Thread;
 
-static DWORD silc_thread_tls;
-
 /* Actual routine that is called by WIN32 when the thread is created.
    We will call the start_func from here. When this returns the thread
    is destroyed. */
@@ -41,10 +39,16 @@ static DWORD silc_thread_tls;
 unsigned __stdcall silc_thread_win32_start(void *context)
 {
   SilcWin32Thread thread = (SilcWin32Thread)context;
+  SilcTls tls;
+
+  tls = silc_thread_tls_init();
+  if (tls)
+    tls->platform_context = thread;
 
-  TlsSetValue(silc_thread_tls, context);
   silc_thread_exit(thread->start_func(thread->context));
 
+  silc_free(tls);
+
   return 0;
 }
 #endif
@@ -59,15 +63,18 @@ SilcThread silc_thread_create(SilcThreadStart start_func, void *context,
   SILC_LOG_DEBUG(("Creating new thread"));
 
   thread = silc_calloc(1, sizeof(*thread));
+  if (!thread)
+    return NULL;
   thread->start_func = start_func;
   thread->context = context;
   thread->waitable = waitable;
   thread->thread =
-    CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)silc_thread_win32_start,
-                (void *)thread, 0, &id);
+    _beginthreadex(NULL, 0, (LPTHREAD_START_ROUTINE)silc_thread_win32_start,
+                  (void *)thread, 0, &id);
 
   if (!thread->thread) {
     SILC_LOG_ERROR(("Could not create new thread"));
+    silc_set_errno_reason(SILC_ERR, "Could not create new thread");
     silc_free(thread);
     return NULL;
   }
@@ -83,40 +90,41 @@ SilcThread silc_thread_create(SilcThreadStart start_func, void *context,
 void silc_thread_exit(void *exit_value)
 {
 #ifdef SILC_THREADS
-  SilcWin32Thread thread = TlsGetValue(silc_thread_tls);
+  SilcTls tls = silc_thread_get_tls();
+  SilcWin32Thread thread = tls->platform_context;
 
   if (thread) {
     /* If the thread is waitable the memory is freed only in silc_thread_wait
        by another thread. If not waitable, free it now. */
-    if (!thread->waitable) {
-      TerminateThread(thread->thread, 0);
+    if (!thread->waitable)
       silc_free(thread);
-    }
-
-    TlsSetValue(silc_thread_tls, NULL);
   }
-  ExitThread(0);
+
+  _endthreadex(0);
 #endif
 }
 
 SilcThread silc_thread_self(void)
 {
 #ifdef SILC_THREADS
-  SilcWin32Thread self = TlsGetValue(silc_thread_tls);
+  SilcTls tls = silc_thread_get_tls();
+  SilcWin32Thread self = tls->platform_context;
 
   if (!self) {
-    /* This should only happen for the main thread! */
+    /* This should only happen for the main thread. */
     HANDLE handle = GetCurrentThread ();
     HANDLE process = GetCurrentProcess ();
     self = silc_calloc(1, sizeof(*self));
-    DuplicateHandle(process, handle, process,
-                   &self->thread, 0, FALSE,
-                   DUPLICATE_SAME_ACCESS);
-    TlsSetValue(silc_thread_tls, self);
+    if (self) {
+      DuplicateHandle(process, handle, process,
+                     &self->thread, 0, FALSE,
+                     DUPLICATE_SAME_ACCESS);
+      tls->platform_context = self;
+    }
   }
 
   return (SilcThread)self;
-       #else
+#else
   return NULL;
 #endif
 }
@@ -133,10 +141,9 @@ SilcBool silc_thread_wait(SilcThread thread, void **exit_value)
 
   /* The thread is waitable thus we will free all memory after the
      WaitForSingleObject returns, the thread is destroyed after that. */
-  if (WaitForSingleObject(self->thread, 2500) == WAIT_TIMEOUT)
-    TerminateThread(self->thread, 0);
+  WaitForSingleObject(self->thread, INFINITE);
+  CloseHandle(self->thread);
 
-  silc_free(self);
   if (exit_value)
     *exit_value = NULL;
 
@@ -146,6 +153,13 @@ SilcBool silc_thread_wait(SilcThread thread, void **exit_value)
 #endif
 }
 
+void silc_thread_yield(void)
+{
+#ifdef SILC_THREADS
+  SleepEx (0,0);
+#endif /* SILC_THREADS */
+}
+
 
 /***************************** SILC Mutex API *******************************/
 
@@ -153,10 +167,8 @@ SilcBool silc_thread_wait(SilcThread thread, void **exit_value)
 struct SilcMutexStruct {
 #ifdef SILC_THREADS
   CRITICAL_SECTION mutex;
-  BOOL locked;
-#else
-  void *tmp;
 #endif /* SILC_THREADS */
+  unsigned int locked : 1;
 };
 
 SilcBool silc_mutex_alloc(SilcMutex *mutex)
@@ -187,7 +199,7 @@ void silc_mutex_lock(SilcMutex mutex)
 #ifdef SILC_THREADS
   if (mutex) {
     EnterCriticalSection(&mutex->mutex);
-    assert(mutex->locked == FALSE);
+    SILC_ASSERT(mutex->locked == FALSE);
     mutex->locked = TRUE;
   }
 #endif /* SILC_THREADS */
@@ -197,18 +209,114 @@ void silc_mutex_unlock(SilcMutex mutex)
 {
 #ifdef SILC_THREADS
   if (mutex) {
-    assert(mutex->locked == TRUE);
+    SILC_ASSERT(mutex->locked == TRUE);
     mutex->locked = FALSE;
     LeaveCriticalSection(&mutex->mutex);
   }
 #endif /* SILC_THREADS */
 }
 
+void silc_mutex_assert_locked(SilcMutex mutex)
+{
+#ifdef SILC_THREADS
+  if (mutex)
+    SILC_ASSERT(mutex->locked);
+#endif /* SILC_THREADS */
+}
+
+
+/***************************** SILC Rwlock API ******************************/
+
+/* SILC read/write lock structure */
+struct SilcRwLockStruct {
+#ifdef SILC_THREADS
+  SilcMutex mutex;
+  SilcCond cond;
+#endif /* SILC_THREADS */
+  unsigned int readers : 31;
+  unsigned int locked  : 1;
+};
 
-/**************************** SILC CondVar API ******************************/
+SilcBool silc_rwlock_alloc(SilcRwLock *rwlock)
+{
+#ifdef SILC_THREADS
+  *rwlock = silc_calloc(1, sizeof(**rwlock));
+  if (!(*rwlock))
+    return FALSE;
+  if (!silc_mutex_alloc(&(*rwlock)->mutex)) {
+    silc_free(*rwlock);
+    return FALSE;
+  }
+  if (!silc_cond_alloc(&(*rwlock)->cond)) {
+    silc_mutex_free((*rwlock)->mutex);
+    silc_free(*rwlock);
+    return FALSE;
+  }
+  return TRUE;
+#else
+  return FALSE;
+#endif /* SILC_THREADS */
+}
+
+void silc_rwlock_free(SilcRwLock rwlock)
+{
+#ifdef SILC_THREADS
+  if (rwlock) {
+    silc_mutex_free(rwlock->mutex);
+    silc_cond_free(rwlock->cond);
+    silc_free(rwlock);
+  }
+#endif /* SILC_THREADS */
+}
+
+void silc_rwlock_rdlock(SilcRwLock rwlock)
+{
+#ifdef SILC_THREADS
+  if (rwlock) {
+    silc_mutex_lock(rwlock->mutex);
+    rwlock->readers++;
+    silc_mutex_unlock(rwlock->mutex);
+  }
+#endif /* SILC_THREADS */
+}
+
+void silc_rwlock_wrlock(SilcRwLock rwlock)
+{
+#ifdef SILC_THREADS
+  if (rwlock) {
+    silc_mutex_lock(rwlock->mutex);
+    while (rwlock->readers > 0)
+      silc_cond_wait(rwlock->cond, rwlock->mutex);
+    rwlock->locked = TRUE;
+  }
+#endif /* SILC_THREADS */
+}
+
+void silc_rwlock_unlock(SilcRwLock rwlock)
+{
+#ifdef SILC_THREADS
+  if (rwlock) {
+    if (rwlock->locked) {
+      /* Unlock writer */
+      rwlock->locked = FALSE;
+      silc_mutex_unlock(rwlock->mutex);
+      return;
+    }
+
+    /* Unlock reader */
+    silc_mutex_lock(rwlock->mutex);
+    rwlock->readers--;
+    silc_cond_broadcast(rwlock->cond);
+    silc_mutex_unlock(rwlock->mutex);
+  }
+#endif /* SILC_THREADS */
+}
+
+
+/**************************** SILC Cond API ******************************/
 
 /* SILC Conditional Variable context */
-struct SilcCondVarStruct {
+struct SilcCondStruct {
 #ifdef SILC_THREADS
   HANDLE event;
 #endif /* SILC_THREADS*/
@@ -216,7 +324,7 @@ struct SilcCondVarStruct {
   unsigned int signal  : 1;
 };
 
-SilcBool silc_condvar_alloc(SilcCondVar *cond)
+SilcBool silc_cond_alloc(SilcCond *cond)
 {
 #ifdef SILC_THREADS
   *cond = silc_calloc(1, sizeof(**cond));
@@ -229,7 +337,7 @@ SilcBool silc_condvar_alloc(SilcCondVar *cond)
 #endif /* SILC_THREADS*/
 }
 
-void silc_condvar_free(SilcCondVar cond)
+void silc_cond_free(SilcCond cond)
 {
 #ifdef SILC_THREADS
   CloseHandle(cond->event);
@@ -237,7 +345,7 @@ void silc_condvar_free(SilcCondVar cond)
 #endif /* SILC_THREADS*/
 }
 
-void silc_condvar_signal(SilcCondVar cond)
+void silc_cond_signal(SilcCond cond)
 {
 #ifdef SILC_THREADS
   cond->signal = TRUE;
@@ -245,7 +353,7 @@ void silc_condvar_signal(SilcCondVar cond)
 #endif /* SILC_THREADS*/
 }
 
-void silc_condvar_broadcast(SilcCondVar cond)
+void silc_cond_broadcast(SilcCond cond)
 {
 #ifdef SILC_THREADS
   cond->signal = TRUE;
@@ -253,15 +361,15 @@ void silc_condvar_broadcast(SilcCondVar cond)
 #endif /* SILC_THREADS*/
 }
 
-void silc_condvar_wait(SilcCondVar cond, SilcMutex mutex)
+void silc_cond_wait(SilcCond cond, SilcMutex mutex)
 {
 #ifdef SILC_THREADS
-  silc_condvar_timedwait(cond, mutex, NULL);
+  silc_cond_timedwait(cond, mutex, 0);
 #endif /* SILC_THREADS*/
 }
 
-SilcBool silc_condvar_timedwait(SilcCondVar cond, SilcMutex mutex,
-                               int timeout)
+SilcBool silc_cond_timedwait(SilcCond cond, SilcMutex mutex,
+                            int timeout)
 {
 #ifdef SILC_THREADS
   DWORD ret, t = INFINITE;
@@ -288,4 +396,67 @@ SilcBool silc_condvar_timedwait(SilcCondVar cond, SilcMutex mutex,
     }
   }
 #endif /* SILC_THREADS*/
+  return TRUE;
+}
+
+/************************** Thread-local Storage ****************************/
+
+#ifdef SILC_THREADS
+
+static DWORD silc_tls;
+SilcBool silc_tls_set = FALSE;
+
+SilcTls silc_thread_tls_init(void)
+{
+  SilcTls tls;
+
+  if (!silc_tls_set) {
+    silc_tls = TlsAlloc();
+    if (silc_tls == TLS_OUT_OF_INDEXES) {
+      SILC_LOG_ERROR(("Error creating Thread-local storage"));
+      return NULL;
+    }
+
+    silc_tls_set = TRUE;
+  }
+
+  if (silc_thread_get_tls())
+    return silc_thread_get_tls();
+
+  /* Allocate Tls for the thread */
+  tls = silc_calloc(1, sizeof(*tls));
+  if (!tls) {
+    SILC_LOG_ERROR(("Error allocating Thread-local storage"));
+    return NULL;
+  }
+
+  TlsSetValue(silc_tls, tls);
+  return tls;
+}
+
+SilcTls silc_thread_get_tls(void)
+{
+  return (SilcTls)TlsGetValue(silc_tls);
 }
+
+#else
+
+SilcTlsStruct tls;
+SilcTls tls_ptr = NULL;
+
+SilcTls silc_thread_tls_init(void)
+{
+  if (silc_thread_get_tls())
+    return silc_thread_get_tls();
+
+  tls_ptr = &tls;
+  memset(tls_ptr, 0, sizeof(*tls_ptr));
+  return tls_ptr;
+}
+
+SilcTls silc_thread_get_tls(void)
+{
+  return tls_ptr;
+}
+
+#endif /* SILC_THREADS */