udp: Use hlist_nulls in UDP RCU code

This is a straightforward patch, using hlist_nulls infrastructure.

RCUification already done on UDP two weeks ago.

Using hlist_nulls permits us to avoid some memory barriers, both
at lookup time and delete time.

Patch is large because it adds new macros to include/net/sock.h.
These macros will be used by TCP & DCCP in next patch.

Signed-off-by: Eric Dumazet <dada1@cosmosbay.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/linux/rculist.h b/include/linux/rculist.h
index 3ba2998..e649bd3 100644
--- a/include/linux/rculist.h
+++ b/include/linux/rculist.h
@@ -383,22 +383,5 @@
 		({ tpos = hlist_entry(pos, typeof(*tpos), member); 1; }); \
 		pos = rcu_dereference(pos->next))
 
-/**
- * hlist_for_each_entry_rcu_safenext - iterate over rcu list of given type
- * @tpos:	the type * to use as a loop cursor.
- * @pos:	the &struct hlist_node to use as a loop cursor.
- * @head:	the head for your list.
- * @member:	the name of the hlist_node within the struct.
- * @next:       the &struct hlist_node to use as a next cursor
- *
- * Special version of hlist_for_each_entry_rcu that make sure
- * each next pointer is fetched before each iteration.
- */
-#define hlist_for_each_entry_rcu_safenext(tpos, pos, head, member, next) \
-	for (pos = rcu_dereference((head)->first);			 \
-		pos && ({ next = pos->next; smp_rmb(); prefetch(next); 1; }) &&	\
-		({ tpos = hlist_entry(pos, typeof(*tpos), member); 1; }); \
-		pos = rcu_dereference(next))
-
 #endif	/* __KERNEL__ */
 #endif
diff --git a/include/net/sock.h b/include/net/sock.h
index 8b2b821..0a63894 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -42,6 +42,7 @@
 
 #include <linux/kernel.h>
 #include <linux/list.h>
+#include <linux/list_nulls.h>
 #include <linux/timer.h>
 #include <linux/cache.h>
 #include <linux/module.h>
@@ -52,6 +53,7 @@
 #include <linux/security.h>
 
 #include <linux/filter.h>
+#include <linux/rculist_nulls.h>
 
 #include <asm/atomic.h>
 #include <net/dst.h>
@@ -106,6 +108,7 @@
  *	@skc_reuse: %SO_REUSEADDR setting
  *	@skc_bound_dev_if: bound device index if != 0
  *	@skc_node: main hash linkage for various protocol lookup tables
+ *	@skc_nulls_node: main hash linkage for UDP/UDP-Lite protocol
  *	@skc_bind_node: bind hash linkage for various protocol lookup tables
  *	@skc_refcnt: reference count
  *	@skc_hash: hash value used with various protocol lookup tables
@@ -120,7 +123,10 @@
 	volatile unsigned char	skc_state;
 	unsigned char		skc_reuse;
 	int			skc_bound_dev_if;
-	struct hlist_node	skc_node;
+	union {
+		struct hlist_node	skc_node;
+		struct hlist_nulls_node skc_nulls_node;
+	};
 	struct hlist_node	skc_bind_node;
 	atomic_t		skc_refcnt;
 	unsigned int		skc_hash;
@@ -206,6 +212,7 @@
 #define sk_reuse		__sk_common.skc_reuse
 #define sk_bound_dev_if		__sk_common.skc_bound_dev_if
 #define sk_node			__sk_common.skc_node
+#define sk_nulls_node		__sk_common.skc_nulls_node
 #define sk_bind_node		__sk_common.skc_bind_node
 #define sk_refcnt		__sk_common.skc_refcnt
 #define sk_hash			__sk_common.skc_hash
@@ -300,12 +307,30 @@
 	return hlist_empty(head) ? NULL : __sk_head(head);
 }
 
