bpf: Move synchronize_rcu_mult for batch processing
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index 587f2b2..cb0d811 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -31,6 +31,7 @@
 #include <linux/poll.h>
 #include <linux/bpf-netns.h>
 #include <linux/rcupdate_trace.h>
+#include <linux/rcupdate_wait.h>
 
 #define IS_FD_ARRAY(map) ((map)->map_type == BPF_MAP_TYPE_PERF_EVENT_ARRAY || \
 			  (map)->map_type == BPF_MAP_TYPE_CGROUP_ARRAY || \
@@ -2920,6 +2921,8 @@ static int bpf_trampoline_batch(const union bpf_attr *attr, int cmd)
 	if (!batch)
 		goto out_clean;
 
+	synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
+
 	for (i = 0; i < count; i++) {
 		if (cmd == BPF_TRAMPOLINE_BATCH_ATTACH) {
 			prog = bpf_prog_get(in[i]);
diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
index 14d0936..97be1856 100644
--- a/kernel/bpf/trampoline.c
+++ b/kernel/bpf/trampoline.c
@@ -271,7 +271,8 @@ static int bpf_trampoline_update(struct bpf_trampoline *tr,
 	 * programs finish executing.
 	 * Wait for these two grace periods together.
 	 */
-	synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
+	if (!batch)
+		synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
 
 	err = arch_prepare_bpf_trampoline(new_image, new_image + PAGE_SIZE / 2,
 					  &tr->func.model, flags, tprogs,