backup active_mm progress
diff --git a/arch/x86/events/core.c b/arch/x86/events/core.c
index 580b60f..77a3309 100644
--- a/arch/x86/events/core.c
+++ b/arch/x86/events/core.c
@@ -2101,8 +2101,7 @@
 
 static void refresh_pce(void *ignored)
 {
-	if (current->active_mm)
-		load_mm_cr4(current->active_mm);
+	load_mm_cr4(this_cpu_read(cpu_tlbstate.loaded_mm));
 }
 
 static void x86_pmu_event_mapped(struct perf_event *event)
diff --git a/arch/x86/include/asm/tlbflush.h b/arch/x86/include/asm/tlbflush.h
index fcc3315..fcdbddb 100644
--- a/arch/x86/include/asm/tlbflush.h
+++ b/arch/x86/include/asm/tlbflush.h
@@ -66,7 +66,13 @@
 #endif
 
 struct tlb_state {
-	struct mm_struct *active_mm;
+	/*
+	 * cpu_tlbstate.loaded_mm should match CR3 whenever interrupts
+	 * are on.  This means that it may not match current->active_mm,
+	 * which will contain the previous user mm when we're in lazy TLB
+	 * mode even if we've already switched back to swapper_pg_dir.
+	 */
+	struct mm_struct *loaded_mm;
 	int state;
 
 	/*
@@ -246,7 +252,9 @@
 static inline void reset_lazy_tlbstate(void)
 {
 	this_cpu_write(cpu_tlbstate.state, 0);
-	this_cpu_write(cpu_tlbstate.active_mm, &init_mm);
+	this_cpu_write(cpu_tlbstate.loaded_mm, &init_mm);
+
+	WARN_ON(read_cr3() != __pa_symbol(swapper_pg_dir));
 }
 
 static inline void arch_tlbbatch_add_mm(struct arch_tlbflush_unmap_batch *batch,
diff --git a/arch/x86/mm/init.c b/arch/x86/mm/init.c
index 0381638..dee6701 100644
--- a/arch/x86/mm/init.c
+++ b/arch/x86/mm/init.c
@@ -764,7 +764,7 @@
 }
 
 DEFINE_PER_CPU_SHARED_ALIGNED(struct tlb_state, cpu_tlbstate) = {
-	.active_mm = &init_mm,
+	.loaded_mm = &init_mm,
 	.state = 0,
 	.cr4 = ~0UL,	/* fail hard if we screw up cr4 shadow initialization */
 };
diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index da1416c..4a2627c 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -34,20 +34,19 @@
  */
 void leave_mm(int cpu)
 {
-	struct mm_struct *active_mm = this_cpu_read(cpu_tlbstate.active_mm);
+	struct mm_struct *loaded_mm = this_cpu_read(cpu_tlbstate.loaded_mm);
 	if (this_cpu_read(cpu_tlbstate.state) == TLBSTATE_OK)
 		BUG();
-	if (cpumask_test_cpu(cpu, mm_cpumask(active_mm))) {
-		cpumask_clear_cpu(cpu, mm_cpumask(active_mm));
-		load_cr3(swapper_pg_dir);
-		/*
-		 * This gets called in the idle path where RCU
-		 * functions differently.  Tracing normally
-		 * uses RCU, so we have to call the tracepoint
-		 * specially here.
-		 */
-		trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
-	}
+
+	/*
+	 * It's plausible that we're in lazy TLB mode while our mm is init_mm.
+	 * If so, our callers still expect us to flush the TLB, but there
+	 * aren't any user TLB entries in init_mm to worry about.
+	 */
+	if (loaded_mm == &init_mm)
+		return;
+
+	switch_mm(NULL, init_mm, NULL);
 }
 EXPORT_SYMBOL_GPL(leave_mm);
 
