-
Notifications
You must be signed in to change notification settings - Fork 134
ML-KEM encaps key modulus check optimization #1874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its mai 8000 ntainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -60,54 +60,77 @@ int crypto_kem_keypair(ml_kem_params *params, | |||||
return 0; | ||||||
} | ||||||
|
||||||
// REFERENCE IMPLEMENTATION OF SEVERAL FIPS 203 FUNCTIONS. | ||||||
// Further below we implement optimized versions of the functions | ||||||
// that are actually used. We commented out and kept the reference | ||||||
// code for posterity. | ||||||
// | ||||||
// FIPS 203. Algorithm 3 BitsToBytes | ||||||
// Converts a bit array (of a length that is a multiple of eight) | ||||||
// into an array of bytes. | ||||||
static void bits_to_bytes(uint8_t *bytes, size_t num_bytes, | ||||||
const uint8_t *bits, size_t num_bits) { | ||||||
assert(num_bits == num_bytes * 8); | ||||||
|
||||||
for (size_t i = 0; i < num_bytes; i++) { | ||||||
uint8_t byte = 0; | ||||||
for (size_t j = 0; j < 8; j++) { | ||||||
byte |= (bits[i * 8 + j] << j); | ||||||
} | ||||||
bytes[i] = byte; | ||||||
} | ||||||
} | ||||||
|
||||||
// static void bits_to_bytes(uint8_t *bytes, size_t num_bytes, | ||||||
// const uint8_t *bits, size_t num_bits) { | ||||||
// assert(num_bits == num_bytes * 8); | ||||||
// | ||||||
// for (size_t i = 0; i < num_bytes; i++) { | ||||||
// uint8_t byte = 0; | ||||||
// for (size_t j = 0; j < 8; j++) { | ||||||
// byte |= (bits[i * 8 + j] << j); | ||||||
// } | ||||||
// bytes[i] = byte; | ||||||
// } | ||||||
// } | ||||||
// FIPS 203. Algorithm 4 BytesToBits | ||||||
// Performs the inverse of BitsToBytes, converting a byte array into a bit array. | ||||||
static void bytes_to_bits(uint8_t *bits, size_t num_bits, | ||||||
const uint8_t *bytes, size_t num_bytes) { | ||||||
assert(num_bits == num_bytes * 8); | ||||||
|
||||||
for (size_t i = 0; i < num_bytes; i++) { | ||||||
uint8_t byte = bytes[i]; | ||||||
for (size_t j = 0; j < 8; j++) { | ||||||
bits[i * 8 + j] = (byte >> j) & 1; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
#define BYTE_ENCODE_12_IN_SIZE (256) | ||||||
#define BYTE_ENCODE_12_OUT_SIZE (32 * 12) | ||||||
#define BYTE_ENCODE_12_NUM_BITS (256 * 12) | ||||||
|
||||||
// FIPS 203. Algorithm 5 ByteEncode_12 | ||||||
// Encodes an array of 256 12-bit integers into a byte array. | ||||||
static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], | ||||||
const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { | ||||||
uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0}; | ||||||
for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) { | ||||||
int16_t a = in[i]; | ||||||
for (size_t j = 0; j < 12; j++) { | ||||||
bits[i * 12 + j] = a & 1; | ||||||
a = a >> 1; | ||||||
} | ||||||
} | ||||||
bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS); | ||||||
} | ||||||
// static void bytes_to_bits(uint8_t *bits, size_t num_bits, | ||||||
// const uint8_t *bytes, size_t num_bytes) { | ||||||
// assert(num_bits == num_bytes * 8); | ||||||
// | ||||||
// for (size_t i = 0; i < num_bytes; i++) { | ||||||
// uint8_t byte = bytes[i]; | ||||||
// for (size_t j = 0; j < 8; j++) { | ||||||
// bits[i * 8 + j] = (byte >> j) & 1; | ||||||
// } | ||||||
// } | ||||||
// } | ||||||
// | ||||||
// #define BYTE_ENCODE_12_IN_SIZE (256) | ||||||
// #define BYTE_ENCODE_12_OUT_SIZE (32 * 12) | ||||||
// #define BYTE_ENCODE_12_NUM_BITS (256 * 12) | ||||||
// | ||||||
// // FIPS 203. Algorithm 5 ByteEncode_12 | ||||||
// // Encodes an array of 256 12-bit integers into a byte array. | ||||||
// static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], | ||||||
// const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { | ||||||
// uint8_t bits[BYTE_ENCODE_12_NUM_BITS] = {0}; | ||||||
// for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE; i++) { | ||||||
// int16_t a = in[i]; | ||||||
// for (size_t j = 0; j < 12; j++) { | ||||||
// bits[i * 12 + j] = a & 1; | ||||||
// a = a >> 1; | ||||||
// } | ||||||
// } | ||||||
// bits_to_bytes(out, BYTE_ENCODE_12_OUT_SIZE, bits, BYTE_ENCODE_12_NUM_BITS); | ||||||
// } | ||||||
// | ||||||
// #define BYTE_DECODE_12_OUT_SIZE (256) | ||||||
// #define BYTE_DECODE_12_IN_SIZE (32 * 12) | ||||||
// #define BYTE_DECODE_12_NUM_BITS (256 * 12) | ||||||
// | ||||||
// // FIPS 203. Algorithm 6 ByteDecode_12 | ||||||
// // Decodes a byte array into an array of 256 12-bit integers. | ||||||
// static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE], | ||||||
// const uint8_t in[BYTE_DECODE_12_IN_SIZE]) { | ||||||
// uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0}; | ||||||
// bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE); | ||||||
// for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) { | ||||||
// int16_t val = 0; | ||||||
// for (size_t j = 0; j < 12; j++) { | ||||||
// val |= bits[i * 12 + j] << j; | ||||||
// } | ||||||
// out[i] = centered_to_positive_representative(barrett_reduce(val)); | ||||||
// } | ||||||
// } | ||||||
|
||||||
// Converts a centered representative |in| which is an integer in | ||||||
// {-(q-1)/2, ..., (q-1)/2}, to a positive representative in {0, ..., q-1}. | ||||||
|
@@ -120,22 +143,58 @@ static int16_t centered_to_positive_representative(int16_t in) { | |||||
return constant_time_select_int(mask, in, in_fixed); | ||||||
} | ||||||
|
||||||
#define BYTE_DECODE_12_OUT_SIZE (256) | ||||||
#define BYTE_DECODE_12_IN_SIZE (32 * 12) | ||||||
#define BYTE_DECODE_12_NUM_BITS (256 * 12) | ||||||
#define BYTE_ENCODE_12_IN_SIZE (256) | ||||||
#define BYTE_ENCODE_12_OUT_SIZE (32 * 12) | ||||||
#define BYTE_DECODE_12_OUT_SIZE (BYTE_ENCODE_12_IN_SIZE) | ||||||
#define BYTE_DECODE_12_IN_SIZE (BYTE_ENCODE_12_OUT_SIZE) | ||||||
|
||||||
// FIPS 203. Algorithm 5 ByteEncode_12 | ||||||
// Encodes an array of 256 12-bit integers into a byte array. | ||||||
// Intuition for the implementation: | ||||||
// in: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ... | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused how this is the same as the commented-out implementation: I thought that each 16-bit array member contained a 12-bit digit and the top 4 are discarded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
that's correct, and that's what we are doing here as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see it now, thanks. |
||||||
// out: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ... | ||||||
// We divide the input in pairs of elements (2 x 12 bits = 24 bits), | ||||||
// and the output in triples (3 x 8 bits = 24 bits). For each pair/triplet we: | ||||||
// - out0 <-- first eight bits of in0, | ||||||
// - out1 <-- concatenate last 4 bits of in0 and first 4 bits of in1, | ||||||
// - out2 <-- last 8 bits of in1. | ||||||
static void byte_encode_12(uint8_t out[BYTE_ENCODE_12_OUT_SIZE], | ||||||
const int16_t in[BYTE_ENCODE_12_IN_SIZE]) { | ||||||
for (size_t i = 0; i < BYTE_ENCODE_12_IN_SIZE / 2; i++) { | ||||||
int16_t in0 = in[2 * i]; | ||||||
int16_t in1 = in[2 * i + 1]; | ||||||
out[3 * i] = in0 & 0xff; | ||||||
out[3 * i + 1] = ((in0 >> 8) & 0xf) | ((in1 & 0xf) << 4); | ||||||
out[3 * i + 2] = (in1 >> 4) & 0xff; | ||||||
} | ||||||
} | ||||||
|
||||||
// FIPS 203. Algorithm 5 ByteDecode_12 | ||||||
// FIPS 203. Algorithm 6 ByteDecode_12 | ||||||
// Decodes a byte array into an array of 256 12-bit integers. | ||||||
// Intuition for the implementation: | ||||||
// in: |xxxxxxxx| |yyyyyyyy| |zzzzzzzz| ... | ||||||
// out: |xxxxxxxxyyyy| |yyyyzzzzzzzz| ... | ||||||
// We divide the input in triples of elements (3 x 8 bits = 24 bits), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit
Suggested change
|
||||||
// and the output in pairs (2 x 12 bits = 24 bits). For each pair/triplet we: | ||||||
// - out[0] <-- concatenate eight bits of in[0] and first 4 bits of in[1], | ||||||
// - out[1] <-- concatenate last 4 bits of in[1] and 8 bits of in[2]. | ||||||
// Additionally we reduce the output elements mod Q as specified in FIPS 203. | ||||||
static void byte_decode_12(int16_t out[BYTE_DECODE_12_OUT_SIZE], | ||||||
const uint8_t in[BYTE_DECODE_12_IN_SIZE]) { | ||||||
uint8_t bits[BYTE_DECODE_12_NUM_BITS] = {0}; | ||||||
bytes_to_bits(bits, BYTE_DECODE_12_NUM_BITS, in, BYTE_DECODE_12_IN_SIZE); | ||||||
for (size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE; i++) { | ||||||
int16_t val = 0; | ||||||
for (size_t j = 0; j < 12; j++) { | ||||||
val |= bits[i * 12 + j] << j; | ||||||
} | ||||||
out[i] = centered_to_positive_representative(barrett_reduce(val)); | ||||||
for(size_t i = 0; i < BYTE_DECODE_12_OUT_SIZE / 2; i++) { | ||||||
// Cast to 16-bit wide uint's to avoid any issues | ||||||
// with shifting and implicit casting. | ||||||
uint16_t in0 = (uint16_t) in[3 * i]; | ||||||
uint16_t in1 = (uint16_t) in[3 * i + 1]; | ||||||
uint16_t in2 = (uint16_t) in[3 * i + 2]; | ||||||
|
||||||
// Build the output pair. | ||||||
uint16_t out0 = in0 | ((in1 & 0xf) << 8); | ||||||
uint16_t out1 = (in1 >> 4) | (in2 << 4); | ||||||
|
||||||
// Reduce mod Q. | ||||||
out[2 * i] = centered_to_positive_representative(barrett_reduce(out0)); | ||||||
out[2 * i + 1] = centered_to_positive_representative(barrett_reduce(out1)); | ||||||
} | ||||||
} | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to? it'll still be in the git history, you can mention that fact here.