net/mlx5e: Handle netdev events in IPsec policy

Update the IPsec policy rules if MAC address is changing.

Reviewed-by: Cosmin Ratiu <cratiu@nvidia.com>
Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.c
index dfae5be..c570164 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.c
@@ -45,7 +45,7 @@
 #include "en_rep.h"
 
 #define MLX5_IPSEC_RESCHED msecs_to_jiffies(1000)
-#define MLX5E_IPSEC_TUNNEL_SA XA_MARK_1
+#define MLX5E_IPSEC_NEIGH XA_MARK_1
 
 static struct mlx5e_ipsec_sa_entry *to_ipsec_sa_entry(struct xfrm_state *x)
 {
@@ -671,7 +671,7 @@ static void mlx5e_ipsec_set_esn_ops(struct mlx5e_ipsec_sa_entry *sa_entry)
 	sa_entry->set_iv_op = mlx5e_ipsec_set_iv;
 }
 
-static void mlx5e_ipsec_handle_netdev_event(struct work_struct *_work)
+static void mlx5e_ipsec_handle_sa_netdev_event(struct work_struct *_work)
 {
 	struct mlx5e_ipsec_work *work =
 		container_of(_work, struct mlx5e_ipsec_work, work);
@@ -695,7 +695,7 @@ static void mlx5e_ipsec_handle_netdev_event(struct work_struct *_work)
 	mlx5e_accel_ipsec_fs_modify(sa_entry);
 }
 
-static int mlx5_ipsec_create_work(struct mlx5e_ipsec_sa_entry *sa_entry)
+static int mlx5_ipsec_create_sa_work(struct mlx5e_ipsec_sa_entry *sa_entry)
 {
 	struct xfrm_state *x = sa_entry->x;
 	struct mlx5e_ipsec_work *work;
@@ -732,7 +732,7 @@ static int mlx5_ipsec_create_work(struct mlx5e_ipsec_sa_entry *sa_entry)
 		if (!data)
 			goto free_work;
 
-		INIT_WORK(&work->work, mlx5e_ipsec_handle_netdev_event);
+		INIT_WORK(&work->work, mlx5e_ipsec_handle_sa_netdev_event);
 		break;
 	default:
 		break;
@@ -748,6 +748,54 @@ static int mlx5_ipsec_create_work(struct mlx5e_ipsec_sa_entry *sa_entry)
 	return -ENOMEM;
 }
 
