diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c index c8ff76945ac8..8a2a489b2cd3 100644 --- a/net/netfilter/xt_socket.c +++ b/net/netfilter/xt_socket.c @@ -148,6 +148,7 @@ struct sock *xt_socket_lookup_slow_v4(struct net *net, const struct net_device *indev) { const struct iphdr *iph = ip_hdr(skb); + struct sock *sk = skb->sk; __be32 uninitialized_var(daddr), uninitialized_var(saddr); __be16 uninitialized_var(dport), uninitialized_var(sport); u8 uninitialized_var(protocol); @@ -198,8 +199,14 @@ struct sock *xt_socket_lookup_slow_v4(struct net *net, } #endif - return xt_socket_get_sock_v4(net, protocol, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v4(dev_net(skb->dev), protocol, + saddr, daddr, sport, dport, + indev); + + return sk; } EXPORT_SYMBOL(xt_socket_lookup_slow_v4); @@ -233,8 +240,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par, transparent) pskb->mark = sk->sk_mark; - if (sk != skb->sk) - sock_gen_put(sk); + sock_gen_put(sk); if (wildcard || !transparent) sk = NULL; @@ -341,6 +347,7 @@ struct sock *xt_socket_lookup_slow_v6(struct net *net, const struct sk_buff *skb, const struct net_device *indev) { + struct sock *sk = skb->sk; __be16 uninitialized_var(dport), uninitialized_var(sport); const struct in6_addr *daddr = NULL, *saddr = NULL; struct ipv6hdr *iph = ipv6_hdr(skb); @@ -374,8 +381,14 @@ struct sock *xt_socket_lookup_slow_v6(struct net *net, return NULL; } - return xt_socket_get_sock_v6(net, tproto, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v6(dev_net(skb->dev), tproto, + saddr, daddr, sport, dport, + indev); + + return sk; } EXPORT_SYMBOL(xt_socket_lookup_slow_v6);