audit: fix auditd/kernel connection state tracking

What started as a rather straightforward race condition reported by
Dmitry using the syzkaller fuzzer ended up revealing some major
problems with how the audit subsystem managed its netlink sockets and
its connection with the userspace audit daemon.  Fixing this properly
had quite the cascading effect and what we are left with is this rather
large and complicated patch.  My initial goal was to try and decompose
this patch into multiple smaller patches, but the way these changes
are intertwined makes it difficult to split these changes into
meaningful pieces that don't break or somehow make things worse for
the intermediate states.

The patch makes a number of changes, but the most significant are
highlighted below:

* The auditd tracking variables, e.g. audit_sock, are now gone and
replaced by a RCU/spin_lock protected variable auditd_conn which is
a structure containing all of the auditd tracking information.

* We no longer track the auditd sock directly, instead we track it
via the network namespace in which it resides and we use the audit
socket associated with that namespace.  In spirit, this is what the
code was trying to do prior to this patch (at least I think that is
what the original authors intended), but it was done rather poorly
and added a layer of obfuscation that only masked the underlying
problems.

* Big backlog queue cleanup, again.  In v4.10 we made some pretty big
changes to how the audit backlog queues work, here we haven't changed
the queue design so much as cleaned up the implementation.  Brought
about by the locking changes, we've simplified kauditd_thread() quite
a bit by consolidating the queue handling into a new helper function,
kauditd_send_queue(), which allows us to eliminate a lot of very
similar code and makes the looping logic in kauditd_thread() clearer.

* All netlink messages sent to auditd are now sent via
auditd_send_unicast_skb().  Other than just making sense, this makes
the lock handling easier.

* Change the audit_log_start() sleep behavior so that we never sleep
on auditd events (unchanged) or if the caller is holding the
audit_cmd_mutex (changed).  Previously we didn't sleep if the caller
was auditd or if the message type fell between a certain range; the
type check was a poor effort of doing what the cmd_mutex check now
does.  Richard Guy Briggs originally proposed not sleeping the
cmd_mutex owner several years ago but his patch wasn't acceptable
at the time.  At least the idea lives on here.

* A problem with the lost record counter has been resolved.  Steve
Grubb and I both happened to notice this problem and according to
some quick testing by Steve, this problem goes back quite some time.
It's largely a harmless problem, although it may have left some
careful sysadmins quite puzzled.

Cc: <stable@vger.kernel.org> # 4.10.x-
Reported-by: Dmitry Vyukov <dvyukov@google.com>
Signed-off-by: Paul Moore <paul@paul-moore.com>
diff --git a/kernel/audit.c b/kernel/audit.c
index e794544..2f4964c 100644
--- a/kernel/audit.c
+++ b/kernel/audit.c
@@ -54,6 +54,10 @@
 #include <linux/kthread.h>
 #include <linux/kernel.h>
 #include <linux/syscalls.h>
+#include <linux/spinlock.h>
+#include <linux/rcupdate.h>
+#include <linux/mutex.h>
+#include <linux/gfp.h>
 
 #include <linux/audit.h>
 
@@ -90,13 +94,34 @@
 /* If auditing cannot proceed, audit_failure selects what happens. */
 static u32	audit_failure = AUDIT_FAIL_PRINTK;
 
-/*
- * If audit records are to be written to the netlink socket, audit_pid
- * contains the pid of the auditd process and audit_nlk_portid contains
- * the portid to use to send netlink messages to that process.
+/* private audit network namespace index */
+static unsigned int audit_net_id;
+
+/**
+ * struct audit_net - audit private network namespace data
+ * @sk: communication socket
  */
-int		audit_pid;
-static __u32	audit_nlk_portid;
+struct audit_net {
+	struct sock *sk;
+};
+
+/**
+ * struct auditd_connection - kernel/auditd connection state
+ * @pid: auditd PID
+ * @portid: netlink portid
+ * @net: the associated network namespace
+ * @lock: spinlock to protect write access
+ *
+ * Description:
+ * This struct is RCU protected; you must either hold the RCU lock for reading
+ * or the included spinlock for writing.
+ */
+static struct auditd_connection {
+	int pid;
+	u32 portid;
+	struct net *net;
+	spinlock_t lock;
+} auditd_conn;
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -123,10 +148,6 @@
 */
 static atomic_t	audit_lost = ATOMIC_INIT(0);
 
-/* The netlink socket. */
-static struct sock *audit_sock;
-static unsigned int audit_net_id;
-
 /* Hash for inode-based rules */
 struct list_head audit_inode_hash[AUDIT_INODE_BUCKETS];
 
@@ -139,6 +160,7 @@
 
 /* queue msgs to send via kauditd_task */
 static struct sk_buff_head audit_queue;
