diff --git a/include/net/xfrm.h b/include/net/xfrm.h index 505542501162..0740f1025630 100644 --- a/include/net/xfrm.h +++ b/include/net/xfrm.h @@ -212,6 +212,7 @@ struct xfrm_state { /* Data for encapsulator */ struct xfrm_encap_tmpl *encap; + struct sock __rcu *encap_sk; /* Data for care-of address */ xfrm_address_t *coaddr; diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c index 8f5417ff355d..419969b26822 100644 --- a/net/ipv4/esp4.c +++ b/net/ipv4/esp4.c @@ -118,16 +118,47 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp) } #ifdef CONFIG_INET_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + static struct sock *esp_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; struct net *net = xs_net(x); + struct esp_tcp_sk *esk; __be16 sport, dport; + struct sock *nsk; struct sock *sk; + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + spin_lock_bh(&x->lock); sport = encap->encap_sport; dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } spin_unlock_bh(&x->lock); sk = inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, x->id.daddr.a4, @@ -140,6 +171,20 @@ static struct sock *esp_find_tcp_sk(struct xfrm_state *x) return ERR_PTR(-EINVAL); } + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + return sk; } @@ -162,8 +207,6 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) err = espintcp_push_skb(sk, skb); bh_unlock_sock(sk); - sock_put(sk); - out: rcu_read_unlock(); return err; @@ -348,8 +391,6 @@ static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x, if (IS_ERR(sk)) return ERR_CAST(sk); - sock_put(sk); - *lenp = htons(len); esph = (struct ip_esp_hdr *)(lenp + 1); diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c index 085a83b807af..a021c88d3d9b 100644 --- a/net/ipv6/esp6.c +++ b/net/ipv6/esp6.c @@ -135,16 +135,47 @@ static void esp_ssg_unref(struct xfrm_state *x, void *tmp) } #ifdef CONFIG_INET6_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + static struct sock *esp6_find_tcp_sk(struct xfrm_state *x) { struct xfrm_encap_tmpl *encap = x->encap; struct net *net = xs_net(x); + struct esp_tcp_sk *esk; __be16 sport, dport; + struct sock *nsk; struct sock *sk; + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + spin_lock_bh(&x->lock); sport = encap->encap_sport; dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } spin_unlock_bh(&x->lock); sk = __inet6_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, &x->id.daddr.in6, @@ -157,6 +188,20 @@ static struct sock *esp6_find_tcp_sk(struct xfrm_state *x) return ERR_PTR(-EINVAL); } + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + return sk; } @@ -179,8 +224,6 @@ static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) err = espintcp_push_skb(sk, skb); bh_unlock_sock(sk); - sock_put(sk); - out: rcu_read_unlock(); return err; @@ -384,8 +427,6 @@ static struct ip_esp_hdr *esp6_output_tcp_encap(struct xfrm_state *x, if (IS_ERR(sk)) return ERR_CAST(sk); - sock_put(sk); - *lenp = htons(len); esph = (struct ip_esp_hdr *)(lenp + 1); diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index f43c09b85f96..a38ea5ae9d5c 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -694,6 +694,9 @@ int __xfrm_state_delete(struct xfrm_state *x) net->xfrm.state_num--; spin_unlock(&net->xfrm.xfrm_state_lock); + if (x->encap_sk) + sock_put(rcu_dereference_raw(x->encap_sk)); + xfrm_dev_state_delete(x); /* All xfrm_state objects are created by xfrm_state_alloc.