bpf: Move synchronize_rcu_mult for batch processing
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index 587f2b2..cb0d811 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -31,6 +31,7 @@
#include <linux/poll.h>
#include <linux/bpf-netns.h>
#include <linux/rcupdate_trace.h>
+#include <linux/rcupdate_wait.h>
#define IS_FD_ARRAY(map) ((map)->map_type == BPF_MAP_TYPE_PERF_EVENT_ARRAY || \
(map)->map_type == BPF_MAP_TYPE_CGROUP_ARRAY || \
@@ -2920,6 +2921,8 @@ static int bpf_trampoline_batch(const union bpf_attr *attr, int cmd)
if (!batch)
goto out_clean;
+ synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
+
for (i = 0; i < count; i++) {
if (cmd == BPF_TRAMPOLINE_BATCH_ATTACH) {
prog = bpf_prog_get(in[i]);
diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
index 14d0936..97be1856 100644
--- a/kernel/bpf/trampoline.c
+++ b/kernel/bpf/trampoline.c
@@ -271,7 +271,8 @@ static int bpf_trampoline_update(struct bpf_trampoline *tr,
* programs finish executing.
* Wait for these two grace periods together.
*/
- synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
+ if (!batch)
+ synchronize_rcu_mult(call_rcu_tasks, call_rcu_tasks_trace);
err = arch_prepare_bpf_trampoline(new_image, new_image + PAGE_SIZE / 2,
&tr->func.model, flags, tprogs,