+static void kauditd_hold_skb(struct sk_buff *skb);
 /* queue msgs due to temporary unicast send problems */
 static struct sk_buff_head audit_retry_queue;
 /* queue msgs waiting for new auditd connection */
@@ -192,6 +214,43 @@
 	struct sk_buff *skb;
 };
 
+/**
+ * auditd_test_task - Check to see if a given task is an audit daemon
+ * @task: the task to check
+ *
+ * Description:
+ * Return 1 if the task is a registered audit daemon, 0 otherwise.
+ */
+int auditd_test_task(const struct task_struct *task)
+{
+	int rc;
+
+	rcu_read_lock();
+	rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0);
+	rcu_read_unlock();
+
+	return rc;
+}
+
+/**
+ * audit_get_sk - Return the audit socket for the given network namespace
+ * @net: the destination network namespace
+ *
+ * Description:
+ * Returns the sock pointer if valid, NULL otherwise.  The caller must ensure
+ * that a reference is held for the network namespace while the sock is in use.
+ */
+static struct sock *audit_get_sk(const struct net *net)
+{
+	struct audit_net *aunet;
+
+	if (!net)
+		return NULL;
+
+	aunet = net_generic(net, audit_net_id);
+	return aunet->sk;
+}
+
 static void audit_set_portid(struct audit_buffer *ab, __u32 portid)
 {
 	if (ab) {
@@ -210,9 +269,7 @@
 			pr_err("%s\n", message);
 		break;
 	case AUDIT_FAIL_PANIC:
-		/* test audit_pid since printk is always losey, why bother? */
-		if (audit_pid)
-			panic("audit: %s\n", message);
+		panic("audit: %s\n", message);
 		break;
 	}
 }
@@ -370,21 +427,87 @@
 	return audit_do_config_change("audit_failure", &audit_failure, state);
 }
 
-/*
- * For one reason or another this nlh isn't getting delivered to the userspace
- * audit daemon, just send it to printk.
+/**
+ * auditd_set - Set/Reset the auditd connection state
+ * @pid: auditd PID
+ * @portid: auditd netlink portid
+ * @net: auditd network namespace pointer
+ *
+ * Description:
+ * This function will obtain and drop network namespace references as
+ * necessary.
+ */
+static void auditd_set(int pid, u32 portid, struct net *net)
+{
+	unsigned long flags;
+
+	spin_lock_irqsave(&auditd_conn.lock, flags);
+	auditd_conn.pid = pid;
+	auditd_conn.portid = portid;
+	if (auditd_conn.net)
+		put_net(auditd_conn.net);
+	if (net)
+		auditd_conn.net = get_net(net);
+	else
+		auditd_conn.net = NULL;
+	spin_unlock_irqrestore(&auditd_conn.lock, flags);
+}
+
+/**
+ * auditd_reset - Disconnect the auditd connection
+ *
+ * Description:
+ * Break the auditd/kauditd connection and move all the queued records into the
+ * hold queue in case auditd reconnects.
+ */
+static void auditd_reset(void)
+{
+	struct sk_buff *skb;
+
+	/* if it isn't already broken, break the connection */
+	rcu_read_lock();
+	if (auditd_conn.pid)
+		auditd_set(0, 0, NULL);
+	rcu_read_unlock();
+
+	/* flush all of the main and retry queues to the hold queue */
+	while ((skb = skb_dequeue(&audit_retry_queue)))
+		kauditd_hold_skb(skb);
+	while ((skb = skb_dequeue(&audit_queue)))
+		kauditd_hold_skb(skb);
+}
+
+/**
+ * kauditd_print_skb - Print the audit record to the ring buffer
+ * @skb: audit record
+ *
+ * Whatever the reason, this packet may not make it to the auditd connection
+ * so write it via printk so the information isn't completely lost.
  */
 static void kauditd_printk_skb(struct sk_buff *skb)
 {
 	struct nlmsghdr *nlh = nlmsg_hdr(skb);
 	char *data = nlmsg_data(nlh);
 
-	if (nlh->nlmsg_type != AUDIT_EOE) {
-		if (printk_ratelimit())
-			pr_notice("type=%d %s\n", nlh->nlmsg_type, data);
-		else
-			audit_log_lost("printk limit exceeded");
-	}
+	if (nlh->nlmsg_type != AUDIT_EOE && printk_ratelimit())
+		pr_notice("type=%d %s\n", nlh->nlmsg_type, data);
+}
+
+/**
+ * kauditd_rehold_skb - Handle a audit record send failure in the hold queue
+ * @skb: audit record
+ *
+ * Description:
+ * This should only be used by the kauditd_thread when it fails to flush the
+ * hold queue.
+ */
+static void kauditd_rehold_skb(struct sk_buff *skb)
+{
+	/* put the record back in the queue at the same place */
+	skb_queue_head(&audit_hold_queue, skb);
+
+	/* fail the auditd connection */
+	auditd_reset();
 }
 
 /**
@@ -421,6 +544,9 @@
 	/* we have no other options - drop the message */
 	audit_log_lost("kauditd hold queue overflow");
 	kfree_skb(skb);
