/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett A C Sheffield <bacs@librecast.net> */

/*
 * generate random test payloads, each > 1500 (MTU) bytes
 * on the receiver:
 * - create a socket and bind some channels to it
 * - enable RaptorQ encoding on the channels
 * - join the channels
 * - recv data on the socket
 *
 * on the sender:
 * - enable RaptorQ encoding on the channels
 * - send symbols interleaved
 *
 * TEST: all payloads received and match
 */

#include "test.h"

#ifdef HAVE_LIBLCRQ

#include "testdata.h"
#include "testnet.h"
#include <librecast/crypto.h>
#include <librecast/net.h>
#include <pthread.h>
#include <semaphore.h>

#define DO_INTERLEAVED 1  /* 1 => randomly interleave packets for different channels */
#define PAYLOADMAX 16384  /* maximum payload size */
#define PAYLOADMIN 1500   /* minimum payload size */
#define PAYLOADS 3        /* number of channels per socket */
#define WAITS 20          /* timeout (seconds) */

enum {
	TID_SEND,
	TID_RECV
};

static sem_t receiver_ready, timeout;
static int sz[PAYLOADS];
static unsigned int ifidx;
static char srcdata[PAYLOADS][PAYLOADMAX];
static char dstdata[PAYLOADS][PAYLOADMAX];

static lc_channel_t *create_and_bind_channel(lc_ctx_t *lctx, lc_socket_t *sock, int channel_number)
{
	lc_channel_t *chan;
	char chanstr[3];
	int rc;

	assert(channel_number < 100 && channel_number >= 0);
	sprintf(chanstr, "%02d", channel_number);
	chan = lc_channel_new(lctx, chanstr);
	if (!chan) return NULL;
	rc = lc_channel_bind(sock, chan);
	if (!test_assert(rc == 0, "%s: lc_channel_bind()", chanstr)) return NULL;
	rc = lc_channel_coding_set(chan, LC_CODE_FEC_RQ);
	if (!test_assert(rc == LC_CODE_FEC_RQ, "%s: lc_channel_coding_set()", chanstr)) return NULL;
	return chan;
}

#if DO_INTERLEAVED
static int alldone(int sent[PAYLOADS])
{
	for (int i = 0; i < PAYLOADS; i++) {
		if (sent[i]) return 0;
	}
	return -1;
}
#endif

static void *send_data_fec(void *arg)
{
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan[PAYLOADS];

	lctx = lc_ctx_new();
	if (!lctx) goto err_exit;
	sock = lc_socket_new(lctx);
	if (!sock) goto err_ctx_free;
	if (lc_socket_bind(sock, ifidx) == -1) goto err_ctx_free;
	lc_socket_loop(sock, 1);
	lc_ctx_ratelimit(lctx, 1024 * 1024, -1);
	for (int i = 0; i < PAYLOADS; i++) {
		chan[i] = create_and_bind_channel(lctx, sock, i);
		if (!chan[i]) goto err_ctx_free;
	}

	test_log("waiting for receiver\n");
	sem_wait(&receiver_ready);

#if DO_INTERLEAVED
	/* send stuff, interleaved */
	int sent[PAYLOADS];
	/* encode payload and send first packet on each channel */
	for (int i = 0; i < PAYLOADS; i++) {
		test_log("%d: attempting to send %i bytes\n", i, sz[i]);
		lc_channel_send(chan[i], srcdata[i], sz[i], 0); /* send first pkt */
		/* packet loss might be quite high sending this much at once:
		 * double the packets sent to compensate */
		sent[i] = (rq_KP(lc_channel_rq(chan[i])) + RQ_OVERHEAD) * 2;
	}
	/* keep sending randomly on all channels until all minimum has been sent to all */
	do {
		int p = arc4random_uniform(PAYLOADS);
		test_log("sending to channel %d\n", p);
		if (lc_channel_send(chan[p], NULL, 0, 0) == -1) {
			test_log("error sending to channel %d\n", p);
			break;
		}
		sent[p]--;
	}
	while (!alldone(sent));
#else
	/* interleaving disabled: send to each channel in order */
	for (int i = 0; i < PAYLOADS; i++) {
		test_log("%d: attempting to send %i bytes\n", i, sz[i]);
		lc_channel_send(chan[i], srcdata[i], sz[i], 0);
		const int pkts = rq_KP(lc_channel_rq(chan[i])) + RQ_OVERHEAD * 2;
		test_log("%d: sending %u packets\n", i, pkts);
		for (int j = 1; j < pkts * 2; j++) {
			if (lc_channel_send(chan[i], NULL, 0, 0) == -1) break;
		}
	}
#endif

err_ctx_free:
	lc_ctx_free(lctx);
err_exit:
	return arg;
}

