batman-adv: concentrate all curr_gw related rcu operations in select/deselect functions

Signed-off-by: Marek Lindner <lindner_marek@yahoo.de>
Signed-off-by: Sven Eckelmann <sven@narfation.org>
diff --git a/net/batman-adv/gateway_client.c b/net/batman-adv/gateway_client.c
index 42a8a7b..2acd7a66 100644
--- a/net/batman-adv/gateway_client.c
+++ b/net/batman-adv/gateway_client.c
@@ -43,61 +43,75 @@
 		call_rcu(&gw_node->rcu, gw_node_free_rcu);
 }
 
-struct orig_node *gw_get_selected(struct bat_priv *bat_priv)
+static struct gw_node *gw_get_selected_gw_node(struct bat_priv *bat_priv)
 {
-	struct gw_node *curr_gateway_tmp;
-	struct orig_node *orig_node = NULL;
+	struct gw_node *gw_node;
 
 	rcu_read_lock();
-	curr_gateway_tmp = rcu_dereference(bat_priv->curr_gw);
-	if (!curr_gateway_tmp)
+	gw_node = rcu_dereference(bat_priv->curr_gw);
+	if (!gw_node)
 		goto out;
 
-	orig_node = curr_gateway_tmp->orig_node;
-	if (!orig_node)
+	if (!atomic_inc_not_zero(&gw_node->refcount))
+		gw_node = NULL;
+
+out:
+	rcu_read_unlock();
+	return gw_node;
+}
+
+struct orig_node *gw_get_selected_orig(struct bat_priv *bat_priv)
+{
+	struct gw_node *gw_node;
+	struct orig_node *orig_node = NULL;
+
+	gw_node = gw_get_selected_gw_node(bat_priv);
+	if (!gw_node)
 		goto out;
 
+	rcu_read_lock();
+	orig_node = gw_node->orig_node;
+	if (!orig_node)
+		goto unlock;
+
 	if (!atomic_inc_not_zero(&orig_node->refcount))
 		orig_node = NULL;
 
-out:
+unlock:
 	rcu_read_unlock();
-	return orig_node;
-}
-
-void gw_deselect(struct bat_priv *bat_priv)
-{
-	struct gw_node *gw_node;
-
-	spin_lock_bh(&bat_priv->gw_list_lock);
-	gw_node = rcu_dereference(bat_priv->curr_gw);
-	rcu_assign_pointer(bat_priv->curr_gw, NULL);
-	spin_unlock_bh(&bat_priv->gw_list_lock);
-
+out:
 	if (gw_node)
 		gw_node_free_ref(gw_node);
+	return orig_node;
 }
 
 static void gw_select(struct bat_priv *bat_priv, struct gw_node *new_gw_node)
 {
 	struct gw_node *curr_gw_node;
 
+	spin_lock_bh(&bat_priv->gw_list_lock);
+
 	if (new_gw_node && !atomic_inc_not_zero(&new_gw_node->refcount))
 		new_gw_node = NULL;
 
-	spin_lock_bh(&bat_priv->gw_list_lock);
-	curr_gw_node = rcu_dereference(bat_priv->curr_gw);
+	curr_gw_node = bat_priv->curr_gw;
 	rcu_assign_pointer(bat_priv->curr_gw, new_gw_node);
-	spin_unlock_bh(&bat_priv->gw_list_lock);
 
 	if (curr_gw_node)
 		gw_node_free_ref(curr_gw_node);
+
+	spin_unlock_bh(&bat_priv->gw_list_lock);
+}
+
+void gw_deselect(struct bat_priv *bat_priv)
+{
+	gw_select(bat_priv, NULL);
 }
 
 void gw_election(struct bat_priv *bat_priv)
 {
 	struct hlist_node *node;
-	struct gw_node *gw_node, *curr_gw, *curr_gw_tmp = NULL;
+	struct gw_node *gw_node, *curr_gw = NULL, *curr_gw_tmp = NULL;
 	struct neigh_node *router;
 	uint8_t max_tq = 0;
 	uint32_t max_gw_factor = 0, tmp_gw_factor = 0;
@@ -112,25 +126,17 @@
 	if (atomic_read(&bat_priv->gw_mode) != GW_MODE_CLIENT)
 		return;
 
+	curr_gw = gw_get_selected_gw_node(bat_priv);
+	if (!curr_gw)
+		goto out;
+
 	rcu_read_lock();
-	curr_gw = rcu_dereference(bat_priv->curr_gw);
-	if (curr_gw) {
-		rcu_read_unlock();
-		return;
-	}
-
 	if (hlist_empty(&bat_priv->gw_list)) {
-
-		if (curr_gw) {
-			rcu_read_unlock();
-			bat_dbg(DBG_BATMAN, bat_priv,
-				"Removing selected gateway - "
-				"no gateway in range\n");
-			gw_deselect(bat_priv);
-		} else
-			rcu_read_unlock();
-
-		return;
+		bat_dbg(DBG_BATMAN, bat_priv,
+			"Removing selected gateway - "
+			"no gateway in range\n");
+		gw_deselect(bat_priv);
+		goto unlock;
 	}
 
 	hlist_for_each_entry_rcu(gw_node, node, &bat_priv->gw_list, list) {
@@ -182,7 +188,7 @@
 	if (curr_gw != curr_gw_tmp) {
 		router = orig_node_get_router(curr_gw_tmp->orig_node);
 		if (!router)
-			goto out;
+			goto unlock;
 
 		if ((curr_gw) && (!curr_gw_tmp))
 			bat_dbg(DBG_BATMAN, bat_priv,
@@ -207,8 +213,11 @@
 		gw_select(bat_priv, curr_gw_tmp);
 	}
 
-out:
+unlock:
 	rcu_read_unlock();
+out:
+	if (curr_gw)
+		gw_node_free_ref(curr_gw);
 }
 
 void gw_check_election(struct bat_priv *bat_priv, struct orig_node *orig_node)
@@ -217,7 +226,7 @@
 	struct neigh_node *router_gw = NULL, *router_orig = NULL;
 	uint8_t gw_tq_avg, orig_tq_avg;
 
-	curr_gw_orig = gw_get_selected(bat_priv);
+	curr_gw_orig = gw_get_selected_orig(bat_priv);
 	if (!curr_gw_orig)
 		goto deselect;
 
@@ -299,7 +308,11 @@
 		    struct orig_node *orig_node, uint8_t new_gwflags)
 {
 	struct hlist_node *node;
-	struct gw_node *gw_node;
+	struct gw_node *gw_node, *curr_gw;
+
+	curr_gw = gw_get_selected_gw_node(bat_priv);
+	if (!curr_gw)
+		goto out;
 
 	rcu_read_lock();
 	hlist_for_each_entry_rcu(gw_node, node, &bat_priv->gw_list, list) {
@@ -320,22 +333,26 @@
 				"Gateway %pM removed from gateway list\n",
 				orig_node->orig);
 
-			if (gw_node == rcu_dereference(bat_priv->curr_gw)) {
-				rcu_read_unlock();
-				gw_deselect(bat_priv);
-				return;
-			}
+			if (gw_node == curr_gw)
+				goto deselect;
 		}
 
-		rcu_read_unlock();
-		return;
+		goto unlock;
 	}
