Compare commits

...

4 Commits

Author SHA1 Message Date
Jakub Jelen
ef50a3c0f0 tests: Remove tests of operations on freed channels
These tests are flaky because even though the care was taken to guess if
the ssh_channel_free() really freed the channel, it might not always be correct
and call to operation on the freed channel results in use after free.

Generally, no operation should be called after the channel is freed by the user.

Signed-off-by: Jakub Jelen <jjelen@redhat.com>
Reviewed-by: Andreas Schneider <asn@cryptomilk.org>
2025-08-06 11:18:45 +02:00
Jakub Jelen
e7cffe7e1b pki: Simplify ed25519 private key duplication
Signed-off-by: Jakub Jelen <jjelen@redhat.com>
Reviewed-by: Andreas Schneider <asn@cryptomilk.org>
2025-08-06 11:18:20 +02:00
Jakub Jelen
d1bf9068a9 Use calloc instead of zeroizing structure after malloc
Signed-off-by: Jakub Jelen <jjelen@redhat.com>
Reviewed-by: Andreas Schneider <asn@cryptomilk.org>
2025-08-06 11:16:58 +02:00
Jakub Jelen
737f9ecc3c agent: Reformat the rest of the file
Signed-off-by: Jakub Jelen <jjelen@redhat.com>
Reviewed-by: Andreas Schneider <asn@cryptomilk.org>
2025-08-06 11:16:58 +02:00
7 changed files with 215 additions and 326 deletions

View File

