/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

#include "test.h"
#include "testnet.h"
#include <librecast/net.h>
#include <pthread.h>
#include <semaphore.h>
#include <time.h>

#define TEST_NAME "FEC + OTI with variable sized objects"
#define ITERATIONS 2
#define SPEED_LIMIT 1024 * 1024 * 32 /* 32Mbps */
#define WAITS 4

static sem_t sem;
static size_t len[ITERATIONS];
static char data[ITERATIONS][BUFSIZ];
static unsigned int ifidx;

static void *thread_recv(void * arg)
{
	(void)arg;
	lc_ctx_t *lctx = NULL;
	lc_socket_t *sock;
	lc_channel_t *chan;
	char buf[BUFSIZ];
	ssize_t byt;
	int rc;

	pthread_cleanup_push((void (*)(void *))lc_ctx_free, lctx);
	lctx = lc_ctx_new();
	if (!test_assert(lctx != NULL, "lc_ctx_new()")) return NULL;
	sock = lc_socket_new(lctx);
	if (!test_assert(sock != NULL, "lc_socket_new()")) goto lctx_free;
	if (!test_assert(lc_socket_bind(sock, ifidx) == 0, "lc_socket_bind() ifx = %u", ifidx))
		goto lctx_free;
	chan = lc_channel_new(lctx, TEST_NAME);
	if (!test_assert(chan != NULL, "lc_channel_new()")) goto lctx_free;
	lc_channel_bind(sock, chan);
	lc_channel_coding_set(chan, LC_CODE_FEC_RQ | LC_CODE_FEC_OTI);
	rc = lc_channel_join(chan);
	if (!test_assert(rc == 0, "lc_channel_join()")) goto lctx_free;

	for (int i = 0; i < ITERATIONS; i++) {
		unsigned char hash[HASHSIZE];
		memset(buf, 0, sizeof buf);
		byt = lc_channel_recv(chan, buf, sizeof buf, 0);
		if (!test_assert(byt == (ssize_t)len[i], "%zu/%zi bytes recv", len[i], byt)) break;
		hash_generic(hash, sizeof hash, (unsigned char *)data[i], len[i]);
		fprintf(stderr, "%i: expected: ", i);
		hash_hex_debug(stderr, hash, HASHSIZE);
		hash_generic(hash, sizeof hash, (unsigned char *)buf, len[i]);
		fprintf(stderr, "%i:      got: ", i);
		hash_hex_debug(stderr, hash, HASHSIZE);
		if (!test_assert(!memcmp(data[i], buf, len[i]), "data matches (%zu bytes)", len[i])) break;
	}

lctx_free:
	pthread_cleanup_pop(1); /* lc_ctx_free(lctx); */
	sem_post(&sem);
	return NULL;
}

static void *thread_send(void * arg)
{
	(void)arg;
	lc_ctx_t *lctx = NULL;
	lc_socket_t *sock;
	lc_channel_t *chan;
	ssize_t byt;

	pthread_cleanup_push((void (*)(void *))lc_ctx_free, lctx);
	lctx = lc_ctx_new();
	if (!test_assert(lctx != NULL, "lc_ctx_new()")) return NULL;
	sock = lc_socket_new(lctx);
	if (!test_assert(sock != NULL, "lc_socket_new()")) goto lctx_free;
	if (!test_assert(lc_socket_bind(sock, ifidx) == 0, "lc_socket_bind() ifx = %u", ifidx))
		goto lctx_free;
	lc_socket_loop(sock, 1);
	chan = lc_channel_new(lctx, TEST_NAME);
	if (!test_assert(chan != NULL, "lc_channel_new()")) goto lctx_free;
	lc_channel_bind(sock, chan);
	lc_channel_coding_set(chan, LC_CODE_FEC_RQ | LC_CODE_FEC_OTI);
	lc_channel_rq_overhead(chan, RQ_OVERHEAD * 2);
	lc_channel_ratelimit(chan, SPEED_LIMIT, 0);

	/* generate payload data */
	arc4random_buf(data, sizeof data);
	for (int i = 0; i < ITERATIONS; i++) {
		unsigned char hash[HASHSIZE];
		len[i] = arc4random_uniform(BUFSIZ);
		hash_generic(hash, sizeof hash, (unsigned char *)data[i], len[i]);
		fprintf(stderr, "%i: ", i);
		hash_hex_debug(stderr, hash, HASHSIZE);
		byt = lc_channel_send(chan, data[i], len[i], 0);
		if (!test_assert(byt == (ssize_t)len[i], "%zi/%zu bytes sent", byt, len[i])) break;
	}

lctx_free:
	pthread_cleanup_pop(1); /* lc_ctx_free(lctx); */
	return NULL;
}

static int test_oti_sizes(void)
{
	enum { TID_SEND, TID_RECV };
	pthread_t tid[2];
	struct timespec timeout;
	int threads = sizeof tid / sizeof tid[0];
	int rc;

	rc = sem_init(&sem, 0, 0);
	if (!test_assert(rc == 0, "sem_init()")) return test_status;

	rc = pthread_create(&tid[TID_SEND], NULL, &thread_send, NULL);
	if (!test_assert(rc == 0, "create send thread")) goto free_sem;
	rc = pthread_create(&tid[TID_RECV], NULL, &thread_recv, NULL);
	if (!test_assert(rc == 0, "create recv thread")) goto free_sem;

	rc = clock_gettime(CLOCK_REALTIME, &timeout);
	if (!test_assert(rc == 0, "clock_gettime()")) goto stop_threads;
	timeout.tv_sec += WAITS;
	rc = sem_timedwait(&sem, &timeout);
	test_assert(rc == 0, "timeout");
stop_threads:
	for (int i = 0; i < threads; i++) {
		pthread_cancel(tid[i]);
		pthread_join(tid[i], NULL);
	}
free_sem:
	sem_destroy(&sem);

	return test_status;
}

int main(void)
{
	char name[] = TEST_NAME;

	test_name(name);
	test_require_net(TEST_NET_BASIC);

	ifidx = get_multicast_if();
	if (!test_assert(ifidx > 0, "get_multicast_if()")) return test_status;

	test_oti_sizes();

	return test_status;
}
