/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

/*
 * https://todo.sr.ht/~librecast/librecast/121
 *
 * lc_socket_multi_recv(3) - same as lc_socket_recv(3), but returns destination channel in dst
 *
 */

#include "test.h"
#include "testnet.h"
#include <librecast/net.h>
#include <pthread.h>
#include <semaphore.h>

#define WAITS 20
#define PAYLOADMAX 1024
#define CHANNELS 2

enum {
	TID_SEND,
	TID_RECV
};
enum {
	CHAN_LEFT,
	CHAN_RIGHT,
};

static sem_t receiver_ready, timeout;
static unsigned int ifidx;

static lc_channel_t *create_and_bind_channel(lc_ctx_t *lctx, lc_socket_t *sock, int channel_number)
{
	lc_channel_t *chan;
	char chanstr[2];
	int rc;

	assert(channel_number >= 0 && channel_number < CHANNELS);
	sprintf(chanstr, "%d", channel_number);
	chan = lc_channel_new(lctx, chanstr);
	if (!chan) return NULL;
	rc = lc_channel_bind(sock, chan);
	if (!test_assert(rc == 0, "lc_channel_bind() %s", chanstr)) return NULL;
	return chan;
}

static void *recv_data(void * arg)
{
	char *buf = (char *)arg;
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan[CHANNELS];
	ssize_t byt;
	int rc;

	lctx = lc_ctx_new();
	if (!lctx) return NULL;
	sock = lc_socket_new(lctx);
	if (!sock) goto err_ctx_free;
	if (lc_socket_bind(sock, ifidx) == -1) goto err_ctx_free;
	lc_socket_loop(sock, 1);
	for (int i = 0; i < CHANNELS; i++) {
		chan[i] = create_and_bind_channel(lctx, sock, i);
		if (!chan[i]) goto err_ctx_free;
		rc = lc_channel_join(chan[i]);
		if (!test_assert(rc == 0, "lc_channel_join() %i", i))
			goto err_ctx_free;
	}

	/* recv from socket, expect data on LEFT channel */
	sem_post(&receiver_ready);
	lc_channel_t *dst = NULL;
	byt = lc_socket_multi_recv(sock, buf, PAYLOADMAX, 0, &dst);
	test_assert(dst == chan[CHAN_LEFT], "lc_channel_recv() - LEFT channel");
	test_assert(byt == PAYLOADMAX, "bytes recved on LEFT = %zi", byt);
	/* recv from socket, expect data on RIGHT channel */
	dst = NULL;
	byt = lc_socket_multi_recv(sock, buf, PAYLOADMAX, 0, &dst);
	test_assert(dst == chan[CHAN_RIGHT], "lc_channel_recv() - RIGHT channel");
	test_assert(byt == PAYLOADMAX, "bytes recved on RIGHT = %zi", byt);

	/* again, with lc_socket_multi_recvmsg() */
	sem_post(&receiver_ready);
	const size_t len = PAYLOADMAX;
	struct iovec iov = { .iov_base = buf, .iov_len = len };
	struct msghdr msg = { .msg_iov = &iov, .msg_iovlen = 1 };
	dst = NULL;
	byt = lc_socket_multi_recvmsg(sock, &msg, 0, &dst);
	test_assert(dst == chan[CHAN_LEFT], "lc_channel_recvmsg() - LEFT channel");
	test_assert(byt == PAYLOADMAX, "bytes recved on LEFT = %zi", byt);
	/* recv from socket, expect data on RIGHT channel */
	dst = NULL;
	byt = lc_socket_multi_recvmsg(sock, &msg, 0, &dst);
	test_assert(dst == chan[CHAN_RIGHT], "lc_channel_recvmsg() - RIGHT channel");
	test_assert(byt == PAYLOADMAX, "bytes recved on RIGHT = %zi", byt);

	sem_post(&timeout);
err_ctx_free:
	lc_ctx_free(lctx);
	return arg;
}

static void *send_data(void * arg)
{
	char *buf = (char *)arg;
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan[CHANNELS];
	ssize_t byt;

	lctx = lc_ctx_new();
	if (!lctx) return NULL;
	sock = lc_socket_new(lctx);
	if (!sock) goto err_ctx_free;
	if (lc_socket_bind(sock, ifidx) == -1) goto err_ctx_free;
	lc_socket_loop(sock, 1);
	for (int i = 0; i < CHANNELS; i++) {
		chan[i] = create_and_bind_channel(lctx, sock, i);
		if (!chan[i]) goto err_ctx_free;
	}

	/* run twice, for lc_socket_multi_recv and lc_socket_multi_recvmsg */
	for (int i = 0; i < 2; i++) {
		sem_wait(&receiver_ready);

		/* send on "LEFT" channel first */
		byt = lc_channel_send(chan[CHAN_LEFT], buf, PAYLOADMAX, 0);
		test_assert(byt == PAYLOADMAX, "lc_channel_recv() - data sent to LEFT channel");
		/* send on "RIGHT" channel now */
		byt = lc_channel_send(chan[CHAN_RIGHT], buf, PAYLOADMAX, 0);
		test_assert(byt == PAYLOADMAX, "lc_channel_recv() - data sent to RIGHT channel");
	}

err_ctx_free:
	lc_ctx_free(lctx);
	return arg;
}

int main(void)
{
	char name[] = "lc_socket_multi_recv()";
	char buf[CHANNELS][PAYLOADMAX];
	pthread_t tid[2];
	struct timespec ts;
	int rc;

	test_name(name);
	test_require_net(TEST_NET_BASIC);

	ifidx = get_multicast_if();
	if (!test_assert(ifidx > 0, "get_multicast_if()")) return test_status;

	sem_init(&timeout, 0, 0);
	sem_init(&receiver_ready, 0, 0);

	/* generate random data and clear recv buffer */
	arc4random_buf(buf[TID_SEND], PAYLOADMAX);
	memset(buf[TID_RECV], 0, PAYLOADMAX);

	pthread_create(&tid[TID_SEND], NULL, &send_data, buf[TID_SEND]);
	pthread_create(&tid[TID_RECV], NULL, &recv_data, buf[TID_RECV]);
	clock_gettime(CLOCK_REALTIME, &ts);
	ts.tv_sec += WAITS;
	rc = sem_timedwait(&timeout, &ts);
	test_assert(rc == 0, "timeout");
	pthread_cancel(tid[TID_RECV]);
	pthread_join(tid[TID_RECV], NULL);
	pthread_join(tid[TID_SEND], NULL);

	sem_destroy(&receiver_ready);
	sem_destroy(&timeout);

	rc = memcmp(buf[TID_RECV], buf[TID_SEND], PAYLOADMAX);
	test_assert(rc == 0, "recv and send buffers match");

	return test_status;
}