@@ -67,87 +67,94 @@
(((x) == SSH_AGENT_FAILURE) || ((x) == SSH_COM_AGENT2_FAILURE) || \
((x) == SSH2_AGENT_FAILURE))
static uint32_t atomicio(struct ssh_agent_struct *agent, void *buf, uint32_t n, int do_read) {
char *b = buf;
uint32_t pos = 0;
ssize_t res;
ssh_pollfd_t pfd;
ssh_channel channel = agent->channel;
socket_t fd;
static uint32_t
atomicio(struct ssh_agent_struct *agent, void *buf, uint32_t n, int do_read)
{
char *b = buf;
uint32_t pos = 0;
ssize_t res;
ssh_pollfd_t pfd;
ssh_channel channel = agent->channel;
socket_t fd;
/* Using a socket ? */
if (channel == NULL) {
fd = ssh_socket_get_fd(agent->sock);
pfd.fd = fd;
pfd.events = do_read ? POLLIN : POLLOUT;
/* Using a socket ? */
if (channel == NULL) {
fd = ssh_socket_get_fd(agent->sock);
pfd.fd = fd;
pfd.events = do_read ? POLLIN : POLLOUT;
while (n > pos) {
if (do_read) {
res = recv(fd, b + pos, n - pos, 0);
} else {
res = send(fd, b + pos, n - pos, 0);
}
switch (res) {
case -1:
if (errno == EINTR) {
continue;
}
while (n > pos) {
if (do_read) {
res = recv(fd, b + pos, n - pos, 0);
} else {
res = send(fd, b + pos, n - pos, 0);
}
switch (res) {
case -1:
if (errno == EINTR) {
continue;
}
#ifdef EWOULDBLOCK
if (errno == EAGAIN || errno == EWOULDBLOCK) {
if (errno == EAGAIN || errno == EWOULDBLOCK) {
#else
if (errno == EAGAIN) {
if (errno == EAGAIN) {
#endif
(void) ssh_poll(&pfd, 1, -1);
continue;
}
return 0;
case 0:
/* read returns 0 on end-of-file */
errno = do_read ? 0 : EPIPE;
return pos;
default:
pos += (uint32_t) res;
(void)ssh_poll(&pfd, 1, -1);
continue;
}
return 0;
case 0:
/* read returns 0 on end-of-file */
errno = do_read ? 0 : EPIPE;
return pos;
default:
pos += (uint32_t)res;
}
}
}
return pos;
return pos;
} else {
/* using an SSH channel */
while (n > pos){
if (do_read)
res = ssh_channel_read(channel,b + pos, n-pos, 0);
else
res = ssh_channel_write(channel, b+pos, n-pos);
if (res == SSH_AGAIN)
continue;
if (res == SSH_ERROR)
return 0;
pos += (uint32_t)res;
}
return pos;
/* using an SSH channel */
while (n > pos) {
if (do_read) {
res = ssh_channel_read(channel, b + pos, n - pos, 0);
} else {
res = ssh_channel_write(channel, b + pos, n - pos);
}
if (res == SSH_AGAIN) {
continue;
}
if (res == SSH_ERROR) {
return 0;
}
pos += (uint32_t)res;
}
return pos;
}
}
ssh_agent ssh_agent_new(struct ssh_session_struct *session) {
ssh_agent agent = NULL;
ssh_agent ssh_agent_new(struct ssh_session_struct *session)
{
ssh_agent agent = NULL;
agent = malloc(sizeof(struct ssh_agent_struct));
if (agent == NULL) {
return NULL;
}
ZERO_STRUCTP(agent);
agent = calloc(1, sizeof(struct ssh_agent_struct));
if (agent == NULL) {
return NULL;
}
agent->count = 0;
agent->sock = ssh_socket_new(session);
if (agent->sock == NULL) {
SAFE_FREE(agent);
return NULL;
}
agent->channel = NULL;
return agent;
agent->count = 0;
agent->sock = ssh_socket_new(session);
if (agent->sock == NULL) {
SAFE_FREE(agent);
return NULL;
}
agent->channel = NULL;
return agent;
}
static void agent_set_channel(struct ssh_agent_struct *agent, ssh_channel channel){
agent->channel = channel;
static void agent_set_channel(struct ssh_agent_struct *agent,
ssh_channel channel)
{
agent->channel = channel;
}
/**
@@ -168,15 +175,19 @@ static void agent_set_channel(struct ssh_agent_struct *agent, ssh_channel channe
* @returns SSH_OK in case of success
* SSH_ERROR in case of an error
*/
int ssh_set_agent_channel(ssh_session session, ssh_channel channel){
if (!session)
return SSH_ERROR;
if (!session->agent){
ssh_set_error(session, SSH_REQUEST_DENIED, "Session has no active agent");
return SSH_ERROR;
}
agent_set_channel(session->agent, channel);
return SSH_OK;
int ssh_set_agent_channel(ssh_session session, ssh_channel channel)
{
if (!session) {
return SSH_ERROR;
}
if (!session->agent) {
ssh_set_error(session,
SSH_REQUEST_DENIED,
"Session has no active agent");
return SSH_ERROR;
}
agent_set_channel(session->agent, channel);
return SSH_OK;
}
/** @brief sets the SSH agent socket.
@@ -187,64 +198,72 @@ int ssh_set_agent_channel(ssh_session session, ssh_channel channel){
* @returns SSH_OK in case of success
* SSH_ERROR in case of an error
*/
int ssh_set_agent_socket(ssh_session session, socket_t fd){
if (!session)
return SSH_ERROR;
if (!session->agent){
ssh_set_error(session, SSH_REQUEST_DENIED, "Session has no active agent");
return SSH_ERROR;
}
int ssh_set_agent_socket(ssh_session session, socket_t fd)
{
if (!session) {
return SSH_ERROR;
}
if (!session->agent) {
ssh_set_error(session,
SSH_REQUEST_DENIED,
"Session has no active agent");
return SSH_ERROR;
}
ssh_socket_set_fd(session->agent->sock, fd);
return SSH_OK;
ssh_socket_set_fd(session->agent->sock, fd);
return SSH_OK;
}
/**
* @}
*/
void ssh_agent_close(struct ssh_agent_struct *agent) {
if (agent == NULL) {
return;
}
void ssh_agent_close(struct ssh_agent_struct *agent)
{
if (agent == NULL) {
return;
}
ssh_socket_close(agent->sock);
ssh_socket_close(agent->sock);
}
void ssh_agent_free(ssh_agent agent) {
if (agent) {
if (agent->ident) {
SSH_BUFFER_FREE(agent->ident);
void ssh_agent_free(ssh_agent agent)
{
if (agent) {
if (agent->ident) {
SSH_BUFFER_FREE(agent->ident);
}
if (agent->sock) {
ssh_agent_close(agent);
ssh_socket_free(agent->sock);
}
SAFE_FREE(agent);
}
if (agent->sock) {
ssh_agent_close(agent);
ssh_socket_free(agent->sock);
}
SAFE_FREE(agent);
}
}
static int agent_connect(ssh_session session) {
const char *auth_sock = NULL;
static int agent_connect(ssh_session session)
{
const char *auth_sock = NULL;
if (session == NULL || session->agent == NULL) {
return -1;
}
if (session->agent->channel != NULL) {
return 0;
}
auth_sock = session->opts.agent_socket ? session->opts.agent_socket
: getenv("SSH_AUTH_SOCK");
if (auth_sock && *auth_sock) {
if (ssh_socket_unix(session->agent->sock, auth_sock) < 0) {
return -1;
}
return 0;
}
if (session == NULL || session->agent == NULL) {
return -1;
}
if (session->agent->channel != NULL)
return 0;
auth_sock = session->opts.agent_socket ?
session->opts.agent_socket : getenv("SSH_AUTH_SOCK");
if (auth_sock && *auth_sock) {
if (ssh_socket_unix(session->agent->sock, auth_sock) < 0) {
return -1;
}
return 0;
}
return -1;
}
#if 0
@@ -268,61 +287,66 @@ static int agent_decode_reply(struct ssh_session_struct *session, int type) {
#endif
static int agent_talk(struct ssh_session_struct *session,
struct ssh_buffer_struct *request, struct ssh_buffer_struct *reply) {
uint32_t len = 0;
uint8_t tmpbuf[4];
uint8_t *payload = tmpbuf;
char err_msg[SSH_ERRNO_MSG_MAX] = {0};
struct ssh_buffer_struct *request,
struct ssh_buffer_struct *reply)
{
uint32_t len = 0;
uint8_t tmpbuf[4];
uint8_t *payload = tmpbuf;
char err_msg[SSH_ERRNO_MSG_MAX] = {0};
len = ssh_buffer_get_len(request);
SSH_LOG(SSH_LOG_TRACE, "Request length: %" PRIu32, len);
PUSH_BE_U32(payload, 0, len);
len = ssh_buffer_get_len(request);
SSH_LOG(SSH_LOG_TRACE, "Request length: %" PRIu32, len);
PUSH_BE_U32(payload, 0, len);
/* send length and then the request packet */
if (atomicio(session->agent, payload, 4, 0) == 4) {
if (atomicio(session->agent, ssh_buffer_get(request), len, 0)
!= len) {
SSH_LOG(SSH_LOG_TRACE, "atomicio sending request failed: %s",
strerror(errno));
return -1;
/* send length and then the request packet */
if (atomicio(session->agent, payload, 4, 0) == 4) {
if (atomicio(session->agent, ssh_buffer_get(request), len, 0) != len) {
SSH_LOG(SSH_LOG_TRACE,
"atomicio sending request failed: %s",
ssh_strerror(errno, err_msg, SSH_ERRNO_MSG_MAX));
return -1;
}
} else {
SSH_LOG(SSH_LOG_TRACE,
"atomicio sending request length failed: %s",
ssh_strerror(errno, err_msg, SSH_ERRNO_MSG_MAX));
return -1;
}
} else {
SSH_LOG(SSH_LOG_TRACE,
"atomicio sending request length failed: %s",
ssh_strerror(errno, err_msg, SSH_ERRNO_MSG_MAX));
return -1;
}
/* wait for response, read the length of the response packet */
if (atomicio(session->agent, payload, 4, 1) != 4) {
SSH_LOG(SSH_LOG_TRACE, "atomicio read response length failed: %s",
strerror(errno));
return -1;
}
/* wait for response, read the length of the response packet */
if (atomicio(session->agent, payload, 4, 1) != 4) {
SSH_LOG(SSH_LOG_TRACE,
"atomicio read response length failed: %s",
ssh_strerror(errno, err_msg, SSH_ERRNO_MSG_MAX));
return -1;
}
len = PULL_BE_U32(payload, 0);
if (len > 256 * 1024) {
ssh_set_error(session, SSH_FATAL,
"Authentication response too long: %" PRIu32, len);
return -1;
}
SSH_LOG(SSH_LOG_TRACE, "Response length: %" PRIu32, len);
len = PULL_BE_U32(payload, 0);
if (len > 256 * 1024) {
ssh_set_error(session,
SSH_FATAL,
"Authentication response too long: %" PRIu32,
len);
return -1;
}
SSH_LOG(SSH_LOG_TRACE, "Response length: %" PRIu32, len);
payload = ssh_buffer_allocate(reply, len);
if (payload == NULL) {
SSH_LOG(SSH_LOG_DEBUG, "Not enough space");
return -1;
}
payload = ssh_buffer_allocate(reply, len);
if (payload == NULL) {
SSH_LOG(SSH_LOG_DEBUG, "Not enough space");
return -1;
}
if (atomicio(session->agent, payload, len, 1) != len) {
SSH_LOG(SSH_LOG_DEBUG,
"Error reading response from authentication socket.");
/* Rollback the unused space */
ssh_buffer_pass_bytes_end(reply, len);
return -1;
}
if (atomicio(session->agent, payload, len, 1) != len) {
SSH_LOG(SSH_LOG_DEBUG,
"Error reading response from authentication socket.");
/* Rollback the unused space */
ssh_buffer_pass_bytes_end(reply, len);
return -1;
}
return 0;
return 0;
}
uint32_t ssh_agent_get_ident_count(struct ssh_session_struct *session)
@@ -471,22 +495,23 @@ ssh_key ssh_agent_get_next_ident(struct ssh_session_struct *session,
return key;
}
int ssh_agent_is_running(ssh_session session) {
if (session == NULL || session->agent == NULL) {
return 0;
}
if (ssh_socket_is_open(session->agent->sock)) {
return 1;
} else {
if (agent_connect(session) < 0) {
return 0;
} else {
return 1;
int ssh_agent_is_running(ssh_session session)
{
if (session == NULL || session->agent == NULL) {
return 0;
}
}
return 0;
if (ssh_socket_is_open(session->agent->sock)) {
return 1;
} else {
if (agent_connect(session) < 0) {
return 0;
} else {
return 1;
}
}
return 0;
}
ssh_string ssh_agent_sign_data(ssh_session session,

View File

@@ -1027,12 +1027,11 @@ int ssh_userauth_agent(ssh_session session, const char *username)
}
if (!session->agent_state) {
session->agent_state = malloc(sizeof(struct ssh_agent_state_struct));
session->agent_state = calloc(1, sizeof(struct ssh_agent_state_struct));
if (!session->agent_state) {
ssh_set_error_oom(session);
return SSH_AUTH_ERROR;
}
ZERO_STRUCTP(session->agent_state);
session->agent_state->state = SSH_AGENT_STATE_NONE;
}

View File

@@ -126,11 +126,10 @@ ssh_pcap_file ssh_pcap_file_new(void)
{
struct ssh_pcap_file_struct *pcap = NULL;
pcap = malloc(sizeof(struct ssh_pcap_file_struct));
pcap = calloc(1, sizeof(struct ssh_pcap_file_struct));
if (pcap == NULL) {
return NULL;
}
ZERO_STRUCTP(pcap);
return pcap;
}
@@ -296,12 +295,13 @@ void ssh_pcap_file_free(ssh_pcap_file pcap)
*/
ssh_pcap_context ssh_pcap_context_new(ssh_session session)
{
ssh_pcap_context ctx = (struct ssh_pcap_context_struct *)malloc(sizeof(struct ssh_pcap_context_struct));
ssh_pcap_context ctx = NULL;
ctx = calloc(1, sizeof(struct ssh_pcap_context_struct));
if (ctx == NULL) {
ssh_set_error_oom(session);
return NULL;
}
ZERO_STRUCTP(ctx);
ctx->session = session;
return ctx;
}

View File

@@ -833,11 +833,10 @@ ssh_signature ssh_signature_new(void)
{
struct ssh_signature_struct *sig = NULL;
sig = malloc(sizeof(struct ssh_signature_struct));
sig = calloc(1, sizeof(struct ssh_signature_struct));
if (sig == NULL) {
return NULL;
}
ZERO_STRUCTP(sig);
return sig;
}

View File

@@ -723,47 +723,11 @@ ssh_key pki_key_dup(const ssh_key key, int demote)
if (!demote && (key->flags & SSH_KEY_FLAG_PRIVATE) &&
key->type == SSH_KEYTYPE_ED25519) {
unsigned char *ed25519_privkey = NULL;
size_t key_len = 0;
rc = EVP_PKEY_get_raw_private_key(key->key, NULL, &key_len);
rc = EVP_PKEY_up_ref(key->key);
if (rc != 1) {
SSH_LOG(SSH_LOG_TRACE,
"Failed to get ed25519 raw private key length: %s",
ERR_error_string(ERR_get_error(), NULL));
goto fail;
}
if (key_len != ED25519_KEY_LEN) {
SSH_LOG(SSH_LOG_TRACE,
"Unexpected length of private key %zu. Expected %d.",
key_len,
ED25519_KEY_LEN);
goto fail;
}
ed25519_privkey = malloc(key_len);
if (ed25519_privkey == NULL) {
SSH_LOG(SSH_LOG_TRACE, "Out of memory");
goto fail;
}
rc = EVP_PKEY_get_raw_private_key(key->key,
ed25519_privkey,
&key_len);
if (rc != 1) {
SSH_LOG(SSH_LOG_TRACE,
"Failed to get ed25519 raw private key: %s",
ERR_error_string(ERR_get_error(), NULL));
free(ed25519_privkey);
goto fail;
}
new->key = EVP_PKEY_new_raw_private_key(EVP_PKEY_ED25519,
NULL,
ed25519_privkey,
key_len);
free(ed25519_privkey);
new->key = key->key;
} else {
unsigned char *ed25519_pubkey = NULL;
size_t key_len = 0;

View File

@@ -321,51 +321,6 @@ static void torture_freed_channel_poll(void **state)
assert_int_equal(rc, SSH_ERROR);
}
/* Ensure that calling 'ssh_channel_poll_timeout' on a freed channel does not
* lead to segmentation faults. */
static void torture_freed_channel_poll_timeout(void **state)
{
struct torture_state *s = *state;
ssh_session session = s->ssh.session;
ssh_channel channel;
bool channel_freed = false;
char request[256];
char buff[256] = {0};
int rc;
snprintf(request, 256,
"dd if=/dev/urandom of=/tmp/file bs=64000 count=2; hexdump -C /tmp/file");
channel = ssh_channel_new(session);
assert_non_null(channel);
rc = ssh_channel_open_session(channel);
assert_ssh_return_code(session, rc);
/* Make the request, read parts with close */
rc = ssh_channel_request_exec(channel, request);
assert_ssh_return_code(session, rc);
do {
rc = ssh_channel_read(channel, buff, 256, 0);
} while(rc > 0);
assert_ssh_return_code(session, rc);
/* when either of these conditions is met the call to ssh_channel_free will
* actually free the channel so calling poll on that channel will be
* use-after-free */
if ((channel->flags & SSH_CHANNEL_FLAG_CLOSED_REMOTE) ||
(channel->flags & SSH_CHANNEL_FLAG_NOT_BOUND)) {
channel_freed = true;
}
ssh_channel_free(channel);
if (!channel_freed) {
rc = ssh_channel_poll_timeout(channel, 500, 0);
assert_int_equal(rc, SSH_ERROR);
}
}
/* Ensure that calling 'ssh_channel_read_nonblocking' on a freed channel does
* not lead to segmentation faults. */
static void torture_freed_channel_read_nonblocking(void **state)
@@ -461,52 +416,6 @@ static void torture_channel_exit_signal(void **state)
SAFE_FREE(exit_signal);
}
/* Ensure that calling 'ssh_channel_get_exit_status' on a freed channel does not
* lead to segmentation faults. */
static void torture_freed_channel_get_exit_status(void **state)
{
struct torture_state *s = *state;
ssh_session session = s->ssh.session;
ssh_channel channel;
bool channel_freed = false;
char request[256];
char buff[256] = {0};
int rc;
snprintf(request, 256,
"dd if=/dev/urandom of=/tmp/file bs=64000 count=2; hexdump -C /tmp/file");
channel = ssh_channel_new(session);
assert_non_null(channel);
rc = ssh_channel_open_session(channel);
assert_ssh_return_code(session, rc);
/* Make the request, read parts with close */
rc = ssh_channel_request_exec(channel, request);
assert_ssh_return_code(session, rc);
do {
rc = ssh_channel_read(channel, buff, 256, 0);
} while(rc > 0);
assert_ssh_return_code(session, rc);
/* when either of these conditions is met the call to ssh_channel_free will
* actually free the channel so calling poll on that channel will be
* use-after-free */
if ((channel->flags & SSH_CHANNEL_FLAG_CLOSED_REMOTE) ||
(channel->flags & SSH_CHANNEL_FLAG_NOT_BOUND)) {
channel_freed = true;
}
SSH_CHANNEL_FREE(channel);
if (!channel_freed) {
rc = ssh_channel_get_exit_status(channel);
assert_ssh_return_code_equal(session, rc, SSH_ERROR);
}
}
static void
torture_channel_read_stderr(void **state)
{
@@ -611,9 +520,6 @@ int torture_run_tests(void) {
cmocka_unit_test_setup_teardown(torture_freed_channel_poll,
session_setup,
session_teardown),
cmocka_unit_test_setup_teardown(torture_freed_channel_poll_timeout,
session_setup,
session_teardown),
cmocka_unit_test_setup_teardown(torture_freed_channel_read_nonblocking,
session_setup,
session_teardown),
@@ -623,9 +529,6 @@ int torture_run_tests(void) {
cmocka_unit_test_setup_teardown(torture_channel_exit_signal,
session_setup,
session_teardown),
cmocka_unit_test_setup_teardown(torture_freed_channel_get_exit_status,
session_setup,
session_teardown),
cmocka_unit_test_setup_teardown(torture_channel_read_stderr,
session_setup,
session_teardown),

View File

@@ -21,11 +21,10 @@ static int myauthcallback (const char *prompt, char *buf, size_t len,
static int setup(void **state)
{
struct ssh_callbacks_struct *cb;
struct ssh_callbacks_struct *cb = NULL;
cb = malloc(sizeof(struct ssh_callbacks_struct));
cb = calloc(1, sizeof(struct ssh_callbacks_struct));
assert_non_null(cb);
ZERO_STRUCTP(cb);
cb->userdata = (void *) 0x0badc0de;
cb->auth_function = myauthcallback;