diff --git a/src/pki_mbedcrypto.c b/src/pki_mbedcrypto.c index d83074e6..1427ded9 100644 --- a/src/pki_mbedcrypto.c +++ b/src/pki_mbedcrypto.c @@ -85,10 +85,8 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase, ssh_auth_callback auth_fn, void *auth_data) { ssh_key key = NULL; - mbedtls_pk_context *rsa = NULL; - mbedtls_pk_context *ecdsa = NULL; - ed25519_privkey *ed25519 = NULL; - enum ssh_keytypes_e type; + mbedtls_pk_context *pk = NULL; + mbedtls_pk_type_t mbed_type; int valid; /* mbedtls pk_parse_key expects strlen to count the 0 byte */ size_t b64len = strlen(b64_key) + 1; @@ -97,159 +95,100 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase, mbedtls_ctr_drbg_context *ctr_drbg = ssh_get_mbedtls_ctr_drbg_context(); #endif - type = pki_privatekey_type_from_string(b64_key); - if (type == SSH_KEYTYPE_UNKNOWN) { - SSH_LOG(SSH_LOG_WARN, "Unknown or invalid private key."); - return NULL; + pk = malloc(sizeof(mbedtls_pk_context)); + if (pk == NULL) { + goto fail; } + mbedtls_pk_init(pk); - switch (type) { - case SSH_KEYTYPE_RSA: - rsa = malloc(sizeof(mbedtls_pk_context)); - if (rsa == NULL) { - return NULL; - } - - mbedtls_pk_init(rsa); - - if (passphrase == NULL) { - if (auth_fn) { - valid = auth_fn("Passphrase for private key:", (char *) tmp, - MAX_PASSPHRASE_SIZE, 0, 0, auth_data); - if (valid < 0) { - goto fail; - } - /* TODO fix signedness and strlen */ -#if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, - b64len, tmp, - strnlen((const char *) tmp, MAX_PASSPHRASE_SIZE), - mbedtls_ctr_drbg_random, ctr_drbg); -#else - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, - b64len, tmp, - strnlen((const char *) tmp, MAX_PASSPHRASE_SIZE)); -#endif - } else { -#if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, - b64len, NULL, - 0, mbedtls_ctr_drbg_random, ctr_drbg); -#else - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, - b64len, NULL, - 0); -#endif - } - } else { -#if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, b64len, - (const unsigned char *) passphrase, - strnlen(passphrase, MAX_PASSPHRASE_SIZE), - mbedtls_ctr_drbg_random, ctr_drbg); -#else - valid = mbedtls_pk_parse_key(rsa, - (const unsigned char *) b64_key, b64len, - (const unsigned char *) passphrase, - strnlen(passphrase, MAX_PASSPHRASE_SIZE)); -#endif - } - - if (valid != 0) { - char error_buf[100]; - mbedtls_strerror(valid, error_buf, 100); - SSH_LOG(SSH_LOG_WARN,"Parsing private key %s", error_buf); + if (passphrase == NULL) { + if (auth_fn) { + valid = auth_fn("Passphrase for private key:", + (char *)tmp, + MAX_PASSPHRASE_SIZE, + 0, + 0, + auth_data); + if (valid < 0) { goto fail; } - break; - case SSH_KEYTYPE_ECDSA_P256: - case SSH_KEYTYPE_ECDSA_P384: - case SSH_KEYTYPE_ECDSA_P521: #if MBEDTLS_VERSION_MAJOR > 2 - ecdsa = malloc(sizeof(mbedtls_ecdsa_context)); + valid = mbedtls_pk_parse_key( + pk, + (const unsigned char *)b64_key, + b64len, + tmp, + strnlen((const char *)tmp, MAX_PASSPHRASE_SIZE), + mbedtls_ctr_drbg_random, + ctr_drbg); #else - ecdsa = malloc(sizeof(mbedtls_pk_context)); + valid = mbedtls_pk_parse_key( + pk, + (const unsigned char *)b64_key, + b64len, + tmp, + strnlen((const char *)tmp, MAX_PASSPHRASE_SIZE)); #endif - if (ecdsa == NULL) { - return NULL; - } - - mbedtls_pk_init(ecdsa); - - if (passphrase == NULL) { - if (auth_fn) { - valid = auth_fn("Passphrase for private key:", (char *) tmp, - MAX_PASSPHRASE_SIZE, 0, 0, auth_data); - if (valid < 0) { - goto fail; - } + } else { #if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, - b64len, tmp, - strnlen((const char *) tmp, MAX_PASSPHRASE_SIZE), - mbedtls_ctr_drbg_random, ctr_drbg); + valid = mbedtls_pk_parse_key(pk, + (const unsigned char *)b64_key, + b64len, + NULL, + 0, + mbedtls_ctr_drbg_random, + ctr_drbg); #else - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, - b64len, tmp, - strnlen((const char *) tmp, MAX_PASSPHRASE_SIZE)); + valid = mbedtls_pk_parse_key(pk, + (const unsigned char *)b64_key, + b64len, + NULL, + 0); #endif - } else { + } + } else { #if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, - b64len, NULL, - 0, mbedtls_ctr_drbg_random, ctr_drbg); + valid = mbedtls_pk_parse_key(pk, + (const unsigned char *)b64_key, + b64len, + (const unsigned char *)passphrase, + strnlen(passphrase, MAX_PASSPHRASE_SIZE), + mbedtls_ctr_drbg_random, + ctr_drbg); #else - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, - b64len, NULL, - 0); + valid = mbedtls_pk_parse_key(pk, + (const unsigned char *)b64_key, + b64len, + (const unsigned char *)passphrase, + strnlen(passphrase, MAX_PASSPHRASE_SIZE)); #endif - } - } else { -#if MBEDTLS_VERSION_MAJOR > 2 - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, b64len, - (const unsigned char *) passphrase, - strnlen(passphrase, MAX_PASSPHRASE_SIZE), - mbedtls_ctr_drbg_random, ctr_drbg); -#else - valid = mbedtls_pk_parse_key(ecdsa, - (const unsigned char *) b64_key, b64len, - (const unsigned char *) passphrase, - strnlen(passphrase, MAX_PASSPHRASE_SIZE)); -#endif - } - - if (valid != 0) { - char error_buf[100]; - mbedtls_strerror(valid, error_buf, 100); - SSH_LOG(SSH_LOG_WARN,"Parsing private key %s", error_buf); - goto fail; - } - break; - case SSH_KEYTYPE_ED25519: - /* Cannot open ed25519 keys with libmbedcrypto */ - default: - SSH_LOG(SSH_LOG_WARN, "Unknown or invalid private key type %d", - type); - return NULL; } + if (valid != 0) { + char error_buf[100]; + mbedtls_strerror(valid, error_buf, 100); + SSH_LOG(SSH_LOG_WARN, "Parsing private key %s", error_buf); + goto fail; + } + + mbed_type = mbedtls_pk_get_type(pk); key = ssh_key_new(); if (key == NULL) { goto fail; } - if (ecdsa != NULL) { - mbedtls_ecp_keypair *keypair = mbedtls_pk_ec(*ecdsa); + switch (mbed_type) { + case MBEDTLS_PK_RSA: + case MBEDTLS_PK_RSA_ALT: + key->rsa = pk; + pk = NULL; + key->type = SSH_KEYTYPE_RSA; + break; + case MBEDTLS_PK_ECKEY: + case MBEDTLS_PK_ECDSA: { + /* type will be set later */ + mbedtls_ecp_keypair *keypair = mbedtls_pk_ec(*pk); + pk = NULL; key->ecdsa = malloc(sizeof(mbedtls_ecdsa_context)); if (key->ecdsa == NULL) { @@ -258,40 +197,36 @@ ssh_key pki_private_key_from_base64(const char *b64_key, const char *passphrase, mbedtls_ecdsa_init(key->ecdsa); mbedtls_ecdsa_from_keypair(key->ecdsa, keypair); - mbedtls_pk_free(ecdsa); - SAFE_FREE(ecdsa); + mbedtls_pk_free(pk); + SAFE_FREE(pk); key->ecdsa_nid = pki_key_ecdsa_to_nid(key->ecdsa); /* pki_privatekey_type_from_string always returns P256 for ECDSA - * keys, so we need to figure out the correct type here */ - type = pki_key_ecdsa_to_key_type(key->ecdsa); - if (type == SSH_KEYTYPE_UNKNOWN) { + * keys, so we need to figure out the correct type here */ + key->type = pki_key_ecdsa_to_key_type(key->ecdsa); + if (key->type == SSH_KEYTYPE_UNKNOWN) { SSH_LOG(SSH_LOG_WARN, "Invalid private key."); goto fail; } - } else { - key->ecdsa = NULL; + break; + } + default: + SSH_LOG(SSH_LOG_WARN, + "Unknown or invalid private key type %d", + mbed_type); + return NULL; } - key->type = type; - key->type_c = ssh_key_type_to_char(type); + key->type_c = ssh_key_type_to_char(key->type); key->flags = SSH_KEY_FLAG_PRIVATE | SSH_KEY_FLAG_PUBLIC; - key->rsa = rsa; - key->ed25519_privkey = ed25519; - rsa = NULL; - ecdsa = NULL; return key; fail: ssh_key_free(key); - if (rsa != NULL) { - mbedtls_pk_free(rsa); - SAFE_FREE(rsa); - } - if (ecdsa != NULL) { - mbedtls_pk_free(ecdsa); - SAFE_FREE(ecdsa); + if (pk != NULL) { + mbedtls_pk_free(pk); + SAFE_FREE(pk); } return NULL; }