mbedtls: Avoid memory leak when handling ECDSA keys

Signed-off-by: Jakub Jelen <jjelen@redhat.com>
Reviewed-by: Andreas Schneider <asn@cryptomilk.org>
This commit is contained in:
Jakub Jelen
2024-07-01 19:57:53 +02:00
parent 2d3b7e07af
commit ec6363d6b5
5 changed files with 128 additions and 105 deletions

View File

@@ -60,7 +60,7 @@ struct ssh_key_struct {
gcry_sexp_t rsa;
gcry_sexp_t ecdsa;
#elif defined(HAVE_LIBMBEDCRYPTO)
mbedtls_pk_context *rsa;
mbedtls_pk_context *pk;
mbedtls_ecdsa_context *ecdsa;
#elif defined(HAVE_LIBCRYPTO)
/* This holds either ENGINE key for PKCS#11 support or just key in

View File

@@ -1693,18 +1693,22 @@ int ssh_userauth_agent_pubkey(ssh_session session,
key->type = publickey->type;
key->type_c = ssh_key_type_to_char(key->type);
key->flags = SSH_KEY_FLAG_PUBLIC;
#ifndef HAVE_LIBCRYPTO
key->rsa = publickey->rsa_pub;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = publickey->rsa_pub;
#elif defined(HAVE_LIBCRYPTO)
key->key = publickey->key_pub;
#else
key->rsa = publickey->rsa_pub;
#endif /* HAVE_LIBCRYPTO */
rc = ssh_userauth_agent_publickey(session, username, key);
#ifndef HAVE_LIBCRYPTO
key->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
key->key = NULL;
#else
key->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(key);

View File

@@ -83,17 +83,21 @@ int ssh_userauth_pubkey(ssh_session session,
key->type = privatekey->type;
key->type_c = ssh_key_type_to_char(key->type);
key->flags = SSH_KEY_FLAG_PRIVATE|SSH_KEY_FLAG_PUBLIC;
#ifndef HAVE_LIBCRYPTO
key->rsa = privatekey->rsa_priv;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = privatekey->rsa_priv;
#elif defined(HAVE_LIBCRYPTO)
key->key = privatekey->key_priv;
#else
key->rsa = privatekey->rsa_priv;
#endif /* HAVE_LIBCRYPTO */
rc = ssh_userauth_publickey(session, username, key);
#ifndef HAVE_LIBCRYPTO
key->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
key->key = NULL;
#else
key->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(key);
@@ -386,17 +390,21 @@ ssh_public_key publickey_from_privatekey(ssh_private_key prv) {
privkey->type = prv->type;
privkey->type_c = ssh_key_type_to_char(privkey->type);
privkey->flags = SSH_KEY_FLAG_PRIVATE | SSH_KEY_FLAG_PUBLIC;
#ifndef HAVE_LIBCRYPTO
privkey->rsa = prv->rsa_priv;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
privkey->pk = prv->rsa_priv;
#elif defined(HAVE_LIBCRYPTO)
privkey->key = prv->key_priv;
#else
privkey->rsa = prv->rsa_priv;
#endif /* HAVE_LIBCRYPTO */
rc = ssh_pki_export_privkey_to_pubkey(privkey, &pubkey);
#ifndef HAVE_LIBCRYPTO
privkey->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
privkey->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
privkey->key = NULL;
#else
privkey->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(privkey);
if (rc < 0) {
@@ -443,14 +451,15 @@ ssh_private_key privatekey_from_file(ssh_session session,
}
privkey->type = key->type;
#ifndef HAVE_LIBCRYPTO
privkey->rsa_priv = key->rsa;
key->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
privkey->rsa_priv = key->pk;
key->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
privkey->key_priv = key->key;
key->key = NULL;
#else
privkey->rsa_priv = key->rsa;
key->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(key);
@@ -537,12 +546,15 @@ ssh_public_key publickey_from_string(ssh_session session, ssh_string pubkey_s) {
pubkey->type = key->type;
pubkey->type_c = key->type_c;
#ifndef HAVE_LIBCRYPTO
pubkey->rsa_pub = key->rsa;
key->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
pubkey->rsa_pub = key->pk;
key->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
pubkey->key_pub = key->key;
key->key = NULL;
#else
pubkey->rsa_pub = key->rsa;
key->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(key);
@@ -567,10 +579,12 @@ ssh_string publickey_to_string(ssh_public_key pubkey) {
key->type = pubkey->type;
key->type_c = pubkey->type_c;
#ifndef HAVE_LIBCRYPTO
key->rsa = pubkey->rsa_pub;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = pubkey->rsa_pub;
#elif defined(HAVE_LIBCRYPTO)
key->key = pubkey->key_pub;
#else
key->rsa = pubkey->rsa_pub;
#endif /* HAVE_LIBCRYPTO */
rc = ssh_pki_export_pubkey_blob(key, &key_blob);
@@ -578,10 +592,12 @@ ssh_string publickey_to_string(ssh_public_key pubkey) {
key_blob = NULL;
}
#ifndef HAVE_LIBCRYPTO
key->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
key->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
key->key = NULL;
#else
key->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(key);

View File

@@ -1201,12 +1201,15 @@ ssh_public_key ssh_pki_convert_key_to_publickey(const ssh_key key)
pub->type = tmp->type;
pub->type_c = tmp->type_c;
#ifndef HAVE_LIBCRYPTO
pub->rsa_pub = tmp->rsa;
tmp->rsa = NULL;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
pub->rsa_pub = tmp->pk;
tmp->pk = NULL;
#elif defined(HAVE_LIBCRYPTO)
pub->key_pub = tmp->key;
tmp->key = NULL;
#else
pub->rsa_pub = tmp->rsa;
tmp->rsa = NULL;
#endif /* HAVE_LIBCRYPTO */
ssh_key_free(tmp);
@@ -1225,10 +1228,12 @@ ssh_private_key ssh_pki_convert_key_to_privatekey(const ssh_key key)
}
privkey->type = key->type;
#ifndef HAVE_LIBCRYPTO
privkey->rsa_priv = key->rsa;
#else
#if defined(HAVE_LIBMBEDCRYPTO)
privkey->rsa_priv = key->pk;
#elif defined(HAVE_LIBCRYPTO)
privkey->key_priv = key->key;
#else
privkey->rsa_priv = key->rsa;
#endif /* HAVE_LIBCRYPTO */
return privkey;

View File

@@ -43,9 +43,9 @@ void pki_key_clean(ssh_key key)
if (key == NULL)
return;
if (key->rsa != NULL) {
mbedtls_pk_free(key->rsa);
SAFE_FREE(key->rsa);
if (key->pk != NULL) {
mbedtls_pk_free(key->pk);
SAFE_FREE(key->pk);
}
if (key->ecdsa != NULL) {
@@ -183,7 +183,7 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase,
switch (mbed_type) {
case MBEDTLS_PK_RSA:
case MBEDTLS_PK_RSA_ALT:
key->rsa = pk;
key->pk = pk;
pk = NULL;
key->type = SSH_KEYTYPE_RSA;
break;
@@ -191,7 +191,6 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase,
case MBEDTLS_PK_ECDSA: {
/* type will be set later */
mbedtls_ecp_keypair *keypair = mbedtls_pk_ec(*pk);
pk = NULL;
key->ecdsa = malloc(sizeof(mbedtls_ecdsa_context));
if (key->ecdsa == NULL) {
@@ -200,8 +199,7 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase,
mbedtls_ecdsa_init(key->ecdsa);
mbedtls_ecdsa_from_keypair(key->ecdsa, keypair);
mbedtls_pk_free(pk);
SAFE_FREE(pk);
key->pk = pk;
key->ecdsa_nid = pki_key_ecdsa_to_nid(key->ecdsa);
@@ -246,21 +244,21 @@ int pki_privkey_build_rsa(ssh_key key,
const mbedtls_pk_info_t *pk_info = NULL;
int rc;
key->rsa = malloc(sizeof(mbedtls_pk_context));
if (key->rsa == NULL) {
key->pk = malloc(sizeof(mbedtls_pk_context));
if (key->pk == NULL) {
return SSH_ERROR;
}
mbedtls_pk_init(key->rsa);
mbedtls_pk_init(key->pk);
pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
mbedtls_pk_setup(key->rsa, pk_info);
mbedtls_pk_setup(key->pk, pk_info);
rc = mbedtls_pk_can_do(key->rsa, MBEDTLS_PK_RSA);
rc = mbedtls_pk_can_do(key->pk, MBEDTLS_PK_RSA);
if (rc == 0) {
goto fail;
}
rsa = mbedtls_pk_rsa(*key->rsa);
rsa = mbedtls_pk_rsa(*key->pk);
rc = mbedtls_rsa_import_raw(rsa,
ssh_string_data(n), ssh_string_len(n),
ssh_string_data(p), ssh_string_len(p),
@@ -287,8 +285,8 @@ int pki_privkey_build_rsa(ssh_key key,
return SSH_OK;
fail:
mbedtls_pk_free(key->rsa);
SAFE_FREE(key->rsa);
mbedtls_pk_free(key->pk);
SAFE_FREE(key->pk);
return SSH_ERROR;
}
@@ -302,21 +300,21 @@ int pki_pubkey_build_rsa(ssh_key key, ssh_string e, ssh_string n)
#endif
int rc;
key->rsa = malloc(sizeof(mbedtls_pk_context));
if (key->rsa == NULL) {
key->pk = malloc(sizeof(mbedtls_pk_context));
if (key->pk == NULL) {
return SSH_ERROR;
}
mbedtls_pk_init(key->rsa);
mbedtls_pk_init(key->pk);
pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
mbedtls_pk_setup(key->rsa, pk_info);
mbedtls_pk_setup(key->pk, pk_info);
rc = mbedtls_pk_can_do(key->rsa, MBEDTLS_PK_RSA);
rc = mbedtls_pk_can_do(key->pk, MBEDTLS_PK_RSA);
if (rc == 0) {
goto fail;
}
rsa = mbedtls_pk_rsa(*key->rsa);
rsa = mbedtls_pk_rsa(*key->pk);
#if MBEDTLS_VERSION_MAJOR > 2
mbedtls_mpi_init(&N);
mbedtls_mpi_init(&E);
@@ -359,8 +357,8 @@ int pki_pubkey_build_rsa(ssh_key key, ssh_string e, ssh_string n)
goto exit;
fail:
rc = SSH_ERROR;
mbedtls_pk_free(key->rsa);
SAFE_FREE(key->rsa);
mbedtls_pk_free(key->pk);
SAFE_FREE(key->pk);
exit:
#if MBEDTLS_VERSION_MAJOR > 2
mbedtls_mpi_free(&N);
@@ -407,23 +405,22 @@ ssh_key pki_key_dup(const ssh_key key, int demote)
case SSH_KEYTYPE_RSA: {
mbedtls_rsa_context *rsa, *new_rsa;
new->rsa = malloc(sizeof(mbedtls_pk_context));
if (new->rsa == NULL) {
new->pk = malloc(sizeof(mbedtls_pk_context));
if (new->pk == NULL) {
goto fail;
}
mbedtls_pk_init(new->rsa);
mbedtls_pk_init(new->pk);
pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
mbedtls_pk_setup(new->rsa, pk_info);
mbedtls_pk_setup(new->pk, pk_info);
if (!mbedtls_pk_can_do(key->rsa, MBEDTLS_PK_RSA) ||
!mbedtls_pk_can_do(new->rsa, MBEDTLS_PK_RSA))
{
if (!mbedtls_pk_can_do(key->pk, MBEDTLS_PK_RSA) ||
!mbedtls_pk_can_do(new->pk, MBEDTLS_PK_RSA)) {
goto fail;
}
rsa = mbedtls_pk_rsa(*key->rsa);
new_rsa = mbedtls_pk_rsa(*new->rsa);
rsa = mbedtls_pk_rsa(*key->pk);
new_rsa = mbedtls_pk_rsa(*new->pk);
if (!demote && (key->flags & SSH_KEY_FLAG_PRIVATE)) {
#if MBEDTLS_VERSION_MAJOR > 2
@@ -572,27 +569,27 @@ int pki_key_generate_rsa(ssh_key key, int parameter)
int rc;
const mbedtls_pk_info_t *info = NULL;
key->rsa = malloc(sizeof(mbedtls_pk_context));
if (key->rsa == NULL) {
key->pk = malloc(sizeof(mbedtls_pk_context));
if (key->pk == NULL) {
return SSH_ERROR;
}
mbedtls_pk_init(key->rsa);
mbedtls_pk_init(key->pk);
info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
rc = mbedtls_pk_setup(key->rsa, info);
rc = mbedtls_pk_setup(key->pk, info);
if (rc != 0) {
return SSH_ERROR;
}
if (mbedtls_pk_can_do(key->rsa, MBEDTLS_PK_RSA)) {
rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(*key->rsa),
if (mbedtls_pk_can_do(key->pk, MBEDTLS_PK_RSA)) {
rc = mbedtls_rsa_gen_key(mbedtls_pk_rsa(*key->pk),
mbedtls_ctr_drbg_random,
ssh_get_mbedtls_ctr_drbg_context(),
parameter,
65537);
if (rc != 0) {
mbedtls_pk_free(key->rsa);
mbedtls_pk_free(key->pk);
return SSH_ERROR;
}
}
@@ -626,30 +623,28 @@ int pki_key_compare(const ssh_key k1, const ssh_key k2, enum ssh_keycmp_e what)
switch (ssh_key_type_plain(k1->type)) {
case SSH_KEYTYPE_RSA: {
mbedtls_rsa_context *rsa1, *rsa2;
if (!mbedtls_pk_can_do(k1->rsa, MBEDTLS_PK_RSA) ||
!mbedtls_pk_can_do(k2->rsa, MBEDTLS_PK_RSA))
{
if (!mbedtls_pk_can_do(k1->pk, MBEDTLS_PK_RSA) ||
!mbedtls_pk_can_do(k2->pk, MBEDTLS_PK_RSA)) {
break;
}
if (mbedtls_pk_get_type(k1->rsa) != mbedtls_pk_get_type(k2->rsa) ||
mbedtls_pk_get_bitlen(k1->rsa) !=
mbedtls_pk_get_bitlen(k2->rsa))
{
if (mbedtls_pk_get_type(k1->pk) != mbedtls_pk_get_type(k2->pk) ||
mbedtls_pk_get_bitlen(k1->pk) !=
mbedtls_pk_get_bitlen(k2->pk)) {
rc = 1;
goto cleanup;
}
if (what == SSH_KEY_CMP_PUBLIC) {
#if MBEDTLS_VERSION_MAJOR > 2
rsa1 = mbedtls_pk_rsa(*k1->rsa);
rsa1 = mbedtls_pk_rsa(*k1->pk);
rc = mbedtls_rsa_export(rsa1, &N1, NULL, NULL, NULL, &E1);
if (rc != 0) {
rc = 1;
goto cleanup;
}
rsa2 = mbedtls_pk_rsa(*k2->rsa);
rsa2 = mbedtls_pk_rsa(*k2->pk);
rc = mbedtls_rsa_export(rsa2, &N2, NULL, NULL, NULL, &E2);
if (rc != 0) {
rc = 1;
@@ -666,8 +661,8 @@ int pki_key_compare(const ssh_key k1, const ssh_key k2, enum ssh_keycmp_e what)
goto cleanup;
}
#else
rsa1 = mbedtls_pk_rsa(*k1->rsa);
rsa2 = mbedtls_pk_rsa(*k2->rsa);
rsa1 = mbedtls_pk_rsa(*k1->pk);
rsa2 = mbedtls_pk_rsa(*k2->pk);
if (mbedtls_mpi_cmp_mpi(&rsa1->N, &rsa2->N) != 0) {
rc = 1;
goto cleanup;
@@ -680,14 +675,14 @@ int pki_key_compare(const ssh_key k1, const ssh_key k2, enum ssh_keycmp_e what)
#endif
} else if (what == SSH_KEY_CMP_PRIVATE) {
#if MBEDTLS_VERSION_MAJOR > 2
rsa1 = mbedtls_pk_rsa(*k1->rsa);
rsa1 = mbedtls_pk_rsa(*k1->pk);
rc = mbedtls_rsa_export(rsa1, &N1, &P1, &Q1, NULL, &E1);
if (rc != 0) {
rc = 1;
goto cleanup;
}
rsa2 = mbedtls_pk_rsa(*k2->rsa);
rsa2 = mbedtls_pk_rsa(*k2->pk);
rc = mbedtls_rsa_export(rsa2, &N2, &P2, &Q2, NULL, &E2);
if (rc != 0) {
rc = 1;
@@ -714,8 +709,8 @@ int pki_key_compare(const ssh_key k1, const ssh_key k2, enum ssh_keycmp_e what)
goto cleanup;
}
#else
rsa1 = mbedtls_pk_rsa(*k1->rsa);
rsa2 = mbedtls_pk_rsa(*k2->rsa);
rsa1 = mbedtls_pk_rsa(*k1->pk);
rsa2 = mbedtls_pk_rsa(*k2->pk);
if (mbedtls_mpi_cmp_mpi(&rsa1->N, &rsa2->N) != 0) {
rc = 1;
goto cleanup;
@@ -916,12 +911,12 @@ ssh_string pki_key_to_blob(const ssh_key key, enum ssh_key_e type)
switch (key->type) {
case SSH_KEYTYPE_RSA: {
mbedtls_rsa_context *rsa;
if (mbedtls_pk_can_do(key->rsa, MBEDTLS_PK_RSA) == 0) {
if (mbedtls_pk_can_do(key->pk, MBEDTLS_PK_RSA) == 0) {
SSH_BUFFER_FREE(buffer);
return NULL;
}
rsa = mbedtls_pk_rsa(*key->rsa);
rsa = mbedtls_pk_rsa(*key->pk);
#if MBEDTLS_VERSION_MAJOR > 2
rc = mbedtls_rsa_export(rsa, &N, NULL, NULL, NULL, &E);
@@ -1274,12 +1269,12 @@ static ssh_signature pki_signature_from_rsa_blob(const ssh_key pubkey, const
size_t rsalen = 0;
size_t len = ssh_string_len(sig_blob);
if (pubkey->rsa == NULL) {
if (pubkey->pk == NULL) {
SSH_LOG(SSH_LOG_TRACE, "Pubkey RSA field NULL");
goto errout;
}
rsalen = mbedtls_pk_get_bitlen(pubkey->rsa) / 8;
rsalen = mbedtls_pk_get_bitlen(pubkey->pk) / 8;
if (len > rsalen) {
SSH_LOG(SSH_LOG_TRACE,
"Signature is too big: %lu > %lu",
@@ -1527,7 +1522,7 @@ ssh_signature pki_do_sign_hash(const ssh_key privkey,
switch(privkey->type) {
case SSH_KEYTYPE_RSA:
sig->rsa_sig = rsa_do_sign_hash(hash, hlen, privkey->rsa, hash_type);
sig->rsa_sig = rsa_do_sign_hash(hash, hlen, privkey->pk, hash_type);
if (sig->rsa_sig == NULL) {
ssh_signature_free(sig);
return NULL;
@@ -1735,9 +1730,12 @@ int pki_verify_data_signature(ssh_signature signature,
switch (pubkey->type) {
case SSH_KEYTYPE_RSA:
case SSH_KEYTYPE_RSA_CERT01:
rc = mbedtls_pk_verify(pubkey->rsa, md, hash, hlen,
ssh_string_data(signature->rsa_sig),
ssh_string_len(signature->rsa_sig));
rc = mbedtls_pk_verify(pubkey->pk,
md,
hash,
hlen,
ssh_string_data(signature->rsa_sig),
ssh_string_len(signature->rsa_sig));
if (rc != 0) {
char error_buf[100];
mbedtls_strerror(rc, error_buf, 100);
@@ -2002,7 +2000,7 @@ int ssh_key_size(ssh_key key)
case SSH_KEYTYPE_RSA:
case SSH_KEYTYPE_RSA_CERT01:
case SSH_KEYTYPE_RSA1:
return mbedtls_pk_get_bitlen(key->rsa);
return mbedtls_pk_get_bitlen(key->pk);
case SSH_KEYTYPE_ECDSA_P256:
case SSH_KEYTYPE_ECDSA_P256_CERT01:
case SSH_KEYTYPE_SK_ECDSA: