diff --git a/include/libssh/hybrid_mlkem.h b/include/libssh/hybrid_mlkem.h index 6e9d9e5d..ca9b6a20 100644 --- a/include/libssh/hybrid_mlkem.h +++ b/include/libssh/hybrid_mlkem.h @@ -34,6 +34,9 @@ extern "C" { #endif +#define NISTP256_SHARED_SECRET_SIZE 32 +#define NISTP384_SHARED_SECRET_SIZE 48 + int ssh_client_hybrid_mlkem_init(ssh_session session); void ssh_client_hybrid_mlkem_remove_callbacks(ssh_session session); diff --git a/src/hybrid_mlkem.c b/src/hybrid_mlkem.c index df2169c2..18a0744d 100644 --- a/src/hybrid_mlkem.c +++ b/src/hybrid_mlkem.c @@ -43,43 +43,65 @@ static struct ssh_packet_callbacks_struct ssh_hybrid_mlkem_client_callbacks = { .user = NULL, }; -static ssh_string derive_ecdh_secret(ssh_session session) +static ssh_string derive_curve25519_secret(ssh_session session) +{ + ssh_string secret = NULL; + int rc; + + secret = ssh_string_new(CURVE25519_PUBKEY_SIZE); + if (secret == NULL) { + ssh_set_error_oom(session); + return NULL; + } + + rc = ssh_curve25519_create_k(session, ssh_string_data(secret)); + if (rc != SSH_OK) { + ssh_set_error(session, + SSH_FATAL, + "Curve25519 secret derivation failed"); + ssh_string_free(secret); + return NULL; + } + + return secret; +} + +static ssh_string derive_nist_curve_secret(ssh_session session, + size_t secret_size) { struct ssh_crypto_struct *crypto = session->next_crypto; ssh_string secret = NULL; - size_t secret_size; int rc; - switch (crypto->kex_type) { + rc = ecdh_build_k(session); + if (rc != SSH_OK) { + ssh_set_error(session, SSH_FATAL, "ECDH secret derivation failed"); + return NULL; + } + + secret = ssh_make_padded_bignum_string(crypto->shared_secret, secret_size); + if (secret == NULL) { + ssh_set_error(session, SSH_FATAL, "Failed to encode the shared secret"); + } + + bignum_safe_free(crypto->shared_secret); + + return secret; +} + +static ssh_string derive_ecdh_secret(ssh_session session) +{ + ssh_string secret = NULL; + + switch (session->next_crypto->kex_type) { case SSH_KEX_MLKEM768X25519_SHA256: - secret = ssh_string_new(CURVE25519_PUBKEY_SIZE); - if (secret == NULL) { - ssh_set_error_oom(session); - return NULL; - } - rc = ssh_curve25519_create_k(session, ssh_string_data(secret)); - if (rc != SSH_OK) { - ssh_set_error(session, SSH_FATAL, "Curve25519 secret derivation failed"); - ssh_string_free(secret); - return NULL; - } + secret = derive_curve25519_secret(session); break; case SSH_KEX_MLKEM768NISTP256_SHA256: + secret = derive_nist_curve_secret(session, NISTP256_SHARED_SECRET_SIZE); + break; case SSH_KEX_MLKEM1024NISTP384_SHA384: - rc = ecdh_build_k(session); - if (rc != SSH_OK) { - ssh_set_error(session, SSH_FATAL, "ECDH secret derivation failed"); - return NULL; - } - secret_size = bignum_num_bytes(crypto->shared_secret); - secret = ssh_string_new(secret_size); - if (secret == NULL) { - ssh_set_error_oom(session); - bignum_safe_free(crypto->shared_secret); - return NULL; - } - bignum_bn2bin(crypto->shared_secret, secret_size, ssh_string_data(secret)); - bignum_safe_free(crypto->shared_secret); + secret = derive_nist_curve_secret(session, NISTP384_SHARED_SECRET_SIZE); break; default: ssh_set_error(session, SSH_FATAL, "Unsupported KEX type");