/*
 * src/kcov/stub.c: KCOV sanitizer coverage hooks
 *
 * Copyright (c) 2026 Ali Polatel <alip@chesswob.org>
 * SPDX-License-Identifier: GPL-3.0
 */

/*
 * # Safety
 *
 * This file is compiled WITHOUT instrumentation to avoid infinite recursion.
 */

#include <errno.h>
#include <stdatomic.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <unistd.h>

// KCOV context
struct kcov_ctx {
	// Memory file descriptor
	int fd;
	// Buffer size in words
	uint64_t words;
	// 0=PC, 1=CMP
	int mode;
};

// External rust functions called via FFI.
extern bool syd_kcov_get_ctx(struct kcov_ctx *out_ctx) __attribute__((weak));

// Atomic recursion guard (not instrumented).
static _Thread_local atomic_bool tls_rec = false;

/*
 * Per-thread syscall number (C mirror of Rust's TLS_SYS).
 *
 * Syd processes all syscalls through the same internal code paths,
 * producing identical PCs regardless of the syscall type. Mixing the
 * active syscall number into each PC makes the same Syd function
 * produce different signal for different syscalls.
 */
static _Thread_local long tls_sys = -1;

/* Called from Rust before entering the syscall handler. */
__attribute__((no_sanitize("coverage"))) void syd_kcov_set_syscall(long nr)
{
	tls_sys = nr;
}

/* Mix syscall number into a relative PC. */
static inline uint64_t mix_syscall(uint64_t pc)
{
	if (tls_sys > 0) {
		pc ^= (uint64_t)tls_sys * 0x517cc1b727220a95ULL;
	}
	return pc;
}

/*
 * Binary base address for ASLR-independent PCs.
 *
 * Syd may be a static-pie binary; __builtin_return_address(0) yields
 * ASLR-randomised absolute addresses that differ between runs.
 * Subtracting the load base turns each PC into a fixed offset within
 * the binary, producing the same canonical PCs across runs.
 */
extern char __executable_start[] __attribute__((weak));
static uint64_t base_addr;
static atomic_bool base_init = false;

static inline uint64_t get_base_addr(void)
{
	if (__builtin_expect(!atomic_load_explicit(&base_init, memory_order_acquire),
	                     0)) {
		base_addr = __executable_start ? (uint64_t)__executable_start : 0;
		atomic_store_explicit(&base_init, true, memory_order_release);
	}
	return base_addr;
}

/* Return an ASLR-independent PC from a raw return address. */
static inline uint64_t pc_rel(uint64_t raw_pc)
{
	return raw_pc - get_base_addr();
}

// Convert to canonical PC format within the kernel text range that
// Syzkaller's signal filter expects:
// [0xFFFFFFFF80000000, 0xFFFFFFFFFF000000).
static uint64_t to_canon_pc(uint64_t v)
{
	v &= ~0xFULL; // 16-byte alignment
#if __SIZEOF_POINTER__ == 8
	// Base: 0xFFFFFFFF_80000000 Mask: 0x3FFF_FFF0 (~1 GiB, 16B aligned)
	return 0xFFFFFFFF80000000ULL | (v & 0x3FFFFFF0ULL);
#else
	uint32_t x = ((uint32_t)v) & 0x0FFFFFF0U;
	return (uint64_t)(0x80000000U | x);
#endif
}

// Safe write: Handles partial writes and EINTR.
static bool write_all_at(int fd, const void *buf, size_t count, off_t offset)
{
	const uint8_t *ptr = (const uint8_t *)buf;
	size_t written = 0;

	while (written < count) {
		ssize_t n = pwrite(fd, ptr + written, count - written, offset + written);
		if (n == -1) {
			if (errno == EINTR) {
				// Retry on EINTR.
				continue;
			}
			// Actual error.
			return false;
		}
		if (n == 0) {
			// Unexpected EOF.
			return false;
		}
		written += n;
	}
	return true;
}

// Safe read: Handles partial reads and EINTR.
static bool read_all_at(int fd, void *buf, size_t count, off_t offset)
{
	uint8_t *ptr = (uint8_t *)buf;
	size_t nread = 0;

	while (nread < count) {
		ssize_t n = pread(fd, ptr + nread, count - nread, offset + nread);
		if (n == -1) {
			if (errno == EINTR) {
				// Retry on EINTR.
				continue;
			}
			// Actual error.
			return false;
		}
		if (n == 0) {
			// Unexpected EOF.
			return false;
		}
		nread += n;
	}
	return true;
}

// Write a u64 at given offset.
static bool write_u64_at(int fd, uint64_t offset, uint64_t value)
{
	return write_all_at(fd, &value, sizeof(value), offset);
}

// Read a u64 at given offset.
static bool read_u64_at(int fd, uint64_t offset, uint64_t *out_value)
{
	return read_all_at(fd, out_value, sizeof(*out_value), offset);
}

