/*	$NetBSD: aes_impl.c,v 1.10 2022/11/05 17:36:33 jmcneill Exp $	*/

/*-
 * Copyright (c) 2020 The NetBSD Foundation, Inc.
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
 * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
 * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
 * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */

#include <sys/cdefs.h>
__KERNEL_RCSID(1, "$NetBSD: aes_impl.c,v 1.10 2022/11/05 17:36:33 jmcneill Exp $");

#include <sys/types.h>
#include <sys/kernel.h>
#include <sys/module.h>
#include <sys/once.h>
#include <sys/sysctl.h>
#include <sys/systm.h>

#include <crypto/aes/aes.h>
#include <crypto/aes/aes_cbc.h>
#include <crypto/aes/aes_bear.h> /* default implementation */
#include <crypto/aes/aes_impl.h>
#include <crypto/aes/aes_xts.h>

static int aes_selftest_stdkeysched(void);

static const struct aes_impl	*aes_md_impl	__read_mostly;
static const struct aes_impl	*aes_impl	__read_mostly;

static int
sysctl_kern_crypto_aes_selected(SYSCTLFN_ARGS)
{
	struct sysctlnode node;

	KASSERTMSG(aes_impl != NULL,
	    "sysctl ran before AES implementation was selected");

	node = *rnode;
	node.sysctl_data = __UNCONST(aes_impl->ai_name);
	node.sysctl_size = strlen(aes_impl->ai_name) + 1;
	return sysctl_lookup(SYSCTLFN_CALL(&node));
}

SYSCTL_SETUP(sysctl_kern_crypto_aes_setup, "sysctl kern.crypto.aes setup")
{
	const struct sysctlnode *cnode;
	const struct sysctlnode *aes_node;

	sysctl_createv(clog, 0, NULL, &cnode, 0, CTLTYPE_NODE, "crypto",
	    SYSCTL_DESCR("Kernel cryptography"),
	    NULL, 0, NULL, 0,
	    CTL_KERN, CTL_CREATE, CTL_EOL);
	sysctl_createv(clog, 0, &cnode, &aes_node, 0, CTLTYPE_NODE, "aes",
	    SYSCTL_DESCR("AES -- Advanced Encryption Standard"),
	    NULL, 0, NULL, 0,
	    CTL_CREATE, CTL_EOL);
	sysctl_createv(clog, 0, &aes_node, NULL,
	    CTLFLAG_PERMANENT|CTLFLAG_READONLY, CTLTYPE_STRING, "selected",
	    SYSCTL_DESCR("Selected AES implementation"),
	    sysctl_kern_crypto_aes_selected, 0, NULL, 0,
	    CTL_CREATE, CTL_EOL);
}

/*
 * The timing of AES implementation selection is finicky:
 *
 *	1. It has to be done _after_ cpu_attach for implementations,
 *	   such as AES-NI, that rely on fpu initialization done by
 *	   fpu_attach.
 *
 *	2. It has to be done _before_ the cgd self-tests or anything
 *	   else that might call AES.
 *
 * For the moment, doing it in module init works.  However, if a
 * driver-class module depended on the aes module, that would break.
 */

static int
aes_select(void)
{

	KASSERT(aes_impl == NULL);

	if (aes_selftest_stdkeysched())
		panic("AES is busted");

	if (aes_md_impl) {
		if (aes_selftest(aes_md_impl))
			aprint_error("aes: self-test failed: %s\n",
			    aes_md_impl->ai_name);
		else
			aes_impl = aes_md_impl;
	}
	if (aes_impl == NULL) {
		if (aes_selftest(&aes_bear_impl))
			aprint_error("aes: self-test failed: %s\n",
			    aes_bear_impl.ai_name);
		else
			aes_impl = &aes_bear_impl;
	}
	if (aes_impl == NULL)
		panic("AES self-tests failed");

	aprint_debug("aes: %s\n", aes_impl->ai_name);
	return 0;
}

MODULE(MODULE_CLASS_MISC, aes, NULL);

static int
aes_modcmd(modcmd_t cmd, void *opaque)
{

	switch (cmd) {
	case MODULE_CMD_INIT:
		return aes_select();
	case MODULE_CMD_FINI:
		return 0;
	default:
		return ENOTTY;
	}
}