+static inline struct sock *__sk_nulls_head(const struct hlist_nulls_head *head)
+{
+	return hlist_nulls_entry(head->first, struct sock, sk_nulls_node);
+}
+
+static inline struct sock *sk_nulls_head(const struct hlist_nulls_head *head)
+{
+	return hlist_nulls_empty(head) ? NULL : __sk_nulls_head(head);
+}
+
 static inline struct sock *sk_next(const struct sock *sk)
 {
 	return sk->sk_node.next ?
 		hlist_entry(sk->sk_node.next, struct sock, sk_node) : NULL;
 }
 
+static inline struct sock *sk_nulls_next(const struct sock *sk)
+{
+	return (!is_a_nulls(sk->sk_nulls_node.next)) ?
+		hlist_nulls_entry(sk->sk_nulls_node.next,
+				  struct sock, sk_nulls_node) :
+		NULL;
+}
+
 static inline int sk_unhashed(const struct sock *sk)
 {
 	return hlist_unhashed(&sk->sk_node);
@@ -321,6 +346,11 @@
 	node->pprev = NULL;
 }
 
+static __inline__ void sk_nulls_node_init(struct hlist_nulls_node *node)
+{
+	node->pprev = NULL;
+}
+
 static __inline__ void __sk_del_node(struct sock *sk)
 {
 	__hlist_del(&sk->sk_node);
@@ -367,18 +397,18 @@
 	return rc;
 }
 
-static __inline__ int __sk_del_node_init_rcu(struct sock *sk)
+static __inline__ int __sk_nulls_del_node_init_rcu(struct sock *sk)
 {
 	if (sk_hashed(sk)) {
-		hlist_del_init_rcu(&sk->sk_node);
+		hlist_nulls_del_init_rcu(&sk->sk_nulls_node);
 		return 1;
 	}
 	return 0;
 }
 