@@ -65,108 +64,109 @@
 			struct task_struct *tsk)
 {
 	unsigned cpu = smp_processor_id();
+	struct mm_struct *real_prev = this_cpu_read(cpu_tlbstate.loaded_mm);
 
-	if (likely(prev != next)) {
-		if (IS_ENABLED(CONFIG_VMAP_STACK)) {
-			/*
-			 * If our current stack is in vmalloc space and isn't
-			 * mapped in the new pgd, we'll double-fault.  Forcibly
-			 * map it.
-			 */
-			unsigned int stack_pgd_index = pgd_index(current_stack_pointer());
+	/*
+	 * NB: The scheduler will call us with prev == next when
+	 * switching from lazy TLB mode to normal mode if active_mm
+	 * isn't changing.  When this happens, there is no guarantee
+	 * that CR3 (and hence cpu_tlbstate.loaded_mm) matches next.
+	 *
+	 * NB: leave_mm() calls us with prev == NULL and tsk == NULL.
+	 */
 
-			pgd_t *pgd = next->pgd + stack_pgd_index;
+	this_cpu_write(cpu_tlbstate.state, TLBSTATE_OK);
 
-			if (unlikely(pgd_none(*pgd)))
-				set_pgd(pgd, init_mm.pgd[stack_pgd_index]);
-		}
-
-		this_cpu_write(cpu_tlbstate.state, TLBSTATE_OK);
-		this_cpu_write(cpu_tlbstate.active_mm, next);
-
-		cpumask_set_cpu(cpu, mm_cpumask(next));
-
+	if (real_prev == next) {
 		/*
-		 * Re-load page tables.
-		 *
-		 * This logic has an ordering constraint:
-		 *
-		 *  CPU 0: Write to a PTE for 'next'
-		 *  CPU 0: load bit 1 in mm_cpumask.  if nonzero, send IPI.
-		 *  CPU 1: set bit 1 in next's mm_cpumask
-		 *  CPU 1: load from the PTE that CPU 0 writes (implicit)
-		 *
-		 * We need to prevent an outcome in which CPU 1 observes
-		 * the new PTE value and CPU 0 observes bit 1 clear in
-		 * mm_cpumask.  (If that occurs, then the IPI will never
-		 * be sent, and CPU 0's TLB will contain a stale entry.)
-		 *
-		 * The bad outcome can occur if either CPU's load is
-		 * reordered before that CPU's store, so both CPUs must
-		 * execute full barriers to prevent this from happening.
-		 *
-		 * Thus, switch_mm needs a full barrier between the
-		 * store to mm_cpumask and any operation that could load
-		 * from next->pgd.  TLB fills are special and can happen
-		 * due to instruction fetches or for no reason at all,
-		 * and neither LOCK nor MFENCE orders them.
-		 * Fortunately, load_cr3() is serializing and gives the
-		 * ordering guarantee we need.
-		 *
+		 * There's nothing to do: we always keep the per-mm control
+		 * regs in sync with cpu_tlbstate.loaded_mm.  Just
+		 * sanity-check mm_cpumask.
 		 */
-		load_cr3(next->pgd);
+		if (WARN_ON_ONCE(!cpumask_test_cpu(cpu, mm_cpumask(next))))
+			cpumask_set_cpu(cpu, mm_cpumask(next));
+		return;
+	}
 
-		trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
+	if (IS_ENABLED(CONFIG_VMAP_STACK)) {
+		/*
+		 * If our current stack is in vmalloc space and isn't
+		 * mapped in the new pgd, we'll double-fault.  Forcibly
+		 * map it.
+		 */
+		unsigned int stack_pgd_index = pgd_index(current_stack_pointer());
 
-		/* Stop flush ipis for the previous mm */
-		cpumask_clear_cpu(cpu, mm_cpumask(prev));
+		pgd_t *pgd = next->pgd + stack_pgd_index;
 
-		/* Load per-mm CR4 state */
-		load_mm_cr4(next);
+		if (unlikely(pgd_none(*pgd)))
+			set_pgd(pgd, init_mm.pgd[stack_pgd_index]);
+	}
+
+	this_cpu_write(cpu_tlbstate.loaded_mm, next);
+
+	WARN_ON_ONCE(cpumask_test_cpu(cpu, mm_cpumask(next)));
+	cpumask_set_cpu(cpu, mm_cpumask(next));
+
+	/*
+	 * Re-load page tables.
+	 *
+	 * This logic has an ordering constraint:
+	 *
+	 *  CPU 0: Write to a PTE for 'next'
+	 *  CPU 0: load bit 1 in mm_cpumask.  if nonzero, send IPI.
+	 *  CPU 1: set bit 1 in next's mm_cpumask
+	 *  CPU 1: load from the PTE that CPU 0 writes (implicit)
+	 *
+	 * We need to prevent an outcome in which CPU 1 observes
+	 * the new PTE value and CPU 0 observes bit 1 clear in
+	 * mm_cpumask.  (If that occurs, then the IPI will never
+	 * be sent, and CPU 0's TLB will contain a stale entry.)
+	 *
+	 * The bad outcome can occur if either CPU's load is
+	 * reordered before that CPU's store, so both CPUs must
+	 * execute full barriers to prevent this from happening.
+	 *
+	 * Thus, switch_mm needs a full barrier between the
+	 * store to mm_cpumask and any operation that could load
+	 * from next->pgd.  TLB fills are special and can happen
+	 * due to instruction fetches or for no reason at all,
+	 * and neither LOCK nor MFENCE orders them.
+	 * Fortunately, load_cr3() is serializing and gives the
+	 * ordering guarantee we need.
+	 *
+	 */
+	load_cr3(next->pgd);
+
+	/*
+	 * This gets called via leave_mm() in the idle path where RCU
+	 * functions differently.  Tracing normally uses RCU, so we have to
+	 * call the tracepoint specially here.
+	 */
+	trace_tlb_flush_rcuidle(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
+
+	/* Stop flush ipis for the previous mm */
+	WARN_ON_ONCE(!cpumask_test_cpu(cpu, mm_cpumask(real_prev)));
+	cpumask_clear_cpu(cpu, mm_cpumask(real_prev));
+
+	/* Load per-mm CR4 state */
+	load_mm_cr4(next);
 
 #ifdef CONFIG_MODIFY_LDT_SYSCALL
-		/*
-		 * Load the LDT, if the LDT is different.
-		 *
-		 * It's possible that prev->context.ldt doesn't match
-		 * the LDT register.  This can happen if leave_mm(prev)
-		 * was called and then modify_ldt changed
-		 * prev->context.ldt but suppressed an IPI to this CPU.
-		 * In this case, prev->context.ldt != NULL, because we
-		 * never set context.ldt to NULL while the mm still
-		 * exists.  That means that next->context.ldt !=
-		 * prev->context.ldt, because mms never share an LDT.
-		 */
-		if (unlikely(prev->context.ldt != next->context.ldt))
-			load_mm_ldt(next);
+	/*
+	 * Load the LDT, if the LDT is different.
+	 *
+	 * It's possible that prev->context.ldt doesn't match
+	 * the LDT register.  This can happen if leave_mm(prev)
+	 * was called and then modify_ldt changed
+	 * prev->context.ldt but suppressed an IPI to this CPU.
+	 * In this case, prev->context.ldt != NULL, because we
+	 * never set context.ldt to NULL while the mm still
+	 * exists.  That means that next->context.ldt !=
+	 * prev->context.ldt, because mms never share an LDT.
+	 */
+	if (unlikely(real_prev->context.ldt != next->context.ldt))
+		load_mm_ldt(next);
 #endif
-	} else {
-		this_cpu_write(cpu_tlbstate.state, TLBSTATE_OK);
-		BUG_ON(this_cpu_read(cpu_tlbstate.active_mm) != next);
-
-		if (!cpumask_test_cpu(cpu, mm_cpumask(next))) {
-			/*
-			 * On established mms, the mm_cpumask is only changed
-			 * from irq context, from ptep_clear_flush() while in
-			 * lazy tlb mode, and here. Irqs are blocked during
-			 * schedule, protecting us from simultaneous changes.
-			 */
-			cpumask_set_cpu(cpu, mm_cpumask(next));
-
-			/*
-			 * We were in lazy tlb mode and leave_mm disabled
-			 * tlb flush IPI delivery. We must reload CR3
-			 * to make sure to use no freed page tables.
-			 *
-			 * As above, load_cr3() is serializing and orders TLB
-			 * fills with respect to the mm_cpumask write.
-			 */
-			load_cr3(next->pgd);
-			trace_tlb_flush(TLB_FLUSH_ON_TASK_SWITCH, TLB_FLUSH_ALL);
-			load_mm_cr4(next);
-			load_mm_ldt(next);
-		}
-	}
 }
 
 /*
@@ -246,7 +246,7 @@
 
 	inc_irq_stat(irq_tlb_count);
 
-	if (f->mm && f->mm != this_cpu_read(cpu_tlbstate.active_mm))
+	if (f->mm && f->mm != this_cpu_read(cpu_tlbstate.loaded_mm))
 		return;
 
 	count_vm_tlb_event(NR_TLB_REMOTE_FLUSH_RECEIVED);
diff --git a/arch/x86/xen/mmu.c b/arch/x86/xen/mmu.c
index 87a3c12..cd97d6e 100644
--- a/arch/x86/xen/mmu.c
+++ b/arch/x86/xen/mmu.c
@@ -1005,17 +1005,25 @@
 static void drop_other_mm_ref(void *info)
 {
 	struct mm_struct *mm = info;
-	struct mm_struct *active_mm;
 
-	active_mm = this_cpu_read(cpu_tlbstate.active_mm);
-
-	if (active_mm == mm && this_cpu_read(cpu_tlbstate.state) != TLBSTATE_OK)
+	if (this_cpu_read(cpu_tlbstate.loaded_mm) == mm) {
+		WARN_ON(this_cpu_read(cpu_tlbstate.state) == TLBSTATE_OK);
 		leave_mm(smp_processor_id());
+	}
 
 	/* If this cpu still has a stale cr3 reference, then make sure
 	   it has been flushed. */
-	if (this_cpu_read(xen_current_cr3) == __pa(mm->pgd))
+	if (this_cpu_read(xen_current_cr3) == __pa(mm->pgd)) {
+		/*
+		 * I don't get it.  If leave_mm didn't flush the state, then
+		 * we're in lazy mode, so I don't see why load_cr3() would help.
+		 *
+		 * Shouldn't this be xen_mc_flush()?
+		 */
+
 		load_cr3(swapper_pg_dir);
+		BUG();
+	}
 }
 
 static void xen_drop_mm_ref(struct mm_struct *mm)
