devlink: Count struct devlink consumers

The struct devlink itself is protected by internal lock and doesn't
need global lock during operation. That global lock is used to protect
addition/removal new devlink instances from the global list in use by
all devlink consumers in the system.

The future conversion of linked list to be xarray will allow us to
actually delete that lock, but first we need to count all struct devlink
users.

The reference counting provides us a way to ensure that no new user
space commands success to grab devlink instance which is going to be
destroyed makes it is safe to access it without lock.

Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
diff --git a/include/net/devlink.h b/include/net/devlink.h
index b4691c4..7ecae2a 100644
--- a/include/net/devlink.h
+++ b/include/net/devlink.h
@@ -54,6 +54,8 @@
 	struct mutex lock; /* Serializes access to devlink instance specific objects such as
 			    * port, sb, dpipe, resource, params, region, traps and more.
 			    */
+	refcount_t refcount;
+	struct completion comp;
 	u8 reload_failed:1,
 	   reload_enabled:1;
 	char priv[0] __aligned(NETDEV_ALIGN);
diff --git a/net/core/devlink.c b/net/core/devlink.c
index b419b7a..c9d66c5 100644
--- a/net/core/devlink.c
+++ b/net/core/devlink.c
@@ -108,10 +108,22 @@
 }
 EXPORT_SYMBOL_GPL(devlink_net);
 
+static void devlink_put(struct devlink *devlink)
+{
+	if (refcount_dec_and_test(&devlink->refcount))
+		complete(&devlink->comp);
+}
+
+static bool __must_check devlink_try_get(struct devlink *devlink)
+{
+	return refcount_inc_not_zero(&devlink->refcount);
+}
+
 static struct devlink *devlink_get_from_attrs(struct net *net,
 					      struct nlattr **attrs)
 {
 	struct devlink *devlink;
+	bool found = false;
 	char *busname;
 	char *devname;
 
@@ -126,16 +138,16 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (strcmp(devlink->dev->bus->name, busname) == 0 &&
 		    strcmp(dev_name(devlink->dev), devname) == 0 &&
-		    net_eq(devlink_net(devlink), net))
-			return devlink;
+		    net_eq(devlink_net(devlink), net)) {
+			found = true;
+			break;
+		}
 	}
 
-	return ERR_PTR(-ENODEV);
-}
+	if (!found || !devlink_try_get(devlink))
+		devlink = ERR_PTR(-ENODEV);
 
-static struct devlink *devlink_get_from_info(struct genl_info *info)
-{
-	return devlink_get_from_attrs(genl_info_net(info), info->attrs);
+	return devlink;
 }
 
 static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink,
@@ -486,7 +498,7 @@
 	int err;
 
 	mutex_lock(&devlink_mutex);
-	devlink = devlink_get_from_info(info);
+	devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs);
 	if (IS_ERR(devlink)) {
 		mutex_unlock(&devlink_mutex);
 		return PTR_ERR(devlink);
@@ -529,6 +541,7 @@
 unlock:
 	if (~ops->internal_flags & DEVLINK_NL_FLAG_NO_LOCK)
 		mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 	return err;
 }
@@ -541,6 +554,7 @@
 	devlink = info->user_ptr[0];
 	if (~ops->internal_flags & DEVLINK_NL_FLAG_NO_LOCK)
 		mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 }
 
@@ -1088,6 +1102,10 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
 			enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
@@ -1104,11 +1122,13 @@
 						   NLM_F_MULTI, NULL);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -1187,13 +1207,20 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		if (idx < start) {
 			idx++;
+			devlink_put(devlink);
 			continue;
 		}
+
 		err = devlink_nl_fill(msg, devlink, DEVLINK_CMD_NEW,
 				      NETLINK_CB(cb->skb).portid,
 				      cb->nlh->nlmsg_seq, NLM_F_MULTI);
+		devlink_put(devlink);
 		if (err)
 			goto out;
 		idx++;