// Record PC coverage.
static void record_pc_impl(uint64_t pc)
{
	// Fast atomic guard to prevent recursion.
	bool expected = false;
	if (!atomic_compare_exchange_strong(&tls_rec, &expected, true)) {
		// Already recording.
		return;
	}

	// Check if FFI functions are available (weak symbols may be NULL).
	if (!syd_kcov_get_ctx) {
		atomic_store(&tls_rec, false);
		return;
	}

	// Get context from Rust.
	struct kcov_ctx ctx;
	if (!syd_kcov_get_ctx(&ctx)) {
		atomic_store(&tls_rec, false);
		return;
	}

	if (ctx.mode != 0 || ctx.words <= 1) {
		atomic_store(&tls_rec, false);
		return;
	}

	// PC mode with valid context.
	uint64_t cap = ctx.words - 1;
	uint64_t cnt;

	// Read header.
	if (!read_u64_at(ctx.fd, 0, &cnt)) {
		atomic_store(&tls_rec, false);
		return;
	}

	if (cnt < cap) {
		// Write payload[cnt] = pc
		if (!write_u64_at(ctx.fd, (1 + cnt) * 8, to_canon_pc(mix_syscall(pc)))) {
			atomic_store(&tls_rec, false);
			return;
		}
		// Increment header.
		if (!write_u64_at(ctx.fd, 0, cnt + 1)) {
			atomic_store(&tls_rec, false);
			return;
		}
	} else if (cnt != cap) {
		// Clamp header to capacity.
		if (!write_u64_at(ctx.fd, 0, cap)) {
			atomic_store(&tls_rec, false);
			return;
		}
	}

	atomic_store(&tls_rec, false);
}

// Encode KCOV comparison type.
//
// type bit 0   : KCOV_CMP_CONST
// type bits 1-2: size code (1->0, 2->2, 4->4, 8->6)
static inline uint64_t kcov_cmp_type(uint8_t sz, bool is_const)
{
	uint64_t size_code;

	switch (sz) {
	case 1:
		size_code = 0;
		break;
	case 2:
		size_code = 2;
		break;
	case 4:
		size_code = 4;
		break;
	case 8:
		size_code = 6;
		break;
	default:
		size_code = 6;
		break;
	}

	return size_code | (is_const ? 1 : 0);
}

// Record CMP coverage.
static void record_cmp_impl(uint8_t sz, bool is_const, uint64_t a, uint64_t b,
                            uint64_t ip)
{
	// Fast atomic guard to prevent recursion.
	bool expected = false;
	if (!atomic_compare_exchange_strong(&tls_rec, &expected, true)) {
		// Already recording.
		return;
	}

	// Check if FFI functions are available (weak symbols may be NULL).
	if (!syd_kcov_get_ctx) {
		atomic_store(&tls_rec, false);
		return;
	}

	// Get context from Rust.
	struct kcov_ctx ctx;
	if (syd_kcov_get_ctx(&ctx)) {
		if (ctx.mode == 1 && ctx.words > 4) { // CMP mode.
			uint64_t payload_words = ctx.words - 1;
			uint64_t cap = payload_words / 4;
			uint64_t cnt;

			// Read header.
			if (!read_u64_at(ctx.fd, 0, &cnt)) {
				atomic_store(&tls_rec, false);
				return;
			}

			if (cnt < cap) {
				uint64_t base = cnt * 4;
				uint64_t ty = kcov_cmp_type(sz, is_const);

				// Write CMP record (4 words).
				if (!write_u64_at(ctx.fd, (1 + base) * 8, ty) ||
				    !write_u64_at(ctx.fd, (1 + base + 1) * 8, a) ||
				    !write_u64_at(ctx.fd, (1 + base + 2) * 8, b) ||
				    !write_u64_at(ctx.fd, (1 + base + 3) * 8,
				                  to_canon_pc(mix_syscall(ip)))) {
					atomic_store(&tls_rec, false);
					return;
				}
				// Increment header.
				if (!write_u64_at(ctx.fd, 0, cnt + 1)) {
					atomic_store(&tls_rec, false);
					return;
				}
			} else if (cnt != cap) {
				// Clamp header to capacity.
				if (!write_u64_at(ctx.fd, 0, cap)) {
					atomic_store(&tls_rec, false);
					return;
				}
			}
		}
	}

	atomic_store(&tls_rec, false);
}

// Sanitizer hooks which call the recording functions.
void __sanitizer_cov_trace_pc(void)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_pc_impl(pc);
}

void __sanitizer_cov_trace_cmp1(uint8_t a, uint8_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(1, false, a, b, pc);
}

void __sanitizer_cov_trace_cmp2(uint16_t a, uint16_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(2, false, a, b, pc);
}

void __sanitizer_cov_trace_cmp4(uint32_t a, uint32_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(4, false, a, b, pc);
}

void __sanitizer_cov_trace_cmp8(uint64_t a, uint64_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(8, false, a, b, pc);
}

void __sanitizer_cov_trace_const_cmp1(uint8_t a, uint8_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(1, true, a, b, pc);
}

void __sanitizer_cov_trace_const_cmp2(uint16_t a, uint16_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(2, true, a, b, pc);
}

void __sanitizer_cov_trace_const_cmp4(uint32_t a, uint32_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(4, true, a, b, pc);
}

void __sanitizer_cov_trace_const_cmp8(uint64_t a, uint64_t b)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_cmp_impl(8, true, a, b, pc);
}

void __sanitizer_cov_trace_switch(uint64_t val, uint64_t *cases)
{
	uint64_t pc = pc_rel((uint64_t)__builtin_return_address(0));
	record_pc_impl(pc);
}