-static __inline__ int sk_del_node_init_rcu(struct sock *sk)
+static __inline__ int sk_nulls_del_node_init_rcu(struct sock *sk)
 {
-	int rc = __sk_del_node_init_rcu(sk);
+	int rc = __sk_nulls_del_node_init_rcu(sk);
 
 	if (rc) {
 		/* paranoid for a while -acme */
@@ -399,15 +429,15 @@
 	__sk_add_node(sk, list);
 }
 
-static __inline__ void __sk_add_node_rcu(struct sock *sk, struct hlist_head *list)
+static __inline__ void __sk_nulls_add_node_rcu(struct sock *sk, struct hlist_nulls_head *list)
 {
-	hlist_add_head_rcu(&sk->sk_node, list);
+	hlist_nulls_add_head_rcu(&sk->sk_nulls_node, list);
 }
 
-static __inline__ void sk_add_node_rcu(struct sock *sk, struct hlist_head *list)
+static __inline__ void sk_nulls_add_node_rcu(struct sock *sk, struct hlist_nulls_head *list)
 {
 	sock_hold(sk);
-	__sk_add_node_rcu(sk, list);
+	__sk_nulls_add_node_rcu(sk, list);
 }
 
 static __inline__ void __sk_del_bind_node(struct sock *sk)
@@ -423,11 +453,16 @@
 
 #define sk_for_each(__sk, node, list) \
 	hlist_for_each_entry(__sk, node, list, sk_node)
-#define sk_for_each_rcu_safenext(__sk, node, list, next) \
-	hlist_for_each_entry_rcu_safenext(__sk, node, list, sk_node, next)
+#define sk_nulls_for_each(__sk, node, list) \
+	hlist_nulls_for_each_entry(__sk, node, list, sk_nulls_node)
+#define sk_nulls_for_each_rcu(__sk, node, list) \
+	hlist_nulls_for_each_entry_rcu(__sk, node, list, sk_nulls_node)
 #define sk_for_each_from(__sk, node) \
 	if (__sk && ({ node = &(__sk)->sk_node; 1; })) \
 		hlist_for_each_entry_from(__sk, node, sk_node)
+#define sk_nulls_for_each_from(__sk, node) \
+	if (__sk && ({ node = &(__sk)->sk_nulls_node; 1; })) \
+		hlist_nulls_for_each_entry_from(__sk, node, sk_nulls_node)
 #define sk_for_each_continue(__sk, node) \
 	if (__sk && ({ node = &(__sk)->sk_node; 1; })) \
 		hlist_for_each_entry_continue(__sk, node, sk_node)
diff --git a/include/net/udp.h b/include/net/udp.h
index df2bfe5..90e6ce5 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -51,7 +51,7 @@
 #define UDP_SKB_CB(__skb)	((struct udp_skb_cb *)((__skb)->cb))
 
 struct udp_hslot {
-	struct hlist_head	head;
+	struct hlist_nulls_head	head;
 	spinlock_t		lock;
 } __attribute__((aligned(2 * sizeof(long))));
 struct udp_table {
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 54badc9..fea2d87 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -127,9 +127,9 @@
 						 const struct sock *sk2))
 {
 	struct sock *sk2;
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 
-	sk_for_each(sk2, node, &hslot->head)
+	sk_nulls_for_each(sk2, node, &hslot->head)
 		if (net_eq(sock_net(sk2), net)			&&
 		    sk2 != sk					&&
 		    sk2->sk_hash == num				&&
@@ -189,12 +189,7 @@
 	inet_sk(sk)->num = snum;
 	sk->sk_hash = snum;
 	if (sk_unhashed(sk)) {
-		/*
-		 * We need that previous write to sk->sk_hash committed
-		 * before write to sk->next done in following add_node() variant
-		 */
-		smp_wmb();
-		sk_add_node_rcu(sk, &hslot->head);
+		sk_nulls_add_node_rcu(sk, &hslot->head);
 		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
 	}
 	error = 0;
@@ -261,7 +256,7 @@
 		int dif, struct udp_table *udptable)
 {
 	struct sock *sk, *result;
-	struct hlist_node *node, *next;
+	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
 	unsigned int hash = udp_hashfn(net, hnum);
 	struct udp_hslot *hslot = &udptable->hash[hash];
@@ -271,13 +266,7 @@
 begin:
 	result = NULL;
 	badness = -1;
-	sk_for_each_rcu_safenext(sk, node, &hslot->head, next) {
-		/*
-		 * lockless reader, and SLAB_DESTROY_BY_RCU items:
-		 * We must check this item was not moved to another chain
-		 */
-		if (udp_hashfn(net, sk->sk_hash) != hash)
-			goto begin;
+	sk_nulls_for_each_rcu(sk, node, &hslot->head) {
 		score = compute_score(sk, net, saddr, hnum, sport,
 				      daddr, dport, dif);
 		if (score > badness) {
@@ -285,6 +274,14 @@
 			badness = score;
 		}
 	}
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != hash)
+		goto begin;
+
 	if (result) {
 		if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
 			result = NULL;
@@ -325,11 +322,11 @@
 					     __be16 rmt_port, __be32 rmt_addr,
 					     int dif)
 {
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 	struct sock *s = sk;
 	unsigned short hnum = ntohs(loc_port);
 
-	sk_for_each_from(s, node) {
+	sk_nulls_for_each_from(s, node) {
 		struct inet_sock *inet = inet_sk(s);
 
 		if (!net_eq(sock_net(s), net)				||
@@ -977,7 +974,7 @@
 	struct udp_hslot *hslot = &udptable->hash[hash];
 
 	spin_lock_bh(&hslot->lock);
-	if (sk_del_node_init_rcu(sk)) {
+	if (sk_nulls_del_node_init_rcu(sk)) {
 		inet_sk(sk)->num = 0;
 		sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
 	}
@@ -1130,7 +1127,7 @@
 	int dif;
 
 	spin_lock(&hslot->lock);
-	sk = sk_head(&hslot->head);
+	sk = sk_nulls_head(&hslot->head);
 	dif = skb->dev->ifindex;
 	sk = udp_v4_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
 	if (sk) {
@@ -1139,7 +1136,7 @@
 		do {
 			struct sk_buff *skb1 = skb;
 
-			sknext = udp_v4_mcast_next(net, sk_next(sk), uh->dest,
+			sknext = udp_v4_mcast_next(net, sk_nulls_next(sk), uh->dest,
 						   daddr, uh->source, saddr,
 						   dif);
 			if (sknext)
@@ -1560,10 +1557,10 @@
 	struct net *net = seq_file_net(seq);
 
 	for (state->bucket = start; state->bucket < UDP_HTABLE_SIZE; ++state->bucket) {
-		struct hlist_node *node;
+		struct hlist_nulls_node *node;
 		struct udp_hslot *hslot = &state->udp_table->hash[state->bucket];
 		spin_lock_bh(&hslot->lock);
-		sk_for_each(sk, node, &hslot->head) {
+		sk_nulls_for_each(sk, node, &hslot->head) {
 			if (!net_eq(sock_net(sk), net))
 				continue;
 			if (sk->sk_family == state->family)
@@ -1582,7 +1579,7 @@
 	struct net *net = seq_file_net(seq);
 
 	do {
-		sk = sk_next(sk);
+		sk = sk_nulls_next(sk);
 	} while (sk && (!net_eq(sock_net(sk), net) || sk->sk_family != state->family));
 
 	if (!sk) {
@@ -1753,7 +1750,7 @@
 	int i;
 
 	for (i = 0; i < UDP_HTABLE_SIZE; i++) {
-		INIT_HLIST_HEAD(&table->hash[i].head);
+		INIT_HLIST_NULLS_HEAD(&table->hash[i].head, i);
 		spin_lock_init(&table->hash[i].lock);
 	}
 }
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 8dafa36..fd2d9ad4 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -98,7 +98,7 @@
 				      int dif, struct udp_table *udptable)
 {
 	struct sock *sk, *result;
-	struct hlist_node *node, *next;
+	struct hlist_nulls_node *node;
 	unsigned short hnum = ntohs(dport);
 	unsigned int hash = udp_hashfn(net, hnum);
 	struct udp_hslot *hslot = &udptable->hash[hash];
@@ -108,19 +108,21 @@
 begin:
 	result = NULL;
 	badness = -1;
-	sk_for_each_rcu_safenext(sk, node, &hslot->head, next) {
-		/*
-		 * lockless reader, and SLAB_DESTROY_BY_RCU items:
-		 * We must check this item was not moved to another chain
-		 */
-		if (udp_hashfn(net, sk->sk_hash) != hash)
-			goto begin;
+	sk_nulls_for_each_rcu(sk, node, &hslot->head) {
 		score = compute_score(sk, net, hnum, saddr, sport, daddr, dport, dif);
 		if (score > badness) {
 			result = sk;
 			badness = score;
 		}
 	}
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != hash)
+		goto begin;
+
 	if (result) {
 		if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
 			result = NULL;
@@ -374,11 +376,11 @@
 				      __be16 rmt_port, struct in6_addr *rmt_addr,
 				      int dif)
 {
-	struct hlist_node *node;
+	struct hlist_nulls_node *node;
 	struct sock *s = sk;
 	unsigned short num = ntohs(loc_port);
 
-	sk_for_each_from(s, node) {
+	sk_nulls_for_each_from(s, node) {
 		struct inet_sock *inet = inet_sk(s);
 
 		if (!net_eq(sock_net(s), net))
@@ -423,7 +425,7 @@
 	int dif;
 
 	spin_lock(&hslot->lock);
-	sk = sk_head(&hslot->head);
+	sk = sk_nulls_head(&hslot->head);
 	dif = inet6_iif(skb);
 	sk = udp_v6_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
 	if (!sk) {
@@ -432,7 +434,7 @@
 	}
 
 	sk2 = sk;
-	while ((sk2 = udp_v6_mcast_next(net, sk_next(sk2), uh->dest, daddr,
+	while ((sk2 = udp_v6_mcast_next(net, sk_nulls_next(sk2), uh->dest, daddr,
 					uh->source, saddr, dif))) {
 		struct sk_buff *buff = skb_clone(skb, GFP_ATOMIC);
 		if (buff) {