static void
aes_guarantee_selected(void)
{
#if 0
	static once_t once;
	int error;

	error = RUN_ONCE(&once, aes_select);
	KASSERT(error == 0);
#endif
}

void
aes_md_init(const struct aes_impl *impl)
{

	KASSERT(cold);
	KASSERTMSG(aes_impl == NULL,
	    "AES implementation `%s' already chosen, can't offer `%s'",
	    aes_impl->ai_name, impl->ai_name);
	KASSERTMSG(aes_md_impl == NULL,
	    "AES implementation `%s' already offered, can't offer `%s'",
	    aes_md_impl->ai_name, impl->ai_name);

	aes_md_impl = impl;
}

static void
aes_setenckey(struct aesenc *enc, const uint8_t key[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_setenckey(enc, key, nrounds);
}

uint32_t
aes_setenckey128(struct aesenc *enc, const uint8_t key[static 16])
{
	uint32_t nrounds = AES_128_NROUNDS;

	aes_setenckey(enc, key, nrounds);
	return nrounds;
}

uint32_t
aes_setenckey192(struct aesenc *enc, const uint8_t key[static 24])
{
	uint32_t nrounds = AES_192_NROUNDS;

	aes_setenckey(enc, key, nrounds);
	return nrounds;
}

uint32_t
aes_setenckey256(struct aesenc *enc, const uint8_t key[static 32])
{
	uint32_t nrounds = AES_256_NROUNDS;

	aes_setenckey(enc, key, nrounds);
	return nrounds;
}

static void
aes_setdeckey(struct aesdec *dec, const uint8_t key[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_setdeckey(dec, key, nrounds);
}

uint32_t
aes_setdeckey128(struct aesdec *dec, const uint8_t key[static 16])
{
	uint32_t nrounds = AES_128_NROUNDS;

	aes_setdeckey(dec, key, nrounds);
	return nrounds;
}

uint32_t
aes_setdeckey192(struct aesdec *dec, const uint8_t key[static 24])
{
	uint32_t nrounds = AES_192_NROUNDS;

	aes_setdeckey(dec, key, nrounds);
	return nrounds;
}

uint32_t
aes_setdeckey256(struct aesdec *dec, const uint8_t key[static 32])
{
	uint32_t nrounds = AES_256_NROUNDS;

	aes_setdeckey(dec, key, nrounds);
	return nrounds;
}

void
aes_enc(const struct aesenc *enc, const uint8_t in[static 16],
    uint8_t out[static 16], uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_enc(enc, in, out, nrounds);
}

void
aes_dec(const struct aesdec *dec, const uint8_t in[static 16],
    uint8_t out[static 16], uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_dec(dec, in, out, nrounds);
}

void
aes_cbc_enc(struct aesenc *enc, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t iv[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_cbc_enc(enc, in, out, nbytes, iv, nrounds);
}

void
aes_cbc_dec(struct aesdec *dec, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t iv[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_cbc_dec(dec, in, out, nbytes, iv, nrounds);
}

void
aes_xts_enc(struct aesenc *enc, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t tweak[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_xts_enc(enc, in, out, nbytes, tweak, nrounds);
}

void
aes_xts_dec(struct aesdec *dec, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t tweak[static 16],
    uint32_t nrounds)
{

	aes_guarantee_selected();
	aes_impl->ai_xts_dec(dec, in, out, nbytes, tweak, nrounds);
}

void
aes_cbcmac_update1(const struct aesenc *enc, const uint8_t in[static 16],
    size_t nbytes, uint8_t auth[static 16], uint32_t nrounds)
{

	KASSERT(nbytes);
	KASSERT(nbytes % 16 == 0);

	aes_guarantee_selected();
	aes_impl->ai_cbcmac_update1(enc, in, nbytes, auth, nrounds);
}

void
aes_ccm_enc1(const struct aesenc *enc, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t authctr[static 32],
    uint32_t nrounds)
{

	KASSERT(nbytes);
	KASSERT(nbytes % 16 == 0);

	aes_guarantee_selected();
	aes_impl->ai_ccm_enc1(enc, in, out, nbytes, authctr, nrounds);
}

void
aes_ccm_dec1(const struct aesenc *enc, const uint8_t in[static 16],
    uint8_t out[static 16], size_t nbytes, uint8_t authctr[static 32],
    uint32_t nrounds)
{

	KASSERT(nbytes);
	KASSERT(nbytes % 16 == 0);

	aes_guarantee_selected();
	aes_impl->ai_ccm_dec1(enc, in, out, nbytes, authctr, nrounds);
}

/*
 * Known-answer self-tests for the standard key schedule.
 */
static int
aes_selftest_stdkeysched(void)
{
	static const uint8_t key[32] = {
		0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,
		0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f,
		0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17,
		0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f,
	};
	static const uint32_t rk128enc[] = {
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
		0xfd74aad6, 0xfa72afd2, 0xf178a6da, 0xfe76abd6,
		0x0bcf92b6, 0xf1bd3d64, 0x00c59bbe, 0xfeb33068,
		0x4e74ffb6, 0xbfc9c2d2, 0xbf0c596c, 0x41bf6904,
		0xbcf7f747, 0x033e3595, 0xbc326cf9, 0xfd8d05fd,
		0xe8a3aa3c, 0xeb9d9fa9, 0x57aff350, 0xaa22f6ad,
		0x7d0f395e, 0x9692a6f7, 0xc13d55a7, 0x6b1fa30a,
		0x1a70f914, 0x8ce25fe3, 0x4ddf0a44, 0x26c0a94e,
		0x35874347, 0xb9651ca4, 0xf4ba16e0, 0xd27abfae,
		0xd1329954, 0x685785f0, 0x9ced9310, 0x4e972cbe,
		0x7f1d1113, 0x174a94e3, 0x8ba707f3, 0xc5302b4d,
	};
	static const uint32_t rk192enc[] = {
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
		0x13121110, 0x17161514, 0xf9f24658, 0xfef4435c,
		0xf5fe4a54, 0xfaf04758, 0xe9e25648, 0xfef4435c,
		0xb349f940, 0x4dbdba1c, 0xb843f048, 0x42b3b710,
		0xab51e158, 0x55a5a204, 0x41b5ff7e, 0x0c084562,
		0xb44bb52a, 0xf6f8023a, 0x5da9e362, 0x080c4166,
		0x728501f5, 0x7e8d4497, 0xcac6f1bd, 0x3c3ef387,
		0x619710e5, 0x699b5183, 0x9e7c1534, 0xe0f151a3,
		0x2a37a01e, 0x16095399, 0x779e437c, 0x1e0512ff,
		0x880e7edd, 0x68ff2f7e, 0x42c88f60, 0x54c1dcf9,
		0x235f9f85, 0x3d5a8d7a, 0x5229c0c0, 0x3ad6efbe,
		0x781e60de, 0x2cdfbc27, 0x0f8023a2, 0x32daaed8,
		0x330a97a4, 0x09dc781a, 0x71c218c4, 0x5d1da4e3,
	};
	static const uint32_t rk256enc[] = {
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
		0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c,
		0x9fc273a5, 0x98c476a1, 0x93ce7fa9, 0x9cc072a5,
		0xcda85116, 0xdabe4402, 0xc1a45d1a, 0xdeba4006,
		0xf0df87ae, 0x681bf10f, 0xfbd58ea6, 0x6715fc03,
		0x48f1e16d, 0x924fa56f, 0x53ebf875, 0x8d51b873,
		0x7f8256c6, 0x1799a7c9, 0xec4c296f, 0x8b59d56c,
		0x753ae23d, 0xe7754752, 0xb49ebf27, 0x39cf0754,
		0x5f90dc0b, 0x48097bc2, 0xa44552ad, 0x2f1c87c1,
		0x60a6f545, 0x87d3b217, 0x334d0d30, 0x0a820a64,
		0x1cf7cf7c, 0x54feb4be, 0xf0bbe613, 0xdfa761d2,
		0xfefa1af0, 0x7929a8e7, 0x4a64a5d7, 0x40e6afb3,
		0x71fe4125, 0x2500f59b, 0xd5bb1388, 0x0a1c725a,
		0x99665a4e, 0xe04ff2a9, 0xaa2b577e, 0xeacdf8cd,
		0xcc79fc24, 0xe97909bf, 0x3cc21a37, 0x36de686d,
	};
	static const uint32_t rk128dec[] = {
		0x7f1d1113, 0x174a94e3, 0x8ba707f3, 0xc5302b4d,
		0xbe29aa13, 0xf6af8f9c, 0x80f570f7, 0x03bff700,
		0x63a46213, 0x4886258f, 0x765aff6b, 0x834a87f7,
		0x74fc828d, 0x2b22479c, 0x3edcdae4, 0xf510789c,
		0x8d09e372, 0x5fdec511, 0x15fe9d78, 0xcbcca278,
		0x2710c42e, 0xd2d72663, 0x4a205869, 0xde323f00,
		0x04f5a2a8, 0xf5c7e24d, 0x98f77e0a, 0x94126769,
		0x91e3c6c7, 0xf13240e5, 0x6d309c47, 0x0ce51963,
		0x9902dba0, 0x60d18622, 0x9c02dca2, 0x61d58524,
		0xf0df568c, 0xf9d35d82, 0xfcd35a80, 0xfdd75986,
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
	};
	static const uint32_t rk192dec[] = {
		0x330a97a4, 0x09dc781a, 0x71c218c4, 0x5d1da4e3,
		0x0dbdbed6, 0x49ea09c2, 0x8073b04d, 0xb91b023e,
		0xc999b98f, 0x3968b273, 0x9dd8f9c7, 0x728cc685,
		0xc16e7df7, 0xef543f42, 0x7f317853, 0x4457b714,
		0x90654711, 0x3b66cf47, 0x8dce0e9b, 0xf0f10bfc,
		0xb6a8c1dc, 0x7d3f0567, 0x4a195ccc, 0x2e3a42b5,
		0xabb0dec6, 0x64231e79, 0xbe5f05a4, 0xab038856,
		0xda7c1bdd, 0x155c8df2, 0x1dab498a, 0xcb97c4bb,
		0x08f7c478, 0xd63c8d31, 0x01b75596, 0xcf93c0bf,
		0x10efdc60, 0xce249529, 0x15efdb62, 0xcf20962f,
		0xdbcb4e4b, 0xdacf4d4d, 0xc7d75257, 0xdecb4949,
		0x1d181f1a, 0x191c1b1e, 0xd7c74247, 0xdecb4949,
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
	};
	static const uint32_t rk256dec[] = {
		0xcc79fc24, 0xe97909bf, 0x3cc21a37, 0x36de686d,
		0xffd1f134, 0x2faacebf, 0x5fe2e9fc, 0x6e015825,
		0xeb48165e, 0x0a354c38, 0x46b77175, 0x84e680dc,
		0x8005a3c8, 0xd07b3f8b, 0x70482743, 0x31e3b1d9,
		0x138e70b5, 0xe17d5a66, 0x4c823d4d, 0xc251f1a9,
		0xa37bda74, 0x507e9c43, 0xa03318c8, 0x41ab969a,
		0x1597a63c, 0xf2f32ad3, 0xadff672b, 0x8ed3cce4,
		0xf3c45ff8, 0xf3054637, 0xf04d848b, 0xe1988e52,
		0x9a4069de, 0xe7648cef, 0x5f0c4df8, 0x232cabcf,
		0x1658d5ae, 0x00c119cf, 0x0348c2bc, 0x11d50ad9,
		0xbd68c615, 0x7d24e531, 0xb868c117, 0x7c20e637,
		0x0f85d77f, 0x1699cc61, 0x0389db73, 0x129dc865,
		0xc940282a, 0xc04c2324, 0xc54c2426, 0xc4482720,
		0x1d181f1a, 0x191c1b1e, 0x15101712, 0x11141316,
		0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c,
	};
	static const struct {
		unsigned	len;
		unsigned	nr;
		const uint32_t	*enc, *dec;
	} C[] = {
		{ 16, AES_128_NROUNDS, rk128enc, rk128dec },
		{ 24, AES_192_NROUNDS, rk192enc, rk192dec },
		{ 32, AES_256_NROUNDS, rk256enc, rk256dec },
	};
	uint32_t rk[60];
	unsigned i;

	for (i = 0; i < __arraycount(C); i++) {
		if (br_aes_ct_keysched_stdenc(rk, key, C[i].len) != C[i].nr)
			return -1;
		if (memcmp(rk, C[i].enc, 4*(C[i].nr + 1)))
			return -1;
		if (br_aes_ct_keysched_stddec(rk, key, C[i].len) != C[i].nr)
			return -1;
		if (memcmp(rk, C[i].dec, 4*(C[i].nr + 1)))
			return -1;
	}

	return 0;
}