@@ -1023,16 +1031,17 @@
 	cpumask_var_t mask;
 	unsigned cpu;
 
-	if (current->active_mm == mm) {
-		if (current->mm == mm)
-			load_cr3(swapper_pg_dir);
-		else
-			leave_mm(smp_processor_id());
+	if (this_cpu_read(cpu_tlbstate.loaded_mm) == mm) {
+		WARN_ON(this_cpu_read(cpu_tlbstate.state) == TLBSTATE_OK);
+		leave_mm(smp_processor_id());
 	}
 
 	/* Get the "official" set of cpus referring to our pagetable. */
 	if (!alloc_cpumask_var(&mask, GFP_ATOMIC)) {
 		for_each_online_cpu(cpu) {
+			/*
+			 * What if xen_cr3 == mm->pgd but xen_current_cr3 != mm->pgd?
+			 */
 			if (!cpumask_test_cpu(cpu, mm_cpumask(mm))
 			    && per_cpu(xen_current_cr3, cpu) != __pa(mm->pgd))
 				continue;
@@ -1059,8 +1068,10 @@
 #else
 static void xen_drop_mm_ref(struct mm_struct *mm)
 {
-	if (current->active_mm == mm)
-		load_cr3(swapper_pg_dir);
+	if (this_cpu_read(cpu_tlbstate.loaded_mm) == mm) {
+		WARN_ON(this_cpu_read(cpu_tlbstate.state) == TLBSTATE_OK);
+		leave_mm(smp_processor_id());
+	}
 }
 #endif