must_write

minimal effect

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 08a81e2..e40c074 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -245,6 +245,18 @@ static inline void spis_or(u64 dst[2], const u64 src[2])
 	dst[1] |= src[1];
 }
 
+static inline void spis_and(u64 dst[2], const u64 src[2])
+{
+	dst[0] &= src[0];
+	dst[1] &= src[1];
+}
+
+static inline void spis_andnot(u64 dst[2], const u64 src[2])
+{
+	dst[0] &= ~src[0];
+	dst[1] &= ~src[1];
+}
+
 static inline void spis_set_all(u64 spis[2])
 {
 	spis[0] = U64_MAX;
@@ -817,6 +829,7 @@ struct bpf_subprog_info {
 struct subprog_arg_access {
 	u64 read[NUM_AT_IDS][2];   /* 4-byte slot bitmask: callee reads */
 	u64 write[NUM_AT_IDS][2];  /* 4-byte slot bitmask: callee writes */
+	u64 must_write[NUM_AT_IDS][2]; /* 4-byte slot bitmask: callee MUST write */
 	u32 unknown_args;
 	/*
 	 * Per-instruction backward liveness of argument reads.
diff --git a/kernel/bpf/liveness.c b/kernel/bpf/liveness.c
index 321c41a..6f694fd 100644
--- a/kernel/bpf/liveness.c
+++ b/kernel/bpf/liveness.c
@@ -341,12 +341,14 @@ static void print_subprog_arg_access(struct bpf_verifier_env *env,
 	for (i = 1; i < NUM_AT_IDS && i <= nr_args; i++) {
 		u64 *r = access->read[i];
 		u64 *w = access->write[i];
+		u64 *mw = access->must_write[i];
 
 		if (spis_is_all(r) || spis_is_all(w))
 			verbose(env, "  r%d: all (conservative)\n", i);
 		else if (!spis_is_zero(r) || !spis_is_zero(w))
-			verbose(env, "  r%d read: 0x%llx:%llx  write: 0x%llx:%llx\n",
-				i, r[1], r[0], w[1], w[0]);
+			verbose(env, "  r%d r:0x%llx:%llx w:0x%llx:%llx mw:0x%llx:%llx\n",
+				i, r[1], r[0], w[1], w[0],
+				mw[1], mw[0]);
 	}
 }
 
@@ -902,6 +904,63 @@ static int compute_arg_live(struct bpf_verifier_env *env,
 	return 0;
 }
 
+/*
+ * Extract per-arg write bitmask for a single instruction.
+ * Only slot-aligned writes that fully cover a 4-byte slot qualify,
+ * same criteria as record_arg_mem_access() lines 545-549.
+ * Writes are accumulated into writes[arg] bitmask.
+ */
+static void get_insn_arg_writes(struct bpf_insn *insn,
+				struct arg_track *at,
+				u64 (*writes)[2])
+{
+	u8 class = BPF_CLASS(insn->code);
+	struct arg_track *ptr;
+	u32 sz, slot, slot_hi, s;
+	s64 acc_off;
+	int arg;
+
+	if (class == BPF_STX && BPF_MODE(insn->code) == BPF_MEM) {
+		ptr = &at[insn->dst_reg];
+		sz = bpf_size_to_bytes(BPF_SIZE(insn->code));
+	} else if (class == BPF_ST && BPF_MODE(insn->code) == BPF_MEM) {
+		ptr = &at[insn->dst_reg];
+		sz = bpf_size_to_bytes(BPF_SIZE(insn->code));
+	} else if (class == BPF_STX && BPF_MODE(insn->code) == BPF_ATOMIC &&
+		   insn->imm == BPF_STORE_REL) {
+		ptr = &at[insn->dst_reg];
+		sz = bpf_size_to_bytes(BPF_SIZE(insn->code));
+	} else {
+		return;
+	}
+
+	if (ptr->arg < 1 || ptr->arg >= NUM_AT_IDS)
+		return;
+	arg = ptr->arg;
+
+	if (ptr->off == OFF_IMPRECISE)
+		return;
+
+	acc_off = ptr->off + insn->off;
+	if (acc_off < 0)
+		return;
+
+	/* Only slot-aligned, full-4-byte writes qualify */
+	if (sz < STACK_SLOT_SZ || (acc_off % STACK_SLOT_SZ))
+		return;
+
+	slot = acc_off / STACK_SLOT_SZ;
+	if (slot >= STACK_SLOTS)
+		return;
+
+	slot_hi = (acc_off + sz - 1) / STACK_SLOT_SZ;
+	if (slot_hi >= STACK_SLOTS)
+		slot_hi = STACK_SLOTS - 1;
+
+	for (s = slot; s <= slot_hi; s++)
+		spis_set_bit(writes[arg], s);
+}
+
 /* Per-subprog intermediate state kept alive across analysis phases */
 struct subprog_at_info {
 	struct arg_track (*at_in)[MAX_BPF_REG];
@@ -929,9 +988,12 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 	struct arg_track (*at_stack_in)[MAX_ARG_SPILL_SLOTS] = NULL;
 	struct arg_track at_stack_out[MAX_ARG_SPILL_SLOTS];
 	u64 (*arg_use)[NUM_AT_IDS][2] = NULL;
-	bool changed;
+	u64 (*mw_in)[NUM_AT_IDS][2] = NULL;
+	u64 mw_out[NUM_AT_IDS][2];
+	u64 insn_writes[NUM_AT_IDS][2];
+	bool changed, seen_exit;
 	u32 mask;
-	int i, r, err;
+	int i, r, a, err;
 
 	for (i = 0; i < NUM_AT_IDS; i++) {
 		spis_clear(access->read[i]);
@@ -956,6 +1018,12 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 		goto err_free;
 	}
 
+	mw_in = kvmalloc_array(len, sizeof(*mw_in), GFP_KERNEL_ACCOUNT);
+	if (!mw_in) {
+		err = -ENOMEM;
+		goto err_free;
+	}
+
 	/* Initialize all registers to unvisited */
 	for (i = 0; i < len; i++)
 		for (r = 0; r < MAX_BPF_REG; r++)
@@ -966,6 +1034,16 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 		for (r = 0; r < MAX_ARG_SPILL_SLOTS; r++)
 			at_stack_in[i][r] = (struct arg_track){ .off = 0, .arg = ARG_UNVISITED };
 
+	/*
+	 * must_write: all-ones = identity for AND (unvisited),
+	 * entry = all-zeros (nothing must-written yet).
+	 */
+	for (i = 0; i < len; i++)
+		for (a = 0; a < NUM_AT_IDS; a++)
+			spis_set_all(mw_in[i][a]);
+	for (a = 0; a < NUM_AT_IDS; a++)
+		spis_clear(mw_in[0][a]);
+
 	/* Entry: R1-R5 are arg-derived with offset 0, FP is identity 0 */
 	for (r = 0; r < MAX_BPF_REG; r++)
 		at_in[0][r] = (struct arg_track){ .off = 0, .arg = ARG_NONE };
@@ -997,6 +1075,20 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 			memcpy(at_stack_out, at_stack_in[i], sizeof(at_stack_out));
 			arg_track_xfer(env, insn, access, at_out, at_stack_out);
 
+			/*
+			 * must_write transfer: mw_out = mw_in | writes.
+			 * Writes extracted from current at_in (may not
+			 * be converged yet; as at_in descends writes can
+			 * only shrink, keeping AND-join monotone).
+			 */
+			for (a = 0; a < NUM_AT_IDS; a++) {
+				spis_clear(insn_writes[a]);
+				spis_copy(mw_out[a], mw_in[i][a]);
+			}
+			get_insn_arg_writes(insn, at_in[i], insn_writes);
+			for (a = 0; a < NUM_AT_IDS; a++)
+				spis_or(mw_out[a], insn_writes[a]);
+
 			/* Log transfer function changes */
 			if (env->log.level & BPF_LOG_LEVEL2) {
 				for (r = 0; r < MAX_BPF_REG; r++) {
@@ -1056,10 +1148,56 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 						changed = true;
 					}
 				}
+
+				/* must_write AND-join */
+				for (a = 0; a < NUM_AT_IDS; a++) {
+					u64 old[2];
+
+					spis_copy(old, mw_in[ti][a]);
+					spis_and(mw_in[ti][a], mw_out[a]);
+					if (!spis_equal(old, mw_in[ti][a]))
+						changed = true;
+				}
 			}
 		}
 	}
 
+	/* Collect must_write at all BPF_EXIT instructions (intersect) */
+	seen_exit = false;
+	for (i = 0; i < len; i++) {
+		struct bpf_insn *insn = &insns[start + i];
+
+		if (insn->code != (BPF_JMP | BPF_EXIT))
+			continue;
+		if (at_in[i][0].arg == ARG_UNVISITED &&
+		    at_in[i][1].arg == ARG_UNVISITED)
+			continue;
+
+		/* Apply transfer at exit (exits have no successors) */
+		for (a = 0; a < NUM_AT_IDS; a++) {
+			spis_clear(insn_writes[a]);
+			spis_copy(mw_out[a], mw_in[i][a]);
+		}
+		get_insn_arg_writes(insn, at_in[i], insn_writes);
+		for (a = 0; a < NUM_AT_IDS; a++)
+			spis_or(mw_out[a], insn_writes[a]);
+
+		if (!seen_exit) {
+			for (a = 0; a < NUM_AT_IDS; a++)
+				spis_copy(access->must_write[a], mw_out[a]);
+			seen_exit = true;
+		} else {
+			for (a = 0; a < NUM_AT_IDS; a++)
+				spis_and(access->must_write[a], mw_out[a]);
+		}
+	}
+
+	/* Clamp: must_write can only include bits in write */
+	for (a = 0; a < NUM_AT_IDS; a++)
+		spis_and(access->must_write[a], access->write[a]);
+
+	kvfree(mw_in);
+
 	/* Apply unknown_args to access masks */
 	mask = access->unknown_args;
 	while (mask) {
@@ -1092,6 +1230,7 @@ static int compute_subprog_arg_tracking(struct bpf_verifier_env *env,
 	return 0;
 
 err_free:
+	kvfree(mw_in);
 	kvfree(arg_use);
 	kvfree(at_stack_in);
 	kvfree(at_in);
@@ -1574,9 +1713,12 @@ static void fp_off_insn_xfer(struct bpf_insn *insn,
 }
 
 /*
- * Apply subprogram argument access masks to caller's stack_use.
- * Maps callee's per-argument read/write bitmasks onto the caller's
- * stack slots starting at @spi.
+ * Apply subprogram argument access masks to caller's stack liveness.
+ * Maps callee's per-argument read/write/must_write bitmasks onto the
+ * caller's stack slots starting at @slot.
+ *
+ * must_write slots go into stack_def (they kill liveness).
+ * Remaining read + conditional-write bits go into stack_use.
  */
 static void apply_callee_stack_access(struct insn_live_regs *st,
 				      struct subprog_arg_access *sa,
@@ -1595,13 +1737,13 @@ static void apply_callee_stack_access(struct insn_live_regs *st,
 		return;
 	}
 
-	/*
-	 * Both read and write-only bits become stack_use.
-	 * Write-only bits are "may write" not "must write",
-	 * so they cannot be stack_def.
-	 */
+	/* Slots the callee must-write go into stack_def — they kill liveness */
+	arg_slots_to_spis(st->stack_def, slot, sa->must_write[arg]);
+
+	/* Read + conditional-write bits (excluding must-write) become stack_use */
 	spis_copy(combined, sa->read[arg]);
 	spis_or(combined, sa->write[arg]);
+	spis_andnot(combined, sa->must_write[arg]);
 	arg_slots_to_spis(st->stack_use, slot, combined);
 }