diff --git a/include/libssh/misc.h b/include/libssh/misc.h index 1d84ffd6..fe86d251 100644 --- a/include/libssh/misc.h +++ b/include/libssh/misc.h @@ -120,6 +120,7 @@ int ssh_tmpname(char *name); char *ssh_strreplace(const char *src, const char *pattern, const char *repl); ssize_t ssh_readn(int fd, void *buf, size_t nbytes); +ssize_t ssh_writen(int fd, const void *buf, size_t nbytes); #ifdef __cplusplus } diff --git a/src/misc.c b/src/misc.c index bc072295..8bdd568e 100644 --- a/src/misc.c +++ b/src/misc.c @@ -2036,4 +2036,55 @@ ssize_t ssh_readn(int fd, void *buf, size_t nbytes) return total_bytes_read; } +/** + * @brief Write the requested number of bytes to a local file. + * + * A call to write() may perform a short write on a local file. This function + * can be used to avoid short writes. + * + * This function tries to write the requested number of bytes until those many + * bytes are written or some error occurs. + * + * On encountering an error due to an interrupt, this function ignores that + * error and continues trying to write the data. + * + * @param[in] fd The file descriptor of the local file to write to. + * + * @param[in] buf Pointer to a buffer in which data to write is stored. + * + * @param[in] nbytes Number of bytes to write. + * + * @returns Number of bytes written on success, + * SSH_ERROR on error with errno set to indicate the + * error. + */ +ssize_t ssh_writen(int fd, const void *buf, size_t nbytes) +{ + size_t total_bytes_written = 0; + ssize_t bytes_written; + + if (fd < 0 || buf == NULL || nbytes == 0) { + errno = EINVAL; + return SSH_ERROR; + } + + do { + bytes_written = write(fd, + ((const char *)buf) + total_bytes_written, + nbytes - total_bytes_written); + if (bytes_written == -1) { + if(errno == EINTR) { + /* Ignoring errors due to signal interrupts */ + continue; + } + + return SSH_ERROR; + } + + total_bytes_written += (size_t)bytes_written; + } while (total_bytes_written < nbytes); + + return total_bytes_written; +} + /** @} */