+
+	/* fail the auditd connection */
+	auditd_reset();
 }
 
 /**
@@ -441,51 +567,122 @@
 }
 
 /**
- * auditd_reset - Disconnect the auditd connection
+ * auditd_send_unicast_skb - Send a record via unicast to auditd
+ * @skb: audit record
  *
  * Description:
- * Break the auditd/kauditd connection and move all the records in the retry
- * queue into the hold queue in case auditd reconnects.  The audit_cmd_mutex
- * must be held when calling this function.
+ * Send a skb to the audit daemon, returns positive/zero values on success and
+ * negative values on failure; in all cases the skb will be consumed by this
+ * function.  If the send results in -ECONNREFUSED the connection with auditd
+ * will be reset.  This function may sleep so callers should not hold any locks
+ * where this would cause a problem.
  */
-static void auditd_reset(void)
+static int auditd_send_unicast_skb(struct sk_buff *skb)
 {
-	struct sk_buff *skb;
+	int rc;
+	u32 portid;
+	struct net *net;
+	struct sock *sk;
 
-	/* break the connection */
-	if (audit_sock) {
-		sock_put(audit_sock);
-		audit_sock = NULL;
+	/* NOTE: we can't call netlink_unicast while in the RCU section so
+	 *       take a reference to the network namespace and grab local
+	 *       copies of the namespace, the sock, and the portid; the
+	 *       namespace and sock aren't going to go away while we hold a
+	 *       reference and if the portid does become invalid after the RCU
+	 *       section netlink_unicast() should safely return an error */
+
+	rcu_read_lock();
+	if (!auditd_conn.pid) {
+		rcu_read_unlock();
+		rc = -ECONNREFUSED;
+		goto err;
 	}
-	audit_pid = 0;
-	audit_nlk_portid = 0;
+	net = auditd_conn.net;
+	get_net(net);
+	sk = audit_get_sk(net);
+	portid = auditd_conn.portid;
+	rcu_read_unlock();
 
-	/* flush all of the retry queue to the hold queue */
-	while ((skb = skb_dequeue(&audit_retry_queue)))
-		kauditd_hold_skb(skb);
+	rc = netlink_unicast(sk, skb, portid, 0);
+	put_net(net);
+	if (rc < 0)
+		goto err;
+
+	return rc;
+
+err:
+	if (rc == -ECONNREFUSED)
+		auditd_reset();
+	return rc;
 }
 
 /**
- * kauditd_send_unicast_skb - Send a record via unicast to auditd
- * @skb: audit record
+ * kauditd_send_queue - Helper for kauditd_thread to flush skb queues
+ * @sk: the sending sock
+ * @portid: the netlink destination
+ * @queue: the skb queue to process
+ * @retry_limit: limit on number of netlink unicast failures
+ * @skb_hook: per-skb hook for additional processing
+ * @err_hook: hook called if the skb fails the netlink unicast send
+ *
+ * Description:
+ * Run through the given queue and attempt to send the audit records to auditd,
+ * returns zero on success, negative values on failure.  It is up to the caller
+ * to ensure that the @sk is valid for the duration of this function.
+ *
  */
-static int kauditd_send_unicast_skb(struct sk_buff *skb)
+static int kauditd_send_queue(struct sock *sk, u32 portid,
+			      struct sk_buff_head *queue,
+			      unsigned int retry_limit,
+			      void (*skb_hook)(struct sk_buff *skb),
+			      void (*err_hook)(struct sk_buff *skb))
 {
-	int rc;
+	int rc = 0;
+	struct sk_buff *skb;
+	static unsigned int failed = 0;
 
-	/* if we know nothing is connected, don't even try the netlink call */
-	if (!audit_pid)
-		return -ECONNREFUSED;
+	/* NOTE: kauditd_thread takes care of all our locking, we just use
+	 *       the netlink info passed to us (e.g. sk and portid) */
 
-	/* get an extra skb reference in case we fail to send */
-	skb_get(skb);
-	rc = netlink_unicast(audit_sock, skb, audit_nlk_portid, 0);
-	if (rc >= 0) {
-		consume_skb(skb);
-		rc = 0;
+	while ((skb = skb_dequeue(queue))) {
+		/* call the skb_hook for each skb we touch */
+		if (skb_hook)
+			(*skb_hook)(skb);
+
+		/* can we send to anyone via unicast? */
+		if (!sk) {
+			if (err_hook)
+				(*err_hook)(skb);
+			continue;
+		}
+
+		/* grab an extra skb reference in case of error */
+		skb_get(skb);
+		rc = netlink_unicast(sk, skb, portid, 0);
+		if (rc < 0) {
+			/* fatal failure for our queue flush attempt? */
+			if (++failed >= retry_limit ||
+			    rc == -ECONNREFUSED || rc == -EPERM) {
+				/* yes - error processing for the queue */
+				sk = NULL;
+				if (err_hook)
+					(*err_hook)(skb);
+				if (!skb_hook)
+					goto out;
+				/* keep processing with the skb_hook */
+				continue;
+			} else
+				/* no - requeue to preserve ordering */
+				skb_queue_head(queue, skb);
+		} else {
+			/* it worked - drop the extra reference and continue */
+			consume_skb(skb);
+			failed = 0;
+		}
 	}
 
-	return rc;
+out:
+	return (rc >= 0 ? 0 : rc);
 }
 
 /*
@@ -493,16 +690,19 @@
  * @skb: audit record
  *
  * Description:
- * This function doesn't consume an skb as might be expected since it has to
- * copy it anyways.
+ * Write a multicast message to anyone listening in the initial network
+ * namespace.  This function doesn't consume an skb as might be expected since
+ * it has to copy it anyways.
  */
 static void kauditd_send_multicast_skb(struct sk_buff *skb)
 {
 	struct sk_buff *copy;
-	struct audit_net *aunet = net_generic(&init_net, audit_net_id);
-	struct sock *sock = aunet->nlsk;
+	struct sock *sock = audit_get_sk(&init_net);
 	struct nlmsghdr *nlh;
 
+	/* NOTE: we are not taking an additional reference for init_net since
+	 *       we don't have to worry about it going away */
+
 	if (!netlink_has_listeners(sock, AUDIT_NLGRP_READLOG))
 		return;
 
@@ -526,149 +726,75 @@
 }
 
 /**
- * kauditd_wake_condition - Return true when it is time to wake kauditd_thread
- *
- * Description:
- * This function is for use by the wait_event_freezable() call in
- * kauditd_thread().
+ * kauditd_thread - Worker thread to send audit records to userspace
+ * @dummy: unused
  */
-static int kauditd_wake_condition(void)
-{
-	static int pid_last = 0;
-	int rc;
-	int pid = audit_pid;
-
-	/* wake on new messages or a change in the connected auditd */
-	rc = skb_queue_len(&audit_queue) || (pid && pid != pid_last);
-	if (rc)
-		pid_last = pid;
-
-	return rc;
-}
-
 static int kauditd_thread(void *dummy)
 {
 	int rc;
-	int auditd = 0;
-	int reschedule = 0;
-	struct sk_buff *skb;
-	struct nlmsghdr *nlh;
+	u32 portid = 0;
+	struct net *net = NULL;
+	struct sock *sk = NULL;
 
 #define UNICAST_RETRIES 5
-#define AUDITD_BAD(x,y) \
-	((x) == -ECONNREFUSED || (x) == -EPERM || ++(y) >= UNICAST_RETRIES)
-
-	/* NOTE: we do invalidate the auditd connection flag on any sending
-	 * errors, but we only "restore" the connection flag at specific places
-	 * in the loop in order to help ensure proper ordering of audit
-	 * records */
 
 	set_freezable();
 	while (!kthread_should_stop()) {
-		/* NOTE: possible area for future improvement is to look at
-		 *       the hold and retry queues, since only this thread
-		 *       has access to these queues we might be able to do
-		 *       our own queuing and skip some/all of the locking */
-
-		/* NOTE: it might be a fun experiment to split the hold and
-		 *       retry queue handling to another thread, but the
-		 *       synchronization issues and other overhead might kill
-		 *       any performance gains */
+		/* NOTE: see the lock comments in auditd_send_unicast_skb() */
+		rcu_read_lock();
+		if (!auditd_conn.pid) {
+			rcu_read_unlock();
+			goto main_queue;
+		}
+		net = auditd_conn.net;
+		get_net(net);
+		sk = audit_get_sk(net);
+		portid = auditd_conn.portid;
+		rcu_read_unlock();
 
 		/* attempt to flush the hold queue */
-		while (auditd && (skb = skb_dequeue(&audit_hold_queue))) {
-			rc = kauditd_send_unicast_skb(skb);
-			if (rc) {
-				/* requeue to the same spot */
-				skb_queue_head(&audit_hold_queue, skb);
-
-				auditd = 0;
-				if (AUDITD_BAD(rc, reschedule)) {
-					mutex_lock(&audit_cmd_mutex);
-					auditd_reset();
-					mutex_unlock(&audit_cmd_mutex);
-					reschedule = 0;
-				}
-			} else
-				/* we were able to send successfully */
-				reschedule = 0;
+		rc = kauditd_send_queue(sk, portid,
+					&audit_hold_queue, UNICAST_RETRIES,
+					NULL, kauditd_rehold_skb);
+		if (rc < 0) {
+			sk = NULL;
+			goto main_queue;
 		}
 
 		/* attempt to flush the retry queue */
-		while (auditd && (skb = skb_dequeue(&audit_retry_queue))) {
-			rc = kauditd_send_unicast_skb(skb);
-			if (rc) {
-				auditd = 0;
-				if (AUDITD_BAD(rc, reschedule)) {
-					kauditd_hold_skb(skb);
-					mutex_lock(&audit_cmd_mutex);
-					auditd_reset();
-					mutex_unlock(&audit_cmd_mutex);
-					reschedule = 0;
-				} else
-					/* temporary problem (we hope), queue
-					 * to the same spot and retry */
-					skb_queue_head(&audit_retry_queue, skb);
-			} else
-				/* we were able to send successfully */
-				reschedule = 0;
+		rc = kauditd_send_queue(sk, portid,
+					&audit_retry_queue, UNICAST_RETRIES,
+					NULL, kauditd_hold_skb);
+		if (rc < 0) {
+			sk = NULL;
+			goto main_queue;
 		}
 
-		/* standard queue processing, try to be as quick as possible */
-quick_loop:
-		skb = skb_dequeue(&audit_queue);
-		if (skb) {
-			/* setup the netlink header, see the comments in
-			 * kauditd_send_multicast_skb() for length quirks */
-			nlh = nlmsg_hdr(skb);
-			nlh->nlmsg_len = skb->len - NLMSG_HDRLEN;
+main_queue:
+		/* process the main queue - do the multicast send and attempt
+		 * unicast, dump failed record sends to the retry queue; if
+		 * sk == NULL due to previous failures we will just do the
+		 * multicast send and move the record to the retry queue */
+		kauditd_send_queue(sk, portid, &audit_queue, 1,
+				   kauditd_send_multicast_skb,
+				   kauditd_retry_skb);
 
-			/* attempt to send to any multicast listeners */
-			kauditd_send_multicast_skb(skb);
-
-			/* attempt to send to auditd, queue on failure */
-			if (auditd) {
-				rc = kauditd_send_unicast_skb(skb);
-				if (rc) {
-					auditd = 0;
-					if (AUDITD_BAD(rc, reschedule)) {
-						mutex_lock(&audit_cmd_mutex);
-						auditd_reset();
-						mutex_unlock(&audit_cmd_mutex);
-						reschedule = 0;
-					}
-
-					/* move to the retry queue */
-					kauditd_retry_skb(skb);
-				} else
-					/* everything is working so go fast! */
-					goto quick_loop;
-			} else if (reschedule)
-				/* we are currently having problems, move to
-				 * the retry queue */
-				kauditd_retry_skb(skb);
-			else
-				/* dump the message via printk and hold it */
-				kauditd_hold_skb(skb);
-		} else {
-			/* we have flushed the backlog so wake everyone */
-			wake_up(&audit_backlog_wait);
-
-			/* if everything is okay with auditd (if present), go
-			 * to sleep until there is something new in the queue
-			 * or we have a change in the connected auditd;
-			 * otherwise simply reschedule to give things a chance
-			 * to recover */
-			if (reschedule) {
-				set_current_state(TASK_INTERRUPTIBLE);
-				schedule();
-			} else
-				wait_event_freezable(kauditd_wait,
-						     kauditd_wake_condition());
-
-			/* update the auditd connection status */
-			auditd = (audit_pid ? 1 : 0);
+		/* drop our netns reference, no auditd sends past this line */
+		if (net) {
+			put_net(net);
+			net = NULL;
 		}
+		sk = NULL;
+
+		/* we have processed all the queues so wake everyone */
+		wake_up(&audit_backlog_wait);
+
+		/* NOTE: we want to wake up if there is anything on the queue,
+		 *       regardless of if an auditd is connected, as we need to
+		 *       do the multicast send and rotate records from the
+		 *       main queue to the retry/hold queues */
+		wait_event_freezable(kauditd_wait,
+				     (skb_queue_len(&audit_queue) ? 1 : 0));
 	}
 
 	return 0;
@@ -678,17 +804,16 @@
 {
 	struct audit_netlink_list *dest = _dest;
 	struct sk_buff *skb;
-	struct net *net = dest->net;
-	struct audit_net *aunet = net_generic(net, audit_net_id);
+	struct sock *sk = audit_get_sk(dest->net);
 
 	/* wait for parent to finish and send an ACK */
 	mutex_lock(&audit_cmd_mutex);
 	mutex_unlock(&audit_cmd_mutex);
 
 	while ((skb = __skb_dequeue(&dest->q)) != NULL)
-		netlink_unicast(aunet->nlsk, skb, dest->portid, 0);
+		netlink_unicast(sk, skb, dest->portid, 0);
 
-	put_net(net);
+	put_net(dest->net);
 	kfree(dest);
 
 	return 0;
@@ -722,16 +847,15 @@
 static int audit_send_reply_thread(void *arg)
 {
 	struct audit_reply *reply = (struct audit_reply *)arg;
-	struct net *net = reply->net;
-	struct audit_net *aunet = net_generic(net, audit_net_id);
+	struct sock *sk = audit_get_sk(reply->net);
 
 	mutex_lock(&audit_cmd_mutex);
 	mutex_unlock(&audit_cmd_mutex);
 
 	/* Ignore failure. It'll only happen if the sender goes away,
 	   because our timeout is set to infinite. */
-	netlink_unicast(aunet->nlsk , reply->skb, reply->portid, 0);
-	put_net(net);
+	netlink_unicast(sk, reply->skb, reply->portid, 0);
+	put_net(reply->net);
 	kfree(reply);
 	return 0;
 }
@@ -949,12 +1073,12 @@
 
 static int audit_replace(pid_t pid)
 {
-	struct sk_buff *skb = audit_make_reply(0, 0, AUDIT_REPLACE, 0, 0,
-					       &pid, sizeof(pid));
+	struct sk_buff *skb;
 
+	skb = audit_make_reply(0, 0, AUDIT_REPLACE, 0, 0, &pid, sizeof(pid));
 	if (!skb)
 		return -ENOMEM;
-	return netlink_unicast(audit_sock, skb, audit_nlk_portid, 0);
+	return auditd_send_unicast_skb(skb);
 }
 
 static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
@@ -981,7 +1105,9 @@
 		memset(&s, 0, sizeof(s));
 		s.enabled		= audit_enabled;
 		s.failure		= audit_failure;
-		s.pid			= audit_pid;
+		rcu_read_lock();
+		s.pid			= auditd_conn.pid;
+		rcu_read_unlock();
 		s.rate_limit		= audit_rate_limit;
 		s.backlog_limit		= audit_backlog_limit;
 		s.lost			= atomic_read(&audit_lost);
@@ -1014,30 +1140,44 @@
 			 *       from the initial pid namespace, but something
 			 *       to keep in mind if this changes */
 			int new_pid = s.pid;
+			pid_t auditd_pid;
 			pid_t requesting_pid = task_tgid_vnr(current);
 
-			if ((!new_pid) && (requesting_pid != audit_pid)) {
-				audit_log_config_change("audit_pid", new_pid, audit_pid, 0);
+			/* test the auditd connection */
+			audit_replace(requesting_pid);
+
+			rcu_read_lock();
+			auditd_pid = auditd_conn.pid;
+			/* only the current auditd can unregister itself */
+			if ((!new_pid) && (requesting_pid != auditd_pid)) {
+				rcu_read_unlock();
+				audit_log_config_change("audit_pid", new_pid,
+							auditd_pid, 0);
 				return -EACCES;
 			}
-			if (audit_pid && new_pid &&
-			    audit_replace(requesting_pid) != -ECONNREFUSED) {
-				audit_log_config_change("audit_pid", new_pid, audit_pid, 0);
+			/* replacing a healthy auditd is not allowed */
+			if (auditd_pid && new_pid) {
+				rcu_read_unlock();
+				audit_log_config_change("audit_pid", new_pid,
+							auditd_pid, 0);
 				return -EEXIST;
 			}
+			rcu_read_unlock();
+
 			if (audit_enabled != AUDIT_OFF)
-				audit_log_config_change("audit_pid", new_pid, audit_pid, 1);
+				audit_log_config_change("audit_pid", new_pid,
+							auditd_pid, 1);
+
 			if (new_pid) {
-				if (audit_sock)
-					sock_put(audit_sock);
-				audit_pid = new_pid;
-				audit_nlk_portid = NETLINK_CB(skb).portid;
-				sock_hold(skb->sk);
-				audit_sock = skb->sk;
-			} else {
+				/* register a new auditd connection */
+				auditd_set(new_pid,
+					   NETLINK_CB(skb).portid,
+					   sock_net(NETLINK_CB(skb).sk));
+				/* try to process any backlog */
+				wake_up_interruptible(&kauditd_wait);
+			} else
+				/* unregister the auditd connection */
 				auditd_reset();
-			}
-			wake_up_interruptible(&kauditd_wait);
 		}
 		if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
 			err = audit_set_rate_limit(s.rate_limit);
@@ -1090,7 +1230,6 @@
 				if (err)
 					break;
 			}
-			mutex_unlock(&audit_cmd_mutex);
 			audit_log_common_recv_msg(&ab, msg_type);
 			if (msg_type != AUDIT_USER_TTY)
 				audit_log_format(ab, " msg='%.*s'",
@@ -1108,7 +1247,6 @@
 			}
 			audit_set_portid(ab, NETLINK_CB(skb).portid);
 			audit_log_end(ab);
-			mutex_lock(&audit_cmd_mutex);
 		}
 		break;
 	case AUDIT_ADD_RULE:
@@ -1298,26 +1436,26 @@
 
 	struct audit_net *aunet = net_generic(net, audit_net_id);
 
-	aunet->nlsk = netlink_kernel_create(net, NETLINK_AUDIT, &cfg);
-	if (aunet->nlsk == NULL) {
+	aunet->sk = netlink_kernel_create(net, NETLINK_AUDIT, &cfg);
+	if (aunet->sk == NULL) {
 		audit_panic("cannot initialize netlink socket in namespace");
 		return -ENOMEM;
 	}
-	aunet->nlsk->sk_sndtimeo = MAX_SCHEDULE_TIMEOUT;
+	aunet->sk->sk_sndtimeo = MAX_SCHEDULE_TIMEOUT;
+
 	return 0;
 }
 
 static void __net_exit audit_net_exit(struct net *net)
 {
 	struct audit_net *aunet = net_generic(net, audit_net_id);
-	struct sock *sock = aunet->nlsk;
-	mutex_lock(&audit_cmd_mutex);
-	if (sock == audit_sock)
-		auditd_reset();
-	mutex_unlock(&audit_cmd_mutex);
 
-	netlink_kernel_release(sock);
-	aunet->nlsk = NULL;
+	rcu_read_lock();
+	if (net == auditd_conn.net)
+		auditd_reset();
+	rcu_read_unlock();
+
+	netlink_kernel_release(aunet->sk);
 }
 
 static struct pernet_operations audit_net_ops __net_initdata = {
@@ -1335,20 +1473,24 @@
 	if (audit_initialized == AUDIT_DISABLED)
 		return 0;
 
-	pr_info("initializing netlink subsys (%s)\n",
-		audit_default ? "enabled" : "disabled");
-	register_pernet_subsys(&audit_net_ops);
+	memset(&auditd_conn, 0, sizeof(auditd_conn));
+	spin_lock_init(&auditd_conn.lock);
 
 	skb_queue_head_init(&audit_queue);
 	skb_queue_head_init(&audit_retry_queue);
 	skb_queue_head_init(&audit_hold_queue);
-	audit_initialized = AUDIT_INITIALIZED;
-	audit_enabled = audit_default;
-	audit_ever_enabled |= !!audit_default;
 
 	for (i = 0; i < AUDIT_INODE_BUCKETS; i++)
 		INIT_LIST_HEAD(&audit_inode_hash[i]);
 
+	pr_info("initializing netlink subsys (%s)\n",
+		audit_default ? "enabled" : "disabled");
+	register_pernet_subsys(&audit_net_ops);
+
+	audit_initialized = AUDIT_INITIALIZED;
+	audit_enabled = audit_default;
+	audit_ever_enabled |= !!audit_default;
+
 	kauditd_task = kthread_run(kauditd_thread, NULL, "kauditd");
 	if (IS_ERR(kauditd_task)) {
 		int err = PTR_ERR(kauditd_task);
@@ -1519,20 +1661,16 @@
 	if (unlikely(!audit_filter(type, AUDIT_FILTER_TYPE)))
 		return NULL;
 
-	/* don't ever fail/sleep on these two conditions:
+	/* NOTE: don't ever fail/sleep on these two conditions:
 	 * 1. auditd generated record - since we need auditd to drain the
 	 *    queue; also, when we are checking for auditd, compare PIDs using
 	 *    task_tgid_vnr() since auditd_pid is set in audit_receive_msg()
 	 *    using a PID anchored in the caller's namespace
-	 * 2. audit command message - record types 1000 through 1099 inclusive
-	 *    are command messages/records used to manage the kernel subsystem
-	 *    and the audit userspace, blocking on these messages could cause
-	 *    problems under load so don't do it (note: not all of these
-	 *    command types are valid as record types, but it is quicker to
-	 *    just check two ints than a series of ints in a if/switch stmt) */
-	if (!((audit_pid && audit_pid == task_tgid_vnr(current)) ||
-	      (type >= 1000 && type <= 1099))) {
-		long sleep_time = audit_backlog_wait_time;
+	 * 2. generator holding the audit_cmd_mutex - we don't want to block
+	 *    while holding the mutex */
+	if (!(auditd_test_task(current) ||
+	      (current == __mutex_owner(&audit_cmd_mutex)))) {
+		long stime = audit_backlog_wait_time;
 
 		while (audit_backlog_limit &&
 		       (skb_queue_len(&audit_queue) > audit_backlog_limit)) {
@@ -1541,14 +1679,13 @@
 
 			/* sleep if we are allowed and we haven't exhausted our
 			 * backlog wait limit */
-			if ((gfp_mask & __GFP_DIRECT_RECLAIM) &&
-			    (sleep_time > 0)) {
+			if (gfpflags_allow_blocking(gfp_mask) && (stime > 0)) {
 				DECLARE_WAITQUEUE(wait, current);
 
 				add_wait_queue_exclusive(&audit_backlog_wait,
 							 &wait);
 				set_current_state(TASK_UNINTERRUPTIBLE);
-				sleep_time = schedule_timeout(sleep_time);
+				stime = schedule_timeout(stime);
 				remove_wait_queue(&audit_backlog_wait, &wait);
 			} else {
 				if (audit_rate_check() && printk_ratelimit())
@@ -2127,15 +2264,27 @@
  */
 void audit_log_end(struct audit_buffer *ab)
 {
+	struct sk_buff *skb;
+	struct nlmsghdr *nlh;
+
 	if (!ab)
 		return;
-	if (!audit_rate_check()) {
-		audit_log_lost("rate limit exceeded");
-	} else {
-		skb_queue_tail(&audit_queue, ab->skb);
-		wake_up_interruptible(&kauditd_wait);
+
+	if (audit_rate_check()) {
+		skb = ab->skb;
 		ab->skb = NULL;
-	}
+
+		/* setup the netlink header, see the comments in
+		 * kauditd_send_multicast_skb() for length quirks */
+		nlh = nlmsg_hdr(skb);
+		nlh->nlmsg_len = skb->len - NLMSG_HDRLEN;
+
+		/* queue the netlink packet and poke the kauditd thread */
+		skb_queue_tail(&audit_queue, skb);
+		wake_up_interruptible(&kauditd_wait);
+	} else
+		audit_log_lost("rate limit exceeded");
+
 	audit_buffer_free(ab);
 }
 
diff --git a/kernel/audit.h b/kernel/audit.h
index ca57988..0f1cf6d 100644
--- a/kernel/audit.h
+++ b/kernel/audit.h
@@ -218,7 +218,7 @@
 			   struct audit_names *n, const struct path *path,
 			   int record_num, int *call_panic);
 
-extern int audit_pid;
+extern int auditd_test_task(const struct task_struct *task);
 
 #define AUDIT_INODE_BUCKETS	32
 extern struct list_head audit_inode_hash[AUDIT_INODE_BUCKETS];
@@ -250,10 +250,6 @@
 
 int audit_send_list(void *);
 
-struct audit_net {
-	struct sock *nlsk;
-};
-
 extern int selinux_audit_rule_update(void);
 
 extern struct mutex audit_filter_mutex;
@@ -340,8 +336,7 @@
 extern int __audit_signal_info(int sig, struct task_struct *t);
 static inline int audit_signal_info(int sig, struct task_struct *t)
 {
-	if (unlikely((audit_pid && t->tgid == audit_pid) ||
-		     (audit_signals && !audit_dummy_context())))
+	if (auditd_test_task(t) || (audit_signals && !audit_dummy_context()))
 		return __audit_signal_info(sig, t);
 	return 0;
 }
diff --git a/kernel/auditsc.c b/kernel/auditsc.c
index d6a8de5..e59ffc7 100644
--- a/kernel/auditsc.c
+++ b/kernel/auditsc.c
@@ -762,7 +762,7 @@
 	struct audit_entry *e;
 	enum audit_state state;
 
-	if (audit_pid && tsk->tgid == audit_pid)
+	if (auditd_test_task(tsk))
 		return AUDIT_DISABLED;
 
 	rcu_read_lock();
@@ -816,7 +816,7 @@
 {
 	struct audit_names *n;
 
-	if (audit_pid && tsk->tgid == audit_pid)
+	if (auditd_test_task(tsk))
 		return;
 
 	rcu_read_lock();
@@ -2256,7 +2256,7 @@
 	struct audit_context *ctx = tsk->audit_context;
 	kuid_t uid = current_uid(), t_uid = task_uid(t);
 
-	if (audit_pid && t->tgid == audit_pid) {
+	if (auditd_test_task(t)) {
 		if (sig == SIGTERM || sig == SIGHUP || sig == SIGUSR1 || sig == SIGUSR2) {
 			audit_sig_pid = task_tgid_nr(tsk);
 			if (uid_valid(tsk->loginuid))