-	rcu_read_unlock();
 
 	if (new_gwflags == 0)
-		return;
+		goto unlock;
 
 	gw_node_add(bat_priv, orig_node, new_gwflags);
+	goto unlock;
+
+deselect:
+	gw_deselect(bat_priv);
+unlock:
+	rcu_read_unlock();
+out:
+	if (curr_gw)
+		gw_node_free_ref(curr_gw);
 }
 
 void gw_node_delete(struct bat_priv *bat_priv, struct orig_node *orig_node)
@@ -345,9 +362,12 @@
 
 void gw_node_purge(struct bat_priv *bat_priv)
 {
-	struct gw_node *gw_node;
+	struct gw_node *gw_node, *curr_gw;
 	struct hlist_node *node, *node_tmp;
 	unsigned long timeout = 2 * PURGE_TIMEOUT * HZ;
+	char do_deselect = 0;
+
+	curr_gw = gw_get_selected_gw_node(bat_priv);
 
 	spin_lock_bh(&bat_priv->gw_list_lock);
 
@@ -358,15 +378,21 @@
 		    atomic_read(&bat_priv->mesh_state) == MESH_ACTIVE)
 			continue;
 
-		if (rcu_dereference(bat_priv->curr_gw) == gw_node)
-			gw_deselect(bat_priv);
+		if (curr_gw == gw_node)
+			do_deselect = 1;
 
 		hlist_del_rcu(&gw_node->list);
 		gw_node_free_ref(gw_node);
 	}
 
-
 	spin_unlock_bh(&bat_priv->gw_list_lock);
+
+	/* gw_deselect() needs to acquire the gw_list_lock */
+	if (do_deselect)
+		gw_deselect(bat_priv);
+
+	if (curr_gw)
+		gw_node_free_ref(curr_gw);
 }
 
 /**
@@ -385,22 +411,22 @@
 	if (!router)
 		goto out;
 
-	rcu_read_lock();
-	curr_gw = rcu_dereference(bat_priv->curr_gw);
+	curr_gw = gw_get_selected_gw_node(bat_priv);
 
 	ret = seq_printf(seq, "%s %pM (%3i) %pM [%10s]: %3i - %i%s/%i%s\n",
-		       (curr_gw == gw_node ? "=>" : "  "),
-		       gw_node->orig_node->orig,
-		       router->tq_avg, router->addr,
-		       router->if_incoming->net_dev->name,
-		       gw_node->orig_node->gw_flags,
-		       (down > 2048 ? down / 1024 : down),
-		       (down > 2048 ? "MBit" : "KBit"),
-		       (up > 2048 ? up / 1024 : up),
-		       (up > 2048 ? "MBit" : "KBit"));
+			 (curr_gw == gw_node ? "=>" : "  "),
+			 gw_node->orig_node->orig,
+			 router->tq_avg, router->addr,
+			 router->if_incoming->net_dev->name,
+			 gw_node->orig_node->gw_flags,
+			 (down > 2048 ? down / 1024 : down),
+			 (down > 2048 ? "MBit" : "KBit"),
+			 (up > 2048 ? up / 1024 : up),
+			 (up > 2048 ? "MBit" : "KBit"));
 
-	rcu_read_unlock();
 	neigh_node_free_ref(router);
+	if (curr_gw)
+		gw_node_free_ref(curr_gw);
 out:
 	return ret;
 }
@@ -459,6 +485,7 @@
 	struct iphdr *iphdr;
 	struct ipv6hdr *ipv6hdr;
 	struct udphdr *udphdr;
+	struct gw_node *curr_gw;
 	unsigned int header_len = 0;
 
 	if (atomic_read(&bat_priv->gw_mode) == GW_MODE_OFF)
@@ -523,12 +550,11 @@
 	if (atomic_read(&bat_priv->gw_mode) == GW_MODE_SERVER)
 		return -1;
 
-	rcu_read_lock();
-	if (!rcu_dereference(bat_priv->curr_gw)) {
-		rcu_read_unlock();
+	curr_gw = gw_get_selected_gw_node(bat_priv);
+	if (!curr_gw)
 		return 0;
-	}
-	rcu_read_unlock();
 
+	if (curr_gw)
+		gw_node_free_ref(curr_gw);
 	return 1;
 }