+static void mlx5e_ipsec_handle_pol_netdev_event(struct work_struct *_work)
+{
+	struct mlx5e_ipsec_work *work =
+		container_of(_work, struct mlx5e_ipsec_work, work);
+	struct mlx5e_ipsec_pol_entry *pol_entry = work->pol_entry;
+	struct mlx5e_ipsec_netevent_data *data = work->data;
+	struct mlx5_accel_pol_xfrm_attrs *attrs;
+
+	attrs = &pol_entry->attrs;
+	WARN_ON_ONCE(attrs->dir != XFRM_DEV_OFFLOAD_IN ||
+		     attrs->type != XFRM_DEV_OFFLOAD_PACKET ||
+		     attrs->mode != XFRM_MODE_TUNNEL);
+
+	ether_addr_copy(attrs->addrs.dmac, data->addr);
+	attrs->drop = false;
+	mlx5e_ipsec_fs_pol_modify(pol_entry);
+}
+
+static int mlx5_ipsec_create_pol_work(struct mlx5e_ipsec_pol_entry *pol_entry)
+{
+	struct mlx5_accel_pol_xfrm_attrs *attrs = &pol_entry->attrs;
+	struct mlx5e_ipsec_work *work;
+	void *data = NULL;
+
+	if (attrs->mode != XFRM_MODE_TUNNEL ||
+	    attrs->dir != XFRM_DEV_OFFLOAD_IN)
+		return 0;
+
+	work = kzalloc(sizeof(*work), GFP_KERNEL);
+	if (!work)
+		return -ENOMEM;
+
+	data = kzalloc(sizeof(struct mlx5e_ipsec_netevent_data), GFP_KERNEL);
+	if (!data)
+		goto free_work;
+
+	INIT_WORK(&work->work, mlx5e_ipsec_handle_pol_netdev_event);
+
+	work->data = data;
+	work->pol_entry = pol_entry;
+	pol_entry->work = work;
+	return 0;
+
+free_work:
+	kfree(work);
+	return -ENOMEM;
+}
+
 static int mlx5e_ipsec_create_dwork(struct mlx5e_ipsec_sa_entry *sa_entry)
 {
 	struct xfrm_state *x = sa_entry->x;
@@ -820,7 +868,7 @@ static int mlx5e_xfrm_add_state(struct net_device *dev,
 
 	mlx5e_ipsec_build_accel_xfrm_attrs(sa_entry, &sa_entry->attrs);
 
-	err = mlx5_ipsec_create_work(sa_entry);
+	err = mlx5_ipsec_create_sa_work(sa_entry);
 	if (err)
 		goto unblock_ipsec;
 
@@ -860,11 +908,12 @@ static int mlx5e_xfrm_add_state(struct net_device *dev,
 		queue_delayed_work(ipsec->wq, &sa_entry->dwork->dwork,
 				   MLX5_IPSEC_RESCHED);
 
-	if (x->xso.type == XFRM_DEV_OFFLOAD_PACKET &&
-	    x->props.mode == XFRM_MODE_TUNNEL) {
+	if (sa_entry->attrs.type == XFRM_DEV_OFFLOAD_PACKET &&
+	    sa_entry->attrs.mode == XFRM_MODE_TUNNEL &&
+	    sa_entry->attrs.dir == XFRM_DEV_OFFLOAD_OUT) {
 		xa_lock_bh(&ipsec->sadb);
 		__xa_set_mark(&ipsec->sadb, sa_entry->ipsec_obj_id,
-			      MLX5E_IPSEC_TUNNEL_SA);
+			      MLX5E_IPSEC_NEIGH);
 		xa_unlock_bh(&ipsec->sadb);
 	}
 
@@ -928,12 +977,32 @@ static void mlx5e_xfrm_free_state(struct net_device *dev, struct xfrm_state *x)
 	kfree(sa_entry);
 }
 
+static bool mlx5e_ipsec_neigh_addr_equal(const struct mlx5e_ipsec_addr *addrs,
+					 const struct neighbour *n)
+{
+	if (addrs->family == AF_INET) {
+		if (!neigh_key_eq32(n, &addrs->saddr.a4) &&
+		    !neigh_key_eq32(n, &addrs->daddr.a4))
+			return false;
+
+		return true;
+	}
+
+	/* IPv6 */
+	if (!neigh_key_eq128(n, &addrs->saddr.a4) &&
+	    !neigh_key_eq128(n, &addrs->daddr.a4))
+		return false;
+
+	return true;
+}
+
 static int mlx5e_ipsec_netevent_event(struct notifier_block *nb,
 				      unsigned long event, void *ptr)
 {
-	struct mlx5_accel_esp_xfrm_attrs *attrs;
+	struct mlx5e_ipsec_pol_entry *pol_entry;
 	struct mlx5e_ipsec_netevent_data *data;
 	struct mlx5e_ipsec_sa_entry *sa_entry;
+	struct mlx5e_ipsec_addr *addrs;
 	struct mlx5e_ipsec *ipsec;
 	struct neighbour *n = ptr;
 	unsigned long idx;
@@ -942,18 +1011,11 @@ static int mlx5e_ipsec_netevent_event(struct notifier_block *nb,
 		return NOTIFY_DONE;
 
 	ipsec = container_of(nb, struct mlx5e_ipsec, netevent_nb);
-	xa_for_each_marked(&ipsec->sadb, idx, sa_entry, MLX5E_IPSEC_TUNNEL_SA) {
-		attrs = &sa_entry->attrs;
+	xa_for_each_marked(&ipsec->sadb, idx, sa_entry, MLX5E_IPSEC_NEIGH) {
+		addrs = &sa_entry->attrs.addrs;
 
-		if (attrs->addrs.family == AF_INET) {
-			if (!neigh_key_eq32(n, &attrs->addrs.saddr.a4) &&
-			    !neigh_key_eq32(n, &attrs->addrs.daddr.a4))
-				continue;
-		} else {
-			if (!neigh_key_eq128(n, &attrs->addrs.saddr.a4) &&
-			    !neigh_key_eq128(n, &attrs->addrs.daddr.a4))
-				continue;
-		}
+		if (!mlx5e_ipsec_neigh_addr_equal(addrs, n))
+			continue;
 
 		data = sa_entry->work->data;
 
@@ -961,6 +1023,18 @@ static int mlx5e_ipsec_netevent_event(struct notifier_block *nb,
 		queue_work(ipsec->wq, &sa_entry->work->work);
 	}
 
+	xa_for_each_marked(&ipsec->poldb, idx, pol_entry, MLX5E_IPSEC_NEIGH) {
+		addrs = &pol_entry->attrs.addrs;
+
+		if (!mlx5e_ipsec_neigh_addr_equal(addrs, n))
+			continue;
+
+		data = pol_entry->work->data;
+
+		neigh_ha_snapshot(data->addr, n, pol_entry->x->xdo.dev);
+		queue_work(ipsec->wq, &pol_entry->work->work);
+	}
+
 	return NOTIFY_DONE;
 }
 
@@ -979,6 +1053,7 @@ void mlx5e_ipsec_init(struct mlx5e_priv *priv)
 		return;
 
 	xa_init_flags(&ipsec->sadb, XA_FLAGS_ALLOC);
+	xa_init_flags(&ipsec->poldb, XA_FLAGS_ALLOC);
 	ipsec->mdev = priv->mdev;
 	init_completion(&ipsec->comp);
 	ipsec->wq = alloc_workqueue("mlx5e_ipsec: %s", WQ_UNBOUND, 0,
@@ -1249,6 +1324,10 @@ mlx5e_ipsec_build_accel_pol_attrs(struct mlx5e_ipsec_pol_entry *pol_entry,
 	attrs->upspec.proto = sel->proto;
 	attrs->prio = x->priority;
 	attrs->mode = x->xfrm_vec[0].mode;
+
+	if (attrs->mode == XFRM_MODE_TUNNEL)
+		attrs->drop = mlx5e_ipsec_init_macs(x->xdo.dev, &attrs->addrs,
+						    &attrs->upspec, attrs->dir);
 }
 
 static int mlx5e_xfrm_add_policy(struct xfrm_policy *x,
@@ -1256,6 +1335,7 @@ static int mlx5e_xfrm_add_policy(struct xfrm_policy *x,
 {
 	struct net_device *netdev = x->xdo.dev;
 	struct mlx5e_ipsec_pol_entry *pol_entry;
+	struct mlx5e_ipsec *ipsec;
 	struct mlx5e_priv *priv;
 	int err;
 
@@ -1274,7 +1354,8 @@ static int mlx5e_xfrm_add_policy(struct xfrm_policy *x,
 		return -ENOMEM;
 
 	pol_entry->x = x;
-	pol_entry->ipsec = priv->ipsec;
+	ipsec = priv->ipsec;
+	pol_entry->ipsec = ipsec;
 
 	if (!mlx5_eswitch_block_ipsec(priv->mdev)) {
 		err = -EBUSY;
@@ -1282,14 +1363,32 @@ static int mlx5e_xfrm_add_policy(struct xfrm_policy *x,
 	}
 
 	mlx5e_ipsec_build_accel_pol_attrs(pol_entry, &pol_entry->attrs);
+	err = xa_alloc(&ipsec->poldb, &pol_entry->idx, pol_entry, xa_limit_32b, GFP_KERNEL);
+	if (err)
+		goto unblock_ipsec;
+
+	err = mlx5_ipsec_create_pol_work(pol_entry);
+	if (err)
+		goto release_idx;
+
 	err = mlx5e_accel_ipsec_fs_add_pol(pol_entry);
 	if (err)
-		goto err_fs;
+		goto release_work;
+
+	if (pol_entry->attrs.mode == XFRM_MODE_TUNNEL &&
+	    pol_entry->attrs.dir == XFRM_DEV_OFFLOAD_IN)
+		xa_set_mark(&ipsec->poldb, pol_entry->idx, MLX5E_IPSEC_NEIGH);
 
 	x->xdo.offload_handle = (unsigned long)pol_entry;
 	return 0;
 
-err_fs:
+release_work:
+	if (pol_entry->work)
+		kfree(pol_entry->work->data);
+	kfree(pol_entry->work);
+release_idx:
+	xa_erase(&ipsec->poldb, pol_entry->idx);
+unblock_ipsec:
 	mlx5_eswitch_unblock_ipsec(priv->mdev);
 ipsec_busy:
 	kfree(pol_entry);
@@ -1300,8 +1399,16 @@ static int mlx5e_xfrm_add_policy(struct xfrm_policy *x,
 static void mlx5e_xfrm_del_policy(struct xfrm_policy *x)
 {
 	struct mlx5e_ipsec_pol_entry *pol_entry = to_ipsec_pol_entry(x);
+	struct mlx5e_ipsec *ipsec = pol_entry->ipsec;
+
+	xa_erase(&ipsec->poldb, pol_entry->idx);
+	if (pol_entry->work)
+		cancel_work_sync(&pol_entry->work->work);
 
 	mlx5e_accel_ipsec_fs_del_pol(pol_entry);
+	if (pol_entry->work)
+		kfree(pol_entry->work->data);
+	kfree(pol_entry->work);
 	mlx5_eswitch_unblock_ipsec(pol_entry->ipsec->mdev);
 }
 
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.h b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.h
index 1be007f..7bc8a89 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.h
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec.h
@@ -158,7 +158,10 @@ struct mlx5e_ipsec_tx;
 
 struct mlx5e_ipsec_work {
 	struct work_struct work;
-	struct mlx5e_ipsec_sa_entry *sa_entry;
+	union {
+		struct mlx5e_ipsec_sa_entry *sa_entry;
+		struct mlx5e_ipsec_pol_entry *pol_entry;
+	};
 	void *data;
 };
 
@@ -240,6 +243,7 @@ struct mlx5e_ipsec_mpv_work {
 struct mlx5e_ipsec {
 	struct mlx5_core_dev *mdev;
 	struct xarray sadb;
+	struct xarray poldb;
 	struct mlx5e_ipsec_sw_stats sw_stats;
 	struct mlx5e_ipsec_hw_stats hw_stats;
 	struct workqueue_struct *wq;
@@ -292,6 +296,7 @@ struct mlx5_accel_pol_xfrm_attrs {
 	struct mlx5e_ipsec_addr addrs;
 	struct upspec upspec;
 	u8 action;
+	u8 drop : 1;
 	u8 mode : 1;
 	u8 type : 2;
 	u8 dir : 2;
@@ -302,8 +307,10 @@ struct mlx5_accel_pol_xfrm_attrs {
 struct mlx5e_ipsec_pol_entry {
 	struct xfrm_policy *x;
 	struct mlx5e_ipsec *ipsec;
+	struct mlx5e_ipsec_work *work;
 	struct mlx5e_ipsec_rule ipsec_rule;
 	struct mlx5_accel_pol_xfrm_attrs attrs;
+	u32 idx;
 };
 
 #ifdef CONFIG_MLX5_EN_IPSEC
@@ -319,6 +326,7 @@ void mlx5e_accel_ipsec_fs_del_rule(struct mlx5e_ipsec_sa_entry *sa_entry);
 int mlx5e_accel_ipsec_fs_add_pol(struct mlx5e_ipsec_pol_entry *pol_entry);
 void mlx5e_accel_ipsec_fs_del_pol(struct mlx5e_ipsec_pol_entry *pol_entry);
 void mlx5e_accel_ipsec_fs_modify(struct mlx5e_ipsec_sa_entry *sa_entry);
+void mlx5e_ipsec_fs_pol_modify(struct mlx5e_ipsec_pol_entry *pol_entry);
 bool mlx5e_ipsec_fs_tunnel_enabled(struct mlx5e_ipsec_sa_entry *sa_entry);
 
 int mlx5_ipsec_create_sa_ctx(struct mlx5e_ipsec_sa_entry *sa_entry);
diff --git a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_fs.c b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_fs.c
index fa36dc7..70b46f5 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_fs.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/en_accel/ipsec_fs.c
@@ -2373,20 +2373,13 @@ static int rx_add_policy(struct mlx5e_ipsec_pol_entry *pol_entry)
 		break;
 	}
 
-	switch (attrs->action) {
-	case XFRM_POLICY_ALLOW:
+	if (attrs->action == XFRM_POLICY_ALLOW && !attrs->drop) {
 		flow_act.action |= MLX5_FLOW_CONTEXT_ACTION_FWD_DEST;
-		break;
-	case XFRM_POLICY_BLOCK:
+	} else {
 		flow_act.action |= MLX5_FLOW_CONTEXT_ACTION_DROP | MLX5_FLOW_CONTEXT_ACTION_COUNT;
 		dest[dstn].type = MLX5_FLOW_DESTINATION_TYPE_COUNTER;
 		dest[dstn].counter = rx->fc->drop;
 		dstn++;
-		break;
-	default:
-		WARN_ON(true);
-		err = -EINVAL;
-		goto err_action;
 	}
 
 	flow_act.flags |= FLOW_ACT_NO_APPEND;
@@ -2861,6 +2854,23 @@ void mlx5e_accel_ipsec_fs_modify(struct mlx5e_ipsec_sa_entry *sa_entry)
 	memcpy(sa_entry, &sa_entry_shadow, sizeof(*sa_entry));
 }
 
+void mlx5e_ipsec_fs_pol_modify(struct mlx5e_ipsec_pol_entry *pol_entry)
+{
+	struct mlx5e_ipsec_pol_entry pol_entry_shadow = {};
+	int err;
+
+	memcpy(&pol_entry_shadow, pol_entry, sizeof(*pol_entry));
+	memset(&pol_entry_shadow.ipsec_rule, 0x00,
+	       sizeof(pol_entry->ipsec_rule));
+
+	err = mlx5e_accel_ipsec_fs_add_pol(&pol_entry_shadow);
+	if (err)
+		return;
+
+	mlx5e_accel_ipsec_fs_del_pol(pol_entry);
+	memcpy(pol_entry, &pol_entry_shadow, sizeof(*pol_entry));
+}
+
 bool mlx5e_ipsec_fs_tunnel_enabled(struct mlx5e_ipsec_sa_entry *sa_entry)
 {
 	struct mlx5_accel_esp_xfrm_attrs *attrs = &sa_entry->attrs;