diff --git a/include/net/af_unix.h b/include/net/af_unix.h index 1614d78c60ed..861045fbf936 100644 --- a/include/net/af_unix.h +++ b/include/net/af_unix.h @@ -10,6 +10,7 @@ extern void unix_inflight(struct file *fp); extern void unix_notinflight(struct file *fp); extern void unix_gc(void); extern void wait_for_unix_gc(void); +extern struct sock *unix_get_socket(struct file *filp); #define UNIX_HASH_SIZE 256 @@ -56,6 +57,7 @@ struct unix_sock { spinlock_t lock; unsigned int gc_candidate : 1; unsigned int gc_maybe_cycle : 1; + unsigned char recursion_level; wait_queue_head_t peer_wait; }; #define unix_sk(__sk) ((struct unix_sock *)__sk) diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index bdcdd01d8810..db8d51a974b4 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1325,9 +1325,25 @@ static void unix_destruct_fds(struct sk_buff *skb) sock_wfree(skb); } +#define MAX_RECURSION_LEVEL 4 + static int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) { int i; + unsigned char max_level = 0; + int unix_sock_count = 0; + + for (i = scm->fp->count - 1; i >= 0; i--) { + struct sock *sk = unix_get_socket(scm->fp->fp[i]); + + if (sk) { + unix_sock_count++; + max_level = max(max_level, + unix_sk(sk)->recursion_level); + } + } + if (unlikely(max_level > MAX_RECURSION_LEVEL)) + return -ETOOMANYREFS; /* * Need to duplicate file references for the sake of garbage @@ -1338,10 +1354,12 @@ static int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) if (!UNIXCB(skb).fp) return -ENOMEM; - for (i = scm->fp->count-1; i >= 0; i--) - unix_inflight(scm->fp->fp[i]); + if (unix_sock_count) { + for (i = scm->fp->count - 1; i >= 0; i--) + unix_inflight(scm->fp->fp[i]); + } skb->destructor = unix_destruct_fds; - return 0; + return max_level; } /* @@ -1363,6 +1381,7 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, struct sk_buff *skb; long timeo; struct scm_cookie tmp_scm; + int max_level = 0; if (NULL == siocb->scm) siocb->scm = &tmp_scm; @@ -1403,8 +1422,9 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, memcpy(UNIXCREDS(skb), &siocb->scm->creds, sizeof(struct ucred)); if (siocb->scm->fp) { err = unix_attach_fds(siocb->scm, skb); - if (err) + if (err < 0) goto out_free; + max_level = err + 1; } unix_get_secdata(siocb->scm, skb); @@ -1485,6 +1505,8 @@ restart: } skb_queue_tail(&other->sk_receive_queue, skb); + if (max_level > unix_sk(other)->recursion_level) + unix_sk(other)->recursion_level = max_level; unix_state_unlock(other); other->sk_data_ready(other, len); sock_put(other); @@ -1515,6 +1537,7 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, int sent = 0; struct scm_cookie tmp_scm; bool fds_sent = false; + int max_level = 0; if (NULL == siocb->scm) siocb->scm = &tmp_scm; @@ -1579,10 +1602,11 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, /* Only send the fds in the first buffer */ if (siocb->scm->fp && !fds_sent) { err = unix_attach_fds(siocb->scm, skb); - if (err) { + if (err < 0) { kfree_skb(skb); goto out_err; } + max_level = err + 1; fds_sent = true; } @@ -1599,6 +1623,8 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, goto pipe_err_free; skb_queue_tail(&other->sk_receive_queue, skb); + if (max_level > unix_sk(other)->recursion_level) + unix_sk(other)->recursion_level = max_level; unix_state_unlock(other); other->sk_data_ready(other, size); sent += size; @@ -1827,6 +1853,7 @@ static int unix_stream_recvmsg(struct kiocb *iocb, struct socket *sock, unix_state_lock(sk); skb = skb_dequeue(&sk->sk_receive_queue); if (skb == NULL) { + unix_sk(sk)->recursion_level = 0; if (copied >= target) goto unlock; diff --git a/net/unix/garbage.c b/net/unix/garbage.c index 19c17e4a0c8b..97d1a380423c 100644 --- a/net/unix/garbage.c +++ b/net/unix/garbage.c @@ -97,7 +97,7 @@ static DECLARE_WAIT_QUEUE_HEAD(unix_gc_wait); unsigned int unix_tot_inflight; -static struct sock *unix_get_socket(struct file *filp) +struct sock *unix_get_socket(struct file *filp) { struct sock *u_sock = NULL; struct inode *inode = filp->f_path.dentry->d_inode;