diff --git a/src/channels.c b/src/channels.c index b54af3bb..208645da 100644 --- a/src/channels.c +++ b/src/channels.c @@ -52,16 +52,19 @@ #include "libssh/server.h" #endif -#define WINDOWBASE 1280000 -#define WINDOWLIMIT (WINDOWBASE/2) - /* * All implementations MUST be able to process packets with an * uncompressed payload length of 32768 bytes or less and a total packet * size of 35000 bytes or less. */ #define CHANNEL_MAX_PACKET 32768 -#define CHANNEL_INITIAL_WINDOW 64000 + +/* + * WINDOW_DEFAULT matches the default OpenSSH session window size. + * This controls how much data the peer can send before needing to receive + * a round-trip SSH2_MSG_CHANNEL_WINDOW_ADJUST message that increases the window. + */ +#define WINDOW_DEFAULT (64*CHANNEL_MAX_PACKET) /** * @defgroup libssh_channel The SSH channel functions @@ -422,32 +425,43 @@ ssh_channel ssh_channel_from_local(ssh_session session, uint32_t id) { * @brief grows the local window and sends a packet to the other party * @param session SSH session * @param channel SSH channel - * @param minimumsize The minimum acceptable size for the new window. * @return SSH_OK if successful; SSH_ERROR otherwise. */ static int grow_window(ssh_session session, - ssh_channel channel, - uint32_t minimumsize) + ssh_channel channel) { - uint32_t new_window = minimumsize > WINDOWBASE ? minimumsize : WINDOWBASE; + uint32_t used; + uint32_t increment; int rc; - if (new_window <= channel->local_window) { + /* Calculate the increment taking into account what the peer may still send + * (local_window) and what we've already buffered (stdout_buffer and + * stderr_buffer). + */ + used = channel->local_window; + if (channel->stdout_buffer != NULL) { + used += ssh_buffer_get_len(channel->stdout_buffer); + } + if (channel->stderr_buffer != NULL) { + used += ssh_buffer_get_len(channel->stderr_buffer); + } + /* Avoid a negative increment in case the peer sent more than the window allowed */ + increment = WINDOW_DEFAULT > used ? WINDOW_DEFAULT - used : 0; + /* Don't grow until we can request at least half a window */ + if (increment < (WINDOW_DEFAULT / 2)) { SSH_LOG(SSH_LOG_DEBUG, "growing window (channel %" PRIu32 ":%" PRIu32 ") to %" PRIu32 " bytes : not needed (%" PRIu32 " bytes)", - channel->local_channel, channel->remote_channel, new_window, + channel->local_channel, channel->remote_channel, WINDOW_DEFAULT, channel->local_window); return SSH_OK; } - /* WINDOW_ADJUST packet needs a relative increment rather than an absolute - * value, so we give here the missing bytes needed to reach new_window - */ + rc = ssh_buffer_pack(session->out_buffer, "bdd", SSH2_MSG_CHANNEL_WINDOW_ADJUST, channel->remote_channel, - new_window - channel->local_window); + increment); if (rc != SSH_OK) { ssh_set_error_oom(session); goto error; @@ -458,12 +472,12 @@ static int grow_window(ssh_session session, } SSH_LOG(SSH_LOG_DEBUG, - "growing window (channel %" PRIu32 ":%" PRIu32 ") to %" PRIu32 " bytes", + "growing window (channel %" PRIu32 ":%" PRIu32 ") by %" PRIu32 " bytes", channel->local_channel, channel->remote_channel, - new_window); + increment); - channel->local_window = new_window; + channel->local_window += increment; return SSH_OK; error: @@ -614,12 +628,17 @@ SSH_PACKET_CALLBACK(channel_rcv_data) channel->local_window, channel->remote_window); - /* What shall we do in this case? Let's accept it anyway */ if (len > channel->local_window) { SSH_LOG(SSH_LOG_RARE, "Data packet too big for our window(%" PRIu32 " vs %" PRIu32 ")", len, channel->local_window); + + SSH_STRING_FREE(str); + + ssh_set_error(session, SSH_FATAL, "Window exceeded"); + + return SSH_PACKET_USED; } data = ssh_string_data(str); @@ -629,11 +648,7 @@ SSH_PACKET_CALLBACK(channel_rcv_data) return SSH_PACKET_USED; } - if (len <= channel->local_window) { - channel->local_window -= len; - } else { - channel->local_window = 0; /* buggy remote */ - } + channel->local_window -= len; SSH_LOG(SSH_LOG_PACKET, "Channel windows are now (local win=%" PRIu32 " remote win=%" PRIu32 ")", @@ -661,19 +676,20 @@ SSH_PACKET_CALLBACK(channel_rcv_data) ssh_buffer_get_len(buf), is_stderr); if (rest > 0) { + int rc; if (channel->counter != NULL) { channel->counter->in_bytes += rest; } ssh_buffer_pass_bytes(buf, rest); + + rc = grow_window(session, channel); + if (rc == SSH_ERROR) { + return -1; + } } } ssh_callbacks_iterate_end(); - if (channel->local_window + ssh_buffer_get_len(buf) < WINDOWLIMIT) { - if (grow_window(session, channel, 0) < 0) { - return -1; - } - } return SSH_PACKET_USED; } @@ -1025,7 +1041,7 @@ int ssh_channel_open_session(ssh_channel channel) return channel_open(channel, "session", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, NULL); } @@ -1053,7 +1069,7 @@ int ssh_channel_open_auth_agent(ssh_channel channel) return channel_open(channel, "auth-agent@openssh.com", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, NULL); } @@ -1122,7 +1138,7 @@ int ssh_channel_open_forward(ssh_channel channel, const char *remotehost, rc = channel_open(channel, "direct-tcpip", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, payload); @@ -1205,7 +1221,7 @@ int ssh_channel_open_forward_unix(ssh_channel channel, rc = channel_open(channel, "direct-streamlocal@openssh.com", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, payload); @@ -2967,14 +2983,13 @@ int channel_read_buffer(ssh_channel channel, ssh_buffer buffer, uint32_t count, struct ssh_channel_read_termination_struct { ssh_channel channel; - uint32_t count; ssh_buffer buffer; }; static int ssh_channel_read_termination(void *s) { struct ssh_channel_read_termination_struct *ctx = s; - if (ssh_buffer_get_len(ctx->buffer) >= ctx->count || + if (ssh_buffer_get_len(ctx->buffer) >= 1 || ctx->channel->remote_eof || ctx->channel->session->session_state == SSH_SESSION_STATE_ERROR) return 1; @@ -3063,28 +3078,17 @@ int ssh_channel_read_timeout(ssh_channel channel, stdbuf=channel->stderr_buffer; } - /* - * We may have problem if the window is too small to accept as much data - * as asked - */ SSH_LOG(SSH_LOG_PACKET, "Read (%" PRIu32 ") buffered : %" PRIu32 " bytes. Window: %" PRIu32, count, ssh_buffer_get_len(stdbuf), channel->local_window); - if (count > ssh_buffer_get_len(stdbuf) + channel->local_window) { - if (grow_window(session, channel, count - ssh_buffer_get_len(stdbuf)) < 0) { - return -1; - } - } - /* block reading until at least one byte has been read * and ignore the trivial case count=0 */ ctx.channel = channel; ctx.buffer = stdbuf; - ctx.count = 1; if (timeout_ms < SSH_TIMEOUT_DEFAULT) { timeout_ms = SSH_TIMEOUT_INFINITE; @@ -3126,11 +3130,10 @@ int ssh_channel_read_timeout(ssh_channel channel, if (channel->delayed_close && !ssh_channel_has_unread_data(channel)) { channel->state = SSH_CHANNEL_STATE_CLOSED; } - /* Authorize some buffering while userapp is busy */ - if (channel->local_window < WINDOWLIMIT) { - if (grow_window(session, channel, 0) < 0) { - return -1; - } + + rc = grow_window(session, channel); + if (rc == SSH_ERROR) { + return -1; } return len; @@ -3290,7 +3293,6 @@ int ssh_channel_poll_timeout(ssh_channel channel, int timeout, int is_stderr) } ctx.buffer = stdbuf; ctx.channel = channel; - ctx.count = 1; rc = ssh_handle_packets_termination(channel->session, timeout, ssh_channel_read_termination, @@ -3708,7 +3710,7 @@ int ssh_channel_open_reverse_forward(ssh_channel channel, const char *remotehost pending: rc = channel_open(channel, "forwarded-tcpip", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, payload); @@ -3771,7 +3773,7 @@ int ssh_channel_open_x11(ssh_channel channel, pending: rc = channel_open(channel, "x11", - CHANNEL_INITIAL_WINDOW, + WINDOW_DEFAULT, CHANNEL_MAX_PACKET, payload);