git: 71e6792cbe81 - stable/13 - ktls: Add simple transmit tests of kernel TLS.
- Go to: [ bottom of page ] [ top of archives ] [ this month ]
Date: Tue, 23 Nov 2021 23:12:57 UTC
The branch stable/13 has been updated by jhb: URL: https://cgit.FreeBSD.org/src/commit/?id=71e6792cbe81f6fcbfdf545ea7c04b2ae3bfda50 commit 71e6792cbe81f6fcbfdf545ea7c04b2ae3bfda50 Author: John Baldwin <jhb@FreeBSD.org> AuthorDate: 2021-11-01 18:28:10 +0000 Commit: John Baldwin <jhb@FreeBSD.org> CommitDate: 2021-11-23 23:11:45 +0000 ktls: Add simple transmit tests of kernel TLS. Note that these tests test the kernel TLS functionality directly. Rather than using OpenSSL to perform negotiation and generate keys, these tests generate random keys send data over a pair of TCP sockets manually decrypting the TLS records generated by the kernel. Reviewed by: markj Sponsored by: Netflix Differential Revision: https://reviews.freebsd.org/D32652 (cherry picked from commit a10482ea7476d68d1ab028145ae6d97cef747b49) --- tests/sys/kern/Makefile | 2 + tests/sys/kern/ktls_test.c | 1033 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1035 insertions(+) diff --git a/tests/sys/kern/Makefile b/tests/sys/kern/Makefile index 6746812d9b4a..ee9decac518c 100644 --- a/tests/sys/kern/Makefile +++ b/tests/sys/kern/Makefile @@ -12,6 +12,7 @@ ATF_TESTS_C+= kern_copyin ATF_TESTS_C+= kern_descrip_test ATF_TESTS_C+= fdgrowtable_test ATF_TESTS_C+= kill_zombie +ATF_TESTS_C+= ktls_test ATF_TESTS_C+= ptrace_test TEST_METADATA.ptrace_test+= timeout="15" ATF_TESTS_C+= reaper @@ -46,6 +47,7 @@ LIBADD.sys_getrandom+= pthread LIBADD.ptrace_test+= pthread LIBADD.unix_seqpacket_test+= pthread LIBADD.kcov+= pthread +LIBADD.ktls_test+= crypto LIBADD.sendfile_helper+= pthread LIBADD.fdgrowtable_test+= util pthread kvm procstat diff --git a/tests/sys/kern/ktls_test.c b/tests/sys/kern/ktls_test.c new file mode 100644 index 000000000000..908f7f1818a2 --- /dev/null +++ b/tests/sys/kern/ktls_test.c @@ -0,0 +1,1033 @@ +/*- + * SPDX-License-Identifier: BSD-2-Clause + * + * Copyright (c) 2021 Netflix Inc. + * Written by: John Baldwin <jhb@FreeBSD.org> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include <sys/types.h> +#include <sys/endian.h> +#include <sys/event.h> +#include <sys/ktls.h> +#include <sys/socket.h> +#include <sys/sysctl.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#include <crypto/cryptodev.h> +#include <assert.h> +#include <err.h> +#include <fcntl.h> +#include <poll.h> +#include <stdbool.h> +#include <stdlib.h> +#include <atf-c.h> + +#include <openssl/err.h> +#include <openssl/evp.h> +#include <openssl/hmac.h> + +static void +require_ktls(void) +{ + size_t len; + bool enable; + + len = sizeof(enable); + if (sysctlbyname("kern.ipc.tls.enable", &enable, &len, NULL, 0) == -1) { + if (errno == ENOENT) + atf_tc_skip("kernel does not support TLS offload"); + atf_libc_error(errno, "Failed to read kern.ipc.tls.enable"); + } + + if (!enable) + atf_tc_skip("Kernel TLS is disabled"); +} + +#define ATF_REQUIRE_KTLS() require_ktls() + +static char +rdigit(void) +{ + /* ASCII printable values between 0x20 and 0x7e */ + return (0x20 + random() % (0x7f - 0x20)); +} + +static char * +alloc_buffer(size_t len) +{ + char *buf; + size_t i; + + if (len == 0) + return (NULL); + buf = malloc(len); + for (i = 0; i < len; i++) + buf[i] = rdigit(); + return (buf); +} + +static bool +socketpair_tcp(int *sv) +{ + struct pollfd pfd; + struct sockaddr_in sin; + socklen_t len; + int as, cs, ls; + + ls = socket(PF_INET, SOCK_STREAM, 0); + if (ls == -1) { + warn("socket() for listen"); + return (false); + } + + memset(&sin, 0, sizeof(sin)); + sin.sin_len = sizeof(sin); + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + if (bind(ls, (struct sockaddr *)&sin, sizeof(sin)) == -1) { + warn("bind"); + close(ls); + return (false); + } + + if (listen(ls, 1) == -1) { + warn("listen"); + close(ls); + return (false); + } + + len = sizeof(sin); + if (getsockname(ls, (struct sockaddr *)&sin, &len) == -1) { + warn("getsockname"); + close(ls); + return (false); + } + + cs = socket(PF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); + if (cs == -1) { + warn("socket() for connect"); + close(ls); + return (false); + } + + if (connect(cs, (struct sockaddr *)&sin, sizeof(sin)) == -1) { + if (errno != EINPROGRESS) { + warn("connect"); + close(ls); + close(cs); + return (false); + } + } + + as = accept4(ls, NULL, NULL, SOCK_NONBLOCK); + if (as == -1) { + warn("accept4"); + close(ls); + close(cs); + return (false); + } + + close(ls); + + pfd.fd = cs; + pfd.events = POLLOUT; + pfd.revents = 0; + ATF_REQUIRE(poll(&pfd, 1, INFTIM) == 1); + ATF_REQUIRE(pfd.revents == POLLOUT); + + sv[0] = cs; + sv[1] = as; + return (true); +} + +static void +fd_set_blocking(int fd) +{ + int flags; + + ATF_REQUIRE((flags = fcntl(fd, F_GETFL)) != -1); + flags &= ~O_NONBLOCK; + ATF_REQUIRE(fcntl(fd, F_SETFL, flags) != -1); +} + +static bool +cbc_decrypt(const EVP_CIPHER *cipher, const char *key, const char *iv, + const char *input, char *output, size_t size) +{ + EVP_CIPHER_CTX *ctx; + int outl, total; + + ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + warnx("EVP_CIPHER_CTX_new failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + return (false); + } + if (EVP_CipherInit_ex(ctx, cipher, NULL, (const u_char *)key, + (const u_char *)iv, 0) != 1) { + warnx("EVP_CipherInit_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + EVP_CIPHER_CTX_set_padding(ctx, 0); + if (EVP_CipherUpdate(ctx, (u_char *)output, &outl, + (const u_char *)input, size) != 1) { + warnx("EVP_CipherUpdate failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + total = outl; + if (EVP_CipherFinal_ex(ctx, (u_char *)output + outl, &outl) != 1) { + warnx("EVP_CipherFinal_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + total += outl; + if ((size_t)total != size) { + warnx("decrypt size mismatch: %zu vs %d", size, total); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + EVP_CIPHER_CTX_free(ctx); + return (true); +} + +static bool +verify_hash(const EVP_MD *md, const void *key, size_t key_len, const void *aad, + size_t aad_len, const void *buffer, size_t len, const void *digest) +{ + HMAC_CTX *ctx; + unsigned char digest2[EVP_MAX_MD_SIZE]; + u_int digest_len; + + ctx = HMAC_CTX_new(); + if (ctx == NULL) { + warnx("HMAC_CTX_new failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + return (false); + } + if (HMAC_Init_ex(ctx, key, key_len, md, NULL) != 1) { + warnx("HMAC_Init_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + HMAC_CTX_free(ctx); + return (false); + } + if (HMAC_Update(ctx, aad, aad_len) != 1) { + warnx("HMAC_Update (aad) failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + HMAC_CTX_free(ctx); + return (false); + } + if (HMAC_Update(ctx, buffer, len) != 1) { + warnx("HMAC_Update (payload) failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + HMAC_CTX_free(ctx); + return (false); + } + if (HMAC_Final(ctx, digest2, &digest_len) != 1) { + warnx("HMAC_Final failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + HMAC_CTX_free(ctx); + return (false); + } + HMAC_CTX_free(ctx); + if (memcmp(digest, digest2, digest_len) != 0) { + warnx("HMAC mismatch"); + return (false); + } + return (true); +} + +static bool +aead_decrypt(const EVP_CIPHER *cipher, const char *key, const char *nonce, + const void *aad, size_t aad_len, const char *input, char *output, + size_t size, const char *tag, size_t tag_len) +{ + EVP_CIPHER_CTX *ctx; + int outl, total; + bool valid; + + ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + warnx("EVP_CIPHER_CTX_new failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + return (false); + } + if (EVP_DecryptInit_ex(ctx, cipher, NULL, (const u_char *)key, + (const u_char *)nonce) != 1) { + warnx("EVP_DecryptInit_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + EVP_CIPHER_CTX_set_padding(ctx, 0); + if (aad != NULL) { + if (EVP_DecryptUpdate(ctx, NULL, &outl, (const u_char *)aad, + aad_len) != 1) { + warnx("EVP_DecryptUpdate for AAD failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + } + if (EVP_DecryptUpdate(ctx, (u_char *)output, &outl, + (const u_char *)input, size) != 1) { + warnx("EVP_DecryptUpdate failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + total = outl; + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, tag_len, + __DECONST(char *, tag)) != 1) { + warnx("EVP_CIPHER_CTX_ctrl(EVP_CTRL_AEAD_SET_TAG) failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + valid = (EVP_DecryptFinal_ex(ctx, (u_char *)output + outl, &outl) == 1); + total += outl; + if ((size_t)total != size) { + warnx("decrypt size mismatch: %zu vs %d", size, total); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + if (!valid) + warnx("tag mismatch"); + EVP_CIPHER_CTX_free(ctx); + return (valid); +} + +static void +build_tls_enable(int cipher_alg, size_t cipher_key_len, int auth_alg, + int minor, uint64_t seqno, struct tls_enable *en) +{ + u_int auth_key_len, iv_len; + + memset(en, 0, sizeof(*en)); + + switch (cipher_alg) { + case CRYPTO_AES_CBC: + if (minor == TLS_MINOR_VER_ZERO) + iv_len = AES_BLOCK_LEN; + else + iv_len = 0; + break; + case CRYPTO_AES_NIST_GCM_16: + if (minor == TLS_MINOR_VER_TWO) + iv_len = TLS_AEAD_GCM_LEN; + else + iv_len = TLS_1_3_GCM_IV_LEN; + break; + case CRYPTO_CHACHA20_POLY1305: + iv_len = TLS_CHACHA20_IV_LEN; + break; + default: + iv_len = 0; + break; + } + switch (auth_alg) { + case CRYPTO_SHA1_HMAC: + auth_key_len = SHA1_HASH_LEN; + break; + case CRYPTO_SHA2_256_HMAC: + auth_key_len = SHA2_256_HASH_LEN; + break; + case CRYPTO_SHA2_384_HMAC: + auth_key_len = SHA2_384_HASH_LEN; + break; + default: + auth_key_len = 0; + break; + } + en->cipher_key = alloc_buffer(cipher_key_len); + en->iv = alloc_buffer(iv_len); + en->auth_key = alloc_buffer(auth_key_len); + en->cipher_algorithm = cipher_alg; + en->cipher_key_len = cipher_key_len; + en->iv_len = iv_len; + en->auth_algorithm = auth_alg; + en->auth_key_len = auth_key_len; + en->tls_vmajor = TLS_MAJOR_VER_ONE; + en->tls_vminor = minor; + be64enc(en->rec_seq, seqno); +} + +static void +free_tls_enable(struct tls_enable *en) +{ + free(__DECONST(void *, en->cipher_key)); + free(__DECONST(void *, en->iv)); + free(__DECONST(void *, en->auth_key)); +} + +static const EVP_CIPHER * +tls_EVP_CIPHER(const struct tls_enable *en) +{ + switch (en->cipher_algorithm) { + case CRYPTO_AES_CBC: + switch (en->cipher_key_len) { + case 128 / 8: + return (EVP_aes_128_cbc()); + case 256 / 8: + return (EVP_aes_256_cbc()); + default: + return (NULL); + } + break; + case CRYPTO_AES_NIST_GCM_16: + switch (en->cipher_key_len) { + case 128 / 8: + return (EVP_aes_128_gcm()); + case 256 / 8: + return (EVP_aes_256_gcm()); + default: + return (NULL); + } + break; + case CRYPTO_CHACHA20_POLY1305: + return (EVP_chacha20_poly1305()); + default: + return (NULL); + } +} + +static const EVP_MD * +tls_EVP_MD(const struct tls_enable *en) +{ + switch (en->auth_algorithm) { + case CRYPTO_SHA1_HMAC: + return (EVP_sha1()); + case CRYPTO_SHA2_256_HMAC: + return (EVP_sha256()); + case CRYPTO_SHA2_384_HMAC: + return (EVP_sha384()); + default: + return (NULL); + } +} + +static size_t +tls_header_len(struct tls_enable *en) +{ + size_t len; + + len = sizeof(struct tls_record_layer); + switch (en->cipher_algorithm) { + case CRYPTO_AES_CBC: + if (en->tls_vminor != TLS_MINOR_VER_ZERO) + len += AES_BLOCK_LEN; + return (len); + case CRYPTO_AES_NIST_GCM_16: + if (en->tls_vminor == TLS_MINOR_VER_TWO) + len += sizeof(uint64_t); + return (len); + case CRYPTO_CHACHA20_POLY1305: + return (len); + default: + return (0); + } +} + +static size_t +tls_mac_len(struct tls_enable *en) +{ + switch (en->cipher_algorithm) { + case CRYPTO_AES_CBC: + switch (en->auth_algorithm) { + case CRYPTO_SHA1_HMAC: + return (SHA1_HASH_LEN); + case CRYPTO_SHA2_256_HMAC: + return (SHA2_256_HASH_LEN); + case CRYPTO_SHA2_384_HMAC: + return (SHA2_384_HASH_LEN); + default: + return (0); + } + case CRYPTO_AES_NIST_GCM_16: + return (AES_GMAC_HASH_LEN); + case CRYPTO_CHACHA20_POLY1305: + return (POLY1305_HASH_LEN); + default: + return (0); + } +} + +/* Includes maximum padding for MTE. */ +static size_t +tls_trailer_len(struct tls_enable *en) +{ + size_t len; + + len = tls_mac_len(en); + if (en->cipher_algorithm == CRYPTO_AES_CBC) + len += AES_BLOCK_LEN; + if (en->tls_vminor == TLS_MINOR_VER_THREE) + len++; + return (len); +} + +/* 'len' is the length of the payload application data. */ +static void +tls_mte_aad(struct tls_enable *en, size_t len, + const struct tls_record_layer *hdr, uint64_t seqno, struct tls_mac_data *ad) +{ + ad->seq = htobe64(seqno); + ad->type = hdr->tls_type; + ad->tls_vmajor = hdr->tls_vmajor; + ad->tls_vminor = hdr->tls_vminor; + ad->tls_length = htons(len); +} + +static void +tls_12_aead_aad(struct tls_enable *en, size_t len, + const struct tls_record_layer *hdr, uint64_t seqno, + struct tls_aead_data *ad) +{ + ad->seq = htobe64(seqno); + ad->type = hdr->tls_type; + ad->tls_vmajor = hdr->tls_vmajor; + ad->tls_vminor = hdr->tls_vminor; + ad->tls_length = htons(len); +} + +static void +tls_13_aad(struct tls_enable *en, const struct tls_record_layer *hdr, + uint64_t seqno, struct tls_aead_data_13 *ad) +{ + ad->type = hdr->tls_type; + ad->tls_vmajor = hdr->tls_vmajor; + ad->tls_vminor = hdr->tls_vminor; + ad->tls_length = hdr->tls_length; +} + +static void +tls_12_gcm_nonce(struct tls_enable *en, const struct tls_record_layer *hdr, + char *nonce) +{ + memcpy(nonce, en->iv, TLS_AEAD_GCM_LEN); + memcpy(nonce + TLS_AEAD_GCM_LEN, hdr + 1, sizeof(uint64_t)); +} + +static void +tls_13_nonce(struct tls_enable *en, uint64_t seqno, char *nonce) +{ + static_assert(TLS_1_3_GCM_IV_LEN == TLS_CHACHA20_IV_LEN, + "TLS 1.3 nonce length mismatch"); + memcpy(nonce, en->iv, TLS_1_3_GCM_IV_LEN); + *(uint64_t *)(nonce + 4) ^= htobe64(seqno); +} + +/* + * Decrypt a TLS record 'len' bytes long at 'src' and store the result at + * 'dst'. If the TLS record header length doesn't match or 'dst' doesn't + * have sufficient room ('avail'), fail the test. + */ +static size_t +decrypt_tls_aes_cbc_mte(struct tls_enable *en, uint64_t seqno, const void *src, + size_t len, void *dst, size_t avail, uint8_t *record_type) +{ + const struct tls_record_layer *hdr; + struct tls_mac_data aad; + const char *iv; + char *buf; + size_t hdr_len, mac_len, payload_len; + int padding; + + hdr = src; + hdr_len = tls_header_len(en); + mac_len = tls_mac_len(en); + ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE); + ATF_REQUIRE(hdr->tls_vminor == en->tls_vminor); + + /* First, decrypt the outer payload into a temporary buffer. */ + payload_len = len - hdr_len; + buf = malloc(payload_len); + if (en->tls_vminor == TLS_MINOR_VER_ZERO) + iv = en->iv; + else + iv = (void *)(hdr + 1); + ATF_REQUIRE(cbc_decrypt(tls_EVP_CIPHER(en), en->cipher_key, iv, + (const u_char *)src + hdr_len, buf, payload_len)); + + /* + * Copy the last encrypted block to use as the IV for the next + * record for TLS 1.0. + */ + if (en->tls_vminor == TLS_MINOR_VER_ZERO) + memcpy(__DECONST(uint8_t *, en->iv), (const u_char *)src + + (len - AES_BLOCK_LEN), AES_BLOCK_LEN); + + /* + * Verify trailing padding and strip. + * + * The kernel always generates the smallest amount of padding. + */ + padding = buf[payload_len - 1] + 1; + ATF_REQUIRE(padding > 0 && padding <= AES_BLOCK_LEN); + ATF_REQUIRE(payload_len >= mac_len + padding); + payload_len -= padding; + + /* Verify HMAC. */ + payload_len -= mac_len; + tls_mte_aad(en, payload_len, hdr, seqno, &aad); + ATF_REQUIRE(verify_hash(tls_EVP_MD(en), en->auth_key, en->auth_key_len, + &aad, sizeof(aad), buf, payload_len, buf + payload_len)); + + ATF_REQUIRE(payload_len <= avail); + memcpy(dst, buf, payload_len); + *record_type = hdr->tls_type; + return (payload_len); +} + +static size_t +decrypt_tls_12_aead(struct tls_enable *en, uint64_t seqno, const void *src, + size_t len, void *dst, uint8_t *record_type) +{ + const struct tls_record_layer *hdr; + struct tls_aead_data aad; + char nonce[12]; + size_t hdr_len, mac_len, payload_len; + + hdr = src; + + hdr_len = tls_header_len(en); + mac_len = tls_mac_len(en); + payload_len = len - (hdr_len + mac_len); + ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE); + ATF_REQUIRE(hdr->tls_vminor == TLS_MINOR_VER_TWO); + + tls_12_aead_aad(en, payload_len, hdr, seqno, &aad); + if (en->cipher_algorithm == CRYPTO_AES_NIST_GCM_16) + tls_12_gcm_nonce(en, hdr, nonce); + else + tls_13_nonce(en, seqno, nonce); + + ATF_REQUIRE(aead_decrypt(tls_EVP_CIPHER(en), en->cipher_key, nonce, + &aad, sizeof(aad), (const char *)src + hdr_len, dst, payload_len, + (const char *)src + hdr_len + payload_len, mac_len)); + + *record_type = hdr->tls_type; + return (payload_len); +} + +static size_t +decrypt_tls_13_aead(struct tls_enable *en, uint64_t seqno, const void *src, + size_t len, void *dst, uint8_t *record_type) +{ + const struct tls_record_layer *hdr; + struct tls_aead_data_13 aad; + char nonce[12]; + char *buf; + size_t hdr_len, mac_len, payload_len; + + hdr = src; + + hdr_len = tls_header_len(en); + mac_len = tls_mac_len(en); + payload_len = len - (hdr_len + mac_len); + ATF_REQUIRE(payload_len >= 1); + ATF_REQUIRE(hdr->tls_type == TLS_RLTYPE_APP); + ATF_REQUIRE(hdr->tls_vmajor == TLS_MAJOR_VER_ONE); + ATF_REQUIRE(hdr->tls_vminor == TLS_MINOR_VER_TWO); + + tls_13_aad(en, hdr, seqno, &aad); + tls_13_nonce(en, seqno, nonce); + + /* + * Have to use a temporary buffer for the output due to the + * record type as the last byte of the trailer. + */ + buf = malloc(payload_len); + + ATF_REQUIRE(aead_decrypt(tls_EVP_CIPHER(en), en->cipher_key, nonce, + &aad, sizeof(aad), (const char *)src + hdr_len, buf, payload_len, + (const char *)src + hdr_len + payload_len, mac_len)); + + /* Trim record type. */ + *record_type = buf[payload_len - 1]; + payload_len--; + + memcpy(dst, buf, payload_len); + free(buf); + + return (payload_len); +} + +static size_t +decrypt_tls_aead(struct tls_enable *en, uint64_t seqno, const void *src, + size_t len, void *dst, size_t avail, uint8_t *record_type) +{ + const struct tls_record_layer *hdr; + size_t payload_len; + + hdr = src; + ATF_REQUIRE(ntohs(hdr->tls_length) + sizeof(*hdr) == len); + + payload_len = len - (tls_header_len(en) + tls_trailer_len(en)); + ATF_REQUIRE(payload_len <= avail); + + if (en->tls_vminor == TLS_MINOR_VER_TWO) { + ATF_REQUIRE(decrypt_tls_12_aead(en, seqno, src, len, dst, + record_type) == payload_len); + } else { + ATF_REQUIRE(decrypt_tls_13_aead(en, seqno, src, len, dst, + record_type) == payload_len); + } + + return (payload_len); +} + +static size_t +decrypt_tls_record(struct tls_enable *en, uint64_t seqno, const void *src, + size_t len, void *dst, size_t avail, uint8_t *record_type) +{ + if (en->cipher_algorithm == CRYPTO_AES_CBC) + return (decrypt_tls_aes_cbc_mte(en, seqno, src, len, dst, avail, + record_type)); + else + return (decrypt_tls_aead(en, seqno, src, len, dst, avail, + record_type)); +} + +static void +test_ktls_transmit_app_data(struct tls_enable *en, uint64_t seqno, size_t len) +{ + struct kevent ev; + struct tls_record_layer *hdr; + char *plaintext, *decrypted, *outbuf; + size_t decrypted_len, outbuf_len, outbuf_cap, record_len, written; + ssize_t rv; + int kq, sockets[2]; + uint8_t record_type; + + plaintext = alloc_buffer(len); + decrypted = malloc(len); + outbuf_cap = tls_header_len(en) + TLS_MAX_MSG_SIZE_V10_2 + + tls_trailer_len(en); + outbuf = malloc(outbuf_cap); + hdr = (struct tls_record_layer *)outbuf; + + ATF_REQUIRE((kq = kqueue()) != -1); + + ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets"); + + ATF_REQUIRE(setsockopt(sockets[1], IPPROTO_TCP, TCP_TXTLS_ENABLE, en, + sizeof(*en)) == 0); + + EV_SET(&ev, sockets[0], EVFILT_READ, EV_ADD, 0, 0, NULL); + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0); + EV_SET(&ev, sockets[1], EVFILT_WRITE, EV_ADD, 0, 0, NULL); + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0); + + decrypted_len = 0; + outbuf_len = 0; + written = 0; + + while (decrypted_len != len) { + ATF_REQUIRE(kevent(kq, NULL, 0, &ev, 1, NULL) == 1); + + switch (ev.filter) { + case EVFILT_WRITE: + /* Try to write any remaining data. */ + rv = write(ev.ident, plaintext + written, + len - written); + ATF_REQUIRE_MSG(rv > 0, + "failed to write to socket"); + written += rv; + if (written == len) { + ev.flags = EV_DISABLE; + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, + NULL) == 0); + } + break; + + case EVFILT_READ: + ATF_REQUIRE((ev.flags & EV_EOF) == 0); + + /* + * Try to read data for the next TLS record + * into outbuf. Start by reading the header + * to determine how much additional data to + * read. + */ + if (outbuf_len < sizeof(struct tls_record_layer)) { + rv = read(ev.ident, outbuf + outbuf_len, + sizeof(struct tls_record_layer) - + outbuf_len); + ATF_REQUIRE_MSG(rv > 0, + "failed to read from socket"); + outbuf_len += rv; + } + + if (outbuf_len < sizeof(struct tls_record_layer)) + break; + + record_len = sizeof(struct tls_record_layer) + + ntohs(hdr->tls_length); + assert(record_len <= outbuf_cap); + assert(record_len > outbuf_len); + rv = read(ev.ident, outbuf + outbuf_len, + record_len - outbuf_len); + if (rv == -1 && errno == EAGAIN) + break; + ATF_REQUIRE_MSG(rv > 0, "failed to read from socket"); + + outbuf_len += rv; + if (outbuf_len == record_len) { + decrypted_len += decrypt_tls_record(en, seqno, + outbuf, outbuf_len, + decrypted + decrypted_len, + len - decrypted_len, &record_type); + ATF_REQUIRE(record_type == TLS_RLTYPE_APP); + + seqno++; + outbuf_len = 0; + } + break; + } + } + + ATF_REQUIRE_MSG(written == decrypted_len, + "read %zu decrypted bytes, but wrote %zu", decrypted_len, written); + + ATF_REQUIRE(memcmp(plaintext, decrypted, len) == 0); + + free(outbuf); + free(decrypted); + free(plaintext); + + close(sockets[1]); + close(sockets[0]); + close(kq); +} + +static void +ktls_send_control_message(int fd, uint8_t type, void *data, size_t len) +{ + struct msghdr msg; + struct cmsghdr *cmsg; + char cbuf[CMSG_SPACE(sizeof(type))]; + struct iovec iov; + + memset(&msg, 0, sizeof(msg)); + + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_level = IPPROTO_TCP; + cmsg->cmsg_type = TLS_SET_RECORD_TYPE; + cmsg->cmsg_len = CMSG_LEN(sizeof(type)); + *(uint8_t *)CMSG_DATA(cmsg) = type; + + iov.iov_base = data; + iov.iov_len = len; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ATF_REQUIRE(sendmsg(fd, &msg, 0) == (ssize_t)len); +} + +static void +test_ktls_transmit_control(struct tls_enable *en, uint64_t seqno, uint8_t type, + size_t len) +{ + struct tls_record_layer *hdr; + char *plaintext, *decrypted, *outbuf; + size_t outbuf_cap, payload_len, record_len; + ssize_t rv; + int sockets[2]; + uint8_t record_type; + + ATF_REQUIRE(len <= TLS_MAX_MSG_SIZE_V10_2); + + plaintext = alloc_buffer(len); + decrypted = malloc(len); + outbuf_cap = tls_header_len(en) + len + tls_trailer_len(en); + outbuf = malloc(outbuf_cap); + hdr = (struct tls_record_layer *)outbuf; + + ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets"); + + ATF_REQUIRE(setsockopt(sockets[1], IPPROTO_TCP, TCP_TXTLS_ENABLE, en, + sizeof(*en)) == 0); + + fd_set_blocking(sockets[0]); + fd_set_blocking(sockets[1]); + + ktls_send_control_message(sockets[1], type, plaintext, len); + + /* + * First read the header to determine how much additional data + * to read. + */ + rv = read(sockets[0], outbuf, sizeof(struct tls_record_layer)); + ATF_REQUIRE(rv == sizeof(struct tls_record_layer)); + payload_len = ntohs(hdr->tls_length); + record_len = payload_len + sizeof(struct tls_record_layer); + assert(record_len <= outbuf_cap); + rv = read(sockets[0], outbuf + sizeof(struct tls_record_layer), + payload_len); + ATF_REQUIRE(rv == (ssize_t)payload_len); + + rv = decrypt_tls_record(en, seqno, outbuf, record_len, decrypted, len, + &record_type); + + ATF_REQUIRE_MSG((ssize_t)len == rv, + "read %zd decrypted bytes, but wrote %zu", rv, len); + ATF_REQUIRE(record_type == type); + + ATF_REQUIRE(memcmp(plaintext, decrypted, len) == 0); + + free(outbuf); + free(decrypted); + free(plaintext); + + close(sockets[1]); + close(sockets[0]); +} + +#define AES_CBC_TESTS(M) \ + M(aes128_cbc_1_0_sha1, CRYPTO_AES_CBC, 128 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ZERO) \ + M(aes256_cbc_1_0_sha1, CRYPTO_AES_CBC, 256 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ZERO) \ + M(aes128_cbc_1_1_sha1, CRYPTO_AES_CBC, 128 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ONE) \ + M(aes256_cbc_1_1_sha1, CRYPTO_AES_CBC, 256 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ONE) \ + M(aes128_cbc_1_2_sha1, CRYPTO_AES_CBC, 128 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_TWO) \ + M(aes256_cbc_1_2_sha1, CRYPTO_AES_CBC, 256 / 8, \ + CRYPTO_SHA1_HMAC, TLS_MINOR_VER_TWO) \ + M(aes128_cbc_1_2_sha256, CRYPTO_AES_CBC, 128 / 8, \ + CRYPTO_SHA2_256_HMAC, TLS_MINOR_VER_TWO) \ + M(aes256_cbc_1_2_sha256, CRYPTO_AES_CBC, 256 / 8, \ + CRYPTO_SHA2_256_HMAC, TLS_MINOR_VER_TWO) \ + M(aes128_cbc_1_2_sha384, CRYPTO_AES_CBC, 128 / 8, \ + CRYPTO_SHA2_384_HMAC, TLS_MINOR_VER_TWO) \ + M(aes256_cbc_1_2_sha384, CRYPTO_AES_CBC, 256 / 8, \ + CRYPTO_SHA2_384_HMAC, TLS_MINOR_VER_TWO) \ *** 99 LINES SKIPPED ***