diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c index 69775b49e485..a52fbaf52691 100644 --- a/net/netfilter/xt_socket.c +++ b/net/netfilter/xt_socket.c @@ -151,6 +151,7 @@ struct sock *xt_socket_lookup_slow_v4(struct net *net, const struct iphdr *iph = ip_hdr(skb); struct sk_buff *data_skb = NULL; int doff = 0; + struct sock *sk = skb->sk; __be32 uninitialized_var(daddr), uninitialized_var(saddr); __be16 uninitialized_var(dport), uninitialized_var(sport); u8 uninitialized_var(protocol); @@ -205,8 +206,14 @@ struct sock *xt_socket_lookup_slow_v4(struct net *net, } #endif - return xt_socket_get_sock_v4(net, data_skb, doff, protocol, saddr, - daddr, sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v4(dev_net(skb->dev), data_skb, doff, + protocol, saddr, daddr, sport, + dport, indev); + + return sk; } EXPORT_SYMBOL(xt_socket_lookup_slow_v4); @@ -240,8 +247,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par, transparent) pskb->mark = sk->sk_mark; - if (sk != skb->sk) - sock_gen_put(sk); + sock_gen_put(sk); if (wildcard || !transparent) sk = NULL; @@ -349,6 +355,7 @@ struct sock *xt_socket_lookup_slow_v6(struct net *net, const struct sk_buff *skb, const struct net_device *indev) { + struct sock *sk = skb->sk; __be16 uninitialized_var(dport), uninitialized_var(sport); const struct in6_addr *daddr = NULL, *saddr = NULL; struct ipv6hdr *iph = ipv6_hdr(skb); @@ -388,8 +395,14 @@ struct sock *xt_socket_lookup_slow_v6(struct net *net, return NULL; } - return xt_socket_get_sock_v6(net, data_skb, doff, tproto, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v6(dev_net(skb->dev), data_skb, doff, + tproto, saddr, daddr, sport, dport, + indev); + + return sk; } EXPORT_SYMBOL(xt_socket_lookup_slow_v6);