static void *recv_data_fec(void *arg)
{
	char buf[PAYLOADMAX];
	lc_ctx_t *lctx;
	lc_socket_t *sock;
	lc_channel_t *chan[PAYLOADS];
	lc_channel_t *dst = NULL;
	int rc;

	lctx = lc_ctx_new();
	if (!lctx) goto err_exit;
	sock = lc_socket_new(lctx);
	if (!sock) goto err_ctx_free;
	if (lc_socket_bind(sock, ifidx) == -1) goto err_ctx_free;
	for (int i = 0; i < PAYLOADS; i++) {
		chan[i] = create_and_bind_channel(lctx, sock, i);
		if (!chan[i]) goto err_ctx_free;
		rc = lc_channel_join(chan[i]);
		if (!test_assert(rc == 0, "%02d: lc_channel_join()"))
			goto err_ctx_free;
	}
	sem_post(&receiver_ready);
	for (int i = 0; i < PAYLOADS; i++) {
		ssize_t byt = lc_socket_multi_recv(sock, buf, sizeof buf, 0, &dst);
		test_log("%zi byt recv'd\n", byt);
		test_log("dst = %p\n", (void *)dst);
		if (byt > 0) {
			/* copy data to correct buffer for dst channel */
			for (int x = 0; x < PAYLOADS; x++) {
				test_log("[%d] = %p\n", x, (void *)chan[x]);
				if (dst == chan[x]) memcpy(dstdata[x], buf, byt);
			}
		}
	}
err_ctx_free:
	lc_ctx_free(lctx);
err_exit:
	sem_post(&timeout);
	return arg;
}

static void generate_source_data(int sz[PAYLOADS], char srcdata[PAYLOADS][PAYLOADMAX])
{
	for (int i = 0; i < PAYLOADS; i++) {
		sz[i] = arc4random_uniform(PAYLOADMAX - PAYLOADMIN) + PAYLOADMIN;
		arc4random_buf(srcdata[i], sz[i]);
		arc4random_buf(srcdata[i], sz[i]);
		test_log("generated %i bytes\n", sz[i]);
	}
}
#endif /* HAVE_LIBLCRQ */

int main(void)
{
	char name[] = "multi-channel socket with RaptorQ encoding";
#ifndef HAVE_LIBLCRQ
	return test_skip(name);
#else
	struct timespec ts;
	pthread_t tid[2];
	int rc;

	test_name(name);
	test_require_net(TEST_NET_BASIC);
	ifidx = get_multicast_if();
	generate_source_data(sz, srcdata);

	/* ensure buffers differ before sync */
	for (int i = 0; i < PAYLOADS; i++) {
		memset(dstdata[i], 0, PAYLOADMAX);
		test_assert(memcmp(dstdata[i], srcdata[i], PAYLOADMAX) != 0, "payload %i differs", i);
	}

	sem_init(&timeout, 0, 0);
	sem_init(&receiver_ready, 0, 0);
	pthread_create(&tid[TID_SEND], NULL, &send_data_fec, srcdata);
	pthread_create(&tid[TID_RECV], NULL, &recv_data_fec, dstdata);
	clock_gettime(CLOCK_REALTIME, &ts);
	ts.tv_sec += WAITS;
	rc = sem_timedwait(&timeout, &ts);
	test_assert(rc == 0, "timeout");
	pthread_cancel(tid[TID_SEND]);
	pthread_cancel(tid[TID_RECV]);
	pthread_join(tid[TID_RECV], NULL);
	pthread_join(tid[TID_SEND], NULL);

	sem_destroy(&receiver_ready);
	sem_destroy(&timeout);

	for (int i = 0; i < PAYLOADS; i++) {
		for (int j = 0; j < sz[i]; j++) {
			if (srcdata[i][j] == dstdata[i][j]) continue;
			test_assert(0, "%d: diff at %d / %d", i, j, sz[i]);
			break;
		}
		test_assert(memcmp(dstdata[i], srcdata[i], sz[i]) == 0, "payload %i matches", i);
	}

	return test_status;
#endif
}