@@ -1242,6 +1269,10 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			if (idx < start) {
@@ -1256,11 +1287,13 @@
 						   cb->extack);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -1902,6 +1935,10 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			if (idx < start) {
@@ -1915,11 +1952,13 @@
 						 NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2047,6 +2086,10 @@
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_pool_get)
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			err = __sb_pool_get_dumpit(msg, start, &idx, devlink,
@@ -2057,10 +2100,12 @@
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2260,6 +2305,10 @@
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
 		    !devlink->ops->sb_port_pool_get)
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			err = __sb_port_pool_get_dumpit(msg, start, &idx,
@@ -2270,10 +2319,12 @@
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -2502,6 +2553,9 @@
 		    !devlink->ops->sb_tc_pool_bind_get)
 			continue;
 
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
 			err = __sb_tc_pool_bind_get_dumpit(msg, start, &idx,
@@ -2513,10 +2567,12 @@
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -4555,6 +4611,10 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(param_item, &devlink->param_list, list) {
 			if (idx < start) {
@@ -4570,11 +4630,13 @@
 				err = 0;
 			} else if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -4823,6 +4885,10 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(devlink_port, &devlink->port_list, list) {
 			list_for_each_entry(param_item,
@@ -4842,12 +4908,14 @@
 					err = 0;
 				} else if (err) {
 					mutex_unlock(&devlink->lock);
+					devlink_put(devlink);
 					goto out;
 				}
 				idx++;
 			}
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -5392,8 +5460,13 @@
 	list_for_each_entry(devlink, &devlink_list, list) {
 		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
 			continue;
+
+		if (!devlink_try_get(devlink))
+			continue;
+
 		err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink,
 							       &idx, start);
+		devlink_put(devlink);
 		if (err)
 			goto out;
 	}
@@ -5756,6 +5829,7 @@
 	nla_nest_end(skb, chunks_attr);
 	genlmsg_end(skb, hdr);
 	mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 
 	return skb->len;
@@ -5764,6 +5838,7 @@
 	genlmsg_cancel(skb, hdr);
 out_unlock:
 	mutex_unlock(&devlink->lock);
+	devlink_put(devlink);
 out_dev:
 	mutex_unlock(&devlink_mutex);
 	return err;
@@ -5939,6 +6014,7 @@
 			break;
 		idx++;
 	}
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 
 	if (err != -EMSGSIZE)
@@ -7023,6 +7099,7 @@
 		goto unlock;
 
 	reporter = devlink_health_reporter_get_from_attrs(devlink, attrs);
+	devlink_put(devlink);
 	mutex_unlock(&devlink_mutex);
 	return reporter;
 unlock:
@@ -7094,8 +7171,14 @@
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
+			devlink_put(devlink);
+			continue;
+		}
+
 		mutex_lock(&devlink->reporters_lock);
 		list_for_each_entry(reporter, &devlink->reporter_list,
 				    list) {
@@ -7111,16 +7194,24 @@
 							      NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->reporters_lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->reporters_lock);
+		devlink_put(devlink);
 	}
 
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
+			devlink_put(devlink);
+			continue;
+		}
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(port, &devlink->port_list, list) {
 			mutex_lock(&port->reporters_lock);
@@ -7137,6 +7228,7 @@
 				if (err) {
 					mutex_unlock(&port->reporters_lock);
 					mutex_unlock(&devlink->lock);
+					devlink_put(devlink);
 					goto out;
 				}
 				idx++;
@@ -7144,6 +7236,7 @@
 			mutex_unlock(&port->reporters_lock);
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -7677,8 +7770,14 @@
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
+			devlink_put(devlink);
+			continue;
+		}
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(trap_item, &devlink->trap_list, list) {
 			if (idx < start) {
@@ -7692,11 +7791,13 @@
 						   NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -8202,8 +8303,14 @@
 
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
-		if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
+		if (!devlink_try_get(devlink))
 			continue;
+
+		if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
+			devlink_put(devlink);
+			continue;
+		}
+
 		mutex_lock(&devlink->lock);
 		list_for_each_entry(policer_item, &devlink->trap_policer_list,
 				    list) {
@@ -8218,11 +8325,13 @@
 							   NLM_F_MULTI);
 			if (err) {
 				mutex_unlock(&devlink->lock);
+				devlink_put(devlink);
 				goto out;
 			}
 			idx++;
 		}
 		mutex_unlock(&devlink->lock);
+		devlink_put(devlink);
 	}
 out:
 	mutex_unlock(&devlink_mutex);
@@ -8803,6 +8912,9 @@
 	INIT_LIST_HEAD(&devlink->trap_policer_list);
 	mutex_init(&devlink->lock);
 	mutex_init(&devlink->reporters_lock);
+	refcount_set(&devlink->refcount, 1);
+	init_completion(&devlink->comp);
+
 	return devlink;
 }
 EXPORT_SYMBOL_GPL(devlink_alloc_ns);
@@ -8832,6 +8944,9 @@
  */
 void devlink_unregister(struct devlink *devlink)
 {
+	devlink_put(devlink);
+	wait_for_completion(&devlink->comp);
+
 	mutex_lock(&devlink_mutex);
 	WARN_ON(devlink_reload_supported(devlink->ops) &&
 		devlink->reload_enabled);
@@ -11288,9 +11403,13 @@
 	 */
 	mutex_lock(&devlink_mutex);
 	list_for_each_entry(devlink, &devlink_list, list) {
+		if (!devlink_try_get(devlink))
+			continue;
+
 		if (net_eq(devlink_net(devlink), net)) {
 			if (WARN_ON(!devlink_reload_supported(devlink->ops)))
-				continue;
+				goto retry;
+
 			err = devlink_reload(devlink, &init_net,
 					     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
 					     DEVLINK_RELOAD_LIMIT_UNSPEC,
@@ -11298,6 +11417,8 @@
 			if (err && err != -EOPNOTSUPP)
 				pr_warn("Failed to reload devlink instance into init_net\n");
 		}
+retry:
+		devlink_put(devlink);
 	}
 	mutex_unlock(&devlink_mutex);
 }