From d9b49a682c2a03b58e50681dd6cafed665097eac Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Deuc=D0=B5?= <5-Deuce@users.noreply.gitlab.synchro.net>
Date: Fri, 19 Jan 2024 20:09:35 +0000
Subject: [PATCH] Use new rwlock for ssl certificate.

---
 src/sbbs3/ssl.c | 127 ++++++++++++++----------------------------------
 1 file changed, 37 insertions(+), 90 deletions(-)

diff --git a/src/sbbs3/ssl.c b/src/sbbs3/ssl.c
index 9c4f23f203..6e782c9f82 100644
--- a/src/sbbs3/ssl.c
+++ b/src/sbbs3/ssl.c
@@ -1,6 +1,7 @@
 #include <stdbool.h>
 #include <stdio.h>
 
+#include "rwlockwrap.h"
 #include <threadwrap.h>
 #include "xpprintf.h"
 #include "eventwrap.h"
@@ -235,10 +236,8 @@ bool get_crypt_error_string(int status, CRYPT_HANDLE sess, char **estr, const ch
 }
 
 static pthread_once_t crypt_init_once = PTHREAD_ONCE_INIT;
-static pthread_mutex_t ssl_cert_mutex;
+static rwlock_t ssl_rwlock;
 static pthread_mutex_t ssl_cert_list_mutex;
-static xpevent_t ssl_cert_read_available;
-static xpevent_t ssl_cert_write_available;
 static bool cryptlib_initialized;
 static int cryptInit_error;
 
@@ -259,10 +258,8 @@ static void internal_do_cryptInit(void)
 	else {
 		cryptInit_error = ret; 
 	}
-	pthread_mutex_init(&ssl_cert_mutex, NULL);
+	rwlock_init(&ssl_rwlock);
 	pthread_mutex_init(&ssl_cert_list_mutex, NULL);
-	ssl_cert_read_available = CreateEvent(NULL, TRUE, TRUE, NULL);
-	ssl_cert_write_available = CreateEvent(NULL, TRUE, TRUE, NULL);
 	return;
 }
 
@@ -283,80 +280,6 @@ bool is_crypt_initialized(void)
 	return cryptlib_initialized;
 }
 
-static uint32_t readers;
-static uint32_t writers;
-static uint32_t writers_waiting;
-
-static void lock_ssl_cert(void)
-{
-	int done = 0;
-
-	pthread_mutex_lock(&ssl_cert_mutex);
-	do {
-		done = (writers == 0 && writers_waiting == 0);
-		if (!done) {
-			pthread_mutex_unlock(&ssl_cert_mutex);
-			WaitForEvent(ssl_cert_read_available, INFINITE);
-			pthread_mutex_lock(&ssl_cert_mutex);
-		}
-	} while (!done);
-	ResetEvent(ssl_cert_write_available);
-	readers++;
-	pthread_mutex_unlock(&ssl_cert_mutex);
-}
-
-static void lock_ssl_cert_write(void)
-{
-	int done;
-
-	ResetEvent(ssl_cert_read_available);
-	pthread_mutex_lock(&ssl_cert_mutex);
-	writers_waiting++;
-	do {
-		done = (readers == 0 && writers == 0);
-		if (!done) {
-			pthread_mutex_unlock(&ssl_cert_mutex);
-			WaitForEvent(ssl_cert_write_available, INFINITE);
-			pthread_mutex_lock(&ssl_cert_mutex);
-		}
-	} while(!done);
-	ResetEvent(ssl_cert_write_available);
-	writers_waiting--;
-	writers++;
-	pthread_mutex_unlock(&ssl_cert_mutex);
-}
-
-static void unlock_ssl_cert(int (*lprintf)(int level, const char* fmt, ...))
-{
-	pthread_mutex_lock(&ssl_cert_mutex);
-	if (readers == 0) {
-		lprintf(LOG_ERR, "Unlocking ssl cert for read when it's not locked.");
-	}
-	else {
-		readers--;
-		if (readers == 0) {
-			SetEvent(ssl_cert_write_available);
-		}
-	}
-	pthread_mutex_unlock(&ssl_cert_mutex);
-}
-
-static void unlock_ssl_cert_write(int (*lprintf)(int level, const char* fmt, ...))
-{
-	pthread_mutex_lock(&ssl_cert_mutex);
-	if (writers == 0) {
-		lprintf(LOG_ERR, "Unlocking ssl cert for write when it's not locked.");
-	}
-	else {
-		writers--;
-		SetEvent(ssl_cert_write_available);
-		if (writers_waiting == 0) {
-			SetEvent(ssl_cert_read_available);
-		}
-	}
-	pthread_mutex_unlock(&ssl_cert_mutex);
-}
-
 // TODO: Failures in get_ssl_cert() aren't logged.
 #define DO(action, handle, x)	cryptStatusOK(x)
 
@@ -376,7 +299,10 @@ bool ssl_sync(scfg_t *scfg, int (*lprintf)(int level, const char* fmt, ...))
 {
 	if (!do_cryptInit(lprintf))
 		return false;
-	lock_ssl_cert_write();
+	if (!rwlock_wrlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to lock ssl_rwlock for write at %d", __LINE__);
+		return false;
+	}
 	if (!cert_path[0])
 		SAFEPRINTF2(cert_path,"%s%s",scfg->ctrl_dir,"ssl.cert");
 	time_t fd = fdate(cert_path);
@@ -399,7 +325,10 @@ bool ssl_sync(scfg_t *scfg, int (*lprintf)(int level, const char* fmt, ...))
 		}
 	}
 	tls_cert_file_date = fd;
-	unlock_ssl_cert_write(lprintf);
+	if (!rwlock_unlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+		return false;
+	}
 	return true;
 }
 
@@ -413,10 +342,15 @@ static CRYPT_CONTEXT get_ssl_cert(scfg_t *cfg, int (*lprintf)(int level, const c
 	if(!do_cryptInit(lprintf))
 		return -1;
 	ssl_sync(cfg, lprintf);
-	lock_ssl_cert_write();
+	if (!rwlock_wrlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to lock ssl_rwlock for write at %d", __LINE__);
+		return -1;
+	}
 	cert_entry = malloc(sizeof(*cert_entry));
 	if(cert_entry == NULL) {
-		unlock_ssl_cert_write(lprintf);
+		if (!rwlock_unlock(&ssl_rwlock)) {
+			lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+		}
 		free(cert_entry);
 		lprintf(LOG_CRIT, "%s line %d: FAILED TO ALLOCATE %u bytes of memory", __FUNCTION__, __LINE__, sizeof *cert_entry);
 		return -1;
@@ -428,7 +362,9 @@ static CRYPT_CONTEXT get_ssl_cert(scfg_t *cfg, int (*lprintf)(int level, const c
 	/* Get the certificate... first try loading it from a file... */
 	if(cryptStatusOK(cryptKeysetOpen(&ssl_keyset, CRYPT_UNUSED, CRYPT_KEYSET_FILE, cert_path, CRYPT_KEYOPT_READONLY))) {
 		if(!DO("getting private key", ssl_keyset, cryptGetPrivateKey(ssl_keyset, &cert_entry->cert, CRYPT_KEYID_NAME, "ssl_cert", cfg->sys_pass))) {
-			unlock_ssl_cert_write(lprintf);
+			if (!rwlock_unlock(&ssl_rwlock)) {
+				lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+			}
 			free(cert_entry);
 			return -1;
 		}
@@ -436,7 +372,9 @@ static CRYPT_CONTEXT get_ssl_cert(scfg_t *cfg, int (*lprintf)(int level, const c
 	else {
 		/* Couldn't do that... create a new context and use the cert from there... */
 		if(!DO("creating SSL context", CRYPT_UNUSED,cryptCreateContext(&cert_entry->cert, CRYPT_UNUSED, CRYPT_ALGO_RSA))) {
-			unlock_ssl_cert_write(lprintf);
+			if (!rwlock_unlock(&ssl_rwlock)) {
+				lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+			}
 			free(cert_entry);
 			return -1;
 		}
@@ -498,7 +436,9 @@ static CRYPT_CONTEXT get_ssl_cert(scfg_t *cfg, int (*lprintf)(int level, const c
 		lprintf(LOG_DEBUG, "Created TLS private key and certificate %d", cert_entry->cert);
 		pthread_mutex_unlock(&ssl_cert_list_mutex);
 	}
-	unlock_ssl_cert_write(lprintf);
+	if (!rwlock_unlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+	}
 	return 0;
 
 failure_return_3:
@@ -508,7 +448,9 @@ failure_return_2:
 failure_return_1:
 	cryptDestroyContext(cert_entry->cert);
 	cert_path[0] = 0;
-	unlock_ssl_cert_write(lprintf);
+	if (!rwlock_unlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for write at %d", __LINE__);
+	}
 	free(cert_entry);
 	return -1;
 }
@@ -560,7 +502,10 @@ int destroy_session(int (*lprintf)(int level, const char* fmt, ...), CRYPT_SESSI
 	struct cert_list *psess = NULL;
 	int ret = CRYPT_ERROR_NOTFOUND;
 
-	lock_ssl_cert();
+	if (!rwlock_rdlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to lock ssl_rwlock for read ast line %d", __LINE__);
+		return CRYPT_ERROR_INTERNAL;
+	}
 	pthread_mutex_lock(&ssl_cert_list_mutex);
 	sess = sess_list;
 	while (sess != NULL) {
@@ -589,7 +534,9 @@ int destroy_session(int (*lprintf)(int level, const char* fmt, ...), CRYPT_SESSI
 		sess = sess->next;
 	}
 	pthread_mutex_unlock(&ssl_cert_list_mutex);
-	unlock_ssl_cert(lprintf);
+	if (!rwlock_unlock(&ssl_rwlock)) {
+		lprintf(LOG_ERR, "Unable to unlock ssl_rwlock for read at %d", __LINE__);
+	}
 	if (ret == CRYPT_ERROR_NOTFOUND)
 		ret = cryptDestroySession(csess);
 	return ret;
-- 
GitLab