inet_diag.c: added filter building code from socketstat (Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>)
             added filter options to create function
             added listening states constant
  	     added process, pid and fd getter functions
             using standard inet_idag.h instead of copy
	     using inet_diag_req_v2
	     various memory leak cleanups
diag_filter.h: new header file
               added various filter condition constants
               added diag_filter structure

Signed-off-by: James Hulka <james.hulka@gmail.com>
Signed-off-by: Jiri Kastner <jkastner@redhat.com>
diff --git a/python-inet_diag/diag_filter.h b/python-inet_diag/diag_filter.h
new file mode 100644
index 0000000..c7fd725
--- /dev/null
+++ b/python-inet_diag/diag_filter.h
@@ -0,0 +1,26 @@
+#ifndef _DIAG_FILTER_H_
+#define _DIAG_FILTER_H_ 1
+
+#define DIAG_FILTER_OR	1
+#define DIAG_FILTER_AND	2
+#define DIAG_FILTER_NOT	3
+#define DIAG_BC_NOP 4
+#define DIAG_BC_JMP 5
+#define DIAG_BC_S_GE 6
+#define DIAG_BC_S_LE 7
+#define DIAG_BC_D_GE 8
+#define DIAG_BC_D_LE 9
+#define DIAG_BC_AUTO 10
+#define DIAG_BC_S_COND 11
+#define DIAG_BC_D_COND 12
+
+#include <linux/types.h>
+
+/* Simple filter using the INET_DIAG_BC_* types */
+struct diag_filter {
+	int type;
+	int value;
+	struct diag_filter *post;
+	struct diag_filter *pred;
+};
+#endif /* _DIAG_FILTER_H_ */
diff --git a/python-inet_diag/inet_diag.c b/python-inet_diag/inet_diag.c
index b7bc32d..0e983ba 100644
--- a/python-inet_diag/inet_diag.c
+++ b/python-inet_diag/inet_diag.c
@@ -25,8 +25,20 @@
 #include <linux/rtnetlink.h>
 #include <arpa/inet.h>
 #include <netinet/tcp.h>
-#include "inet_diag_copy.h"
+#include <linux/inet_diag.h>
+#include <linux/sock_diag.h>
 
+#include <stdlib.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <string.h>
+#include <netdb.h>
+#include <dirent.h>
+#include <fnmatch.h>
+#include <fcntl.h>
+#include <sys/ioctl.h>
+
+#include "diag_filter.h"
 #ifndef __unused
 #define __unused __attribute__ ((unused))
 #endif
@@ -89,6 +101,8 @@
 					     (1 << SS_TIME_WAIT) |
 					     (1 << SS_SYN_RECV));
 
+static const int listen_states = (1<<SS_LISTEN) | (1<<SS_CLOSE);
+
 static const char *tmr_name[] = {
 	"off",
 	"on",
@@ -103,6 +117,7 @@
 	struct inet_diag_msg msg;
 	struct inet_diag_meminfo *ext_memory;
 	struct tcp_info *ext_protocol;
+    struct user_ent *proc;
 	char *ext_congestion;
 };
 
@@ -111,6 +126,7 @@
 {
 	free(self->ext_memory);
 	free(self->ext_protocol);
+    free(self->proc);
 	free(self->ext_congestion);
 	PyObject_Del(self);
 }
@@ -136,6 +152,161 @@
 	return 0;
 }
 
+/* process and user lookup */
+struct user_ent {
+    struct user_ent *next;
+    unsigned int    ino;
+    int             pid;
+    int             fd;
+    char            process[4096];
+};
+
+#define USER_ENT_HASH_SIZE  256
+struct user_ent *user_ent_hash[USER_ENT_HASH_SIZE] = {0};
+
+int show_users = 0;
+
+static int user_ent_hashfn(unsigned int ino)
+{
+    int val = (ino >> 24) ^ (ino >> 16) ^ (ino >> 8) ^ ino;
+
+    return val & (USER_ENT_HASH_SIZE - 1);
+}
+
+static void user_ent_add(unsigned int ino, int pid, int fd)
+{
+    struct user_ent *p, **pp;
+
+    p = malloc(sizeof(struct user_ent));
+    if (!p)
+        abort();
+    p->next = NULL;
+    p->ino = ino;
+    p->pid = pid;
+    p->fd = fd;
+
+    pp = &user_ent_hash[user_ent_hashfn(ino)];
+    p->next = *pp;
+    *pp = p;
+}
+
+static void user_ent_hash_build(void)
+{   
+    const char *root = getenv("PROC_ROOT") ? : "/proc/";
+    struct dirent *d;
+    char name[1024];
+    int nameoff;
+    DIR *dir;
+
+    strcpy(name, root);
+    if (strlen(name) == 0 || name[strlen(name)-1] != '/')
+        strcat(name, "/");
+
+    nameoff = strlen(name);
+
+    dir = opendir(name);
+    if (!dir)
+        return;
+
+    while ((d = readdir(dir)) != NULL) {
+        struct dirent *d1;
+        int pid, pos;
+        DIR *dir1;
+        char crap;
+
+        if (sscanf(d->d_name, "%d%c", &pid, &crap) != 1)
+            continue;
+
+        sprintf(name + nameoff, "%d/fd/", pid);
+        pos = strlen(name);
+        if ((dir1 = opendir(name)) == NULL)
+            continue;
+
+        while ((d1 = readdir(dir1)) != NULL) {
+            const char *pattern = "socket:[";
+            unsigned int ino;
+            char lnk[64];
+            int fd;
+            ssize_t link_len;
+
+            if (sscanf(d1->d_name, "%d%c", &fd, &crap) != 1)
+                continue;
+
+            sprintf(name+pos, "%d", fd);
+
+            link_len = readlink(name, lnk, sizeof(lnk)-1);
+            if (link_len == -1)
+                continue;
+            lnk[link_len] = '\0';
+
+            if (strncmp(lnk, pattern, strlen(pattern)))
+                continue;
+
+            sscanf(lnk, "socket:[%u]", &ino);
+
+            user_ent_add(ino, pid, fd);
+        }
+        closedir(dir1);
+    }
+    closedir(dir);
+}
+
+static int find_users(unsigned ino, struct user_ent *found)
+{
+    struct user_ent *p;
+    int cnt = 0;
+
+    if (!ino)
+        return 0;
+
+    p = user_ent_hash[user_ent_hashfn(ino)];
+    while (p) {
+        if (p->ino != ino)
+            goto next;
+
+        found->ino  = p->ino;
+        found->fd   = p->fd;
+        found->pid  = p->pid;
+        found->next = NULL;
+
+        cnt++;
+    next:
+        p = p->next;
+    }
+
+    //get the full process path
+    char tmp[4096];
+    const char *root = getenv("PROC_ROOT") ? : "/proc/";
+
+    snprintf(tmp, sizeof(tmp), "%s/%d/exe", root, found->pid);
+    char *bin_path = canonicalize_file_name(tmp);
+    if ( bin_path != NULL ) {
+        strncpy(found->process, bin_path, 4096);
+        free(bin_path);
+    }
+
+    return cnt;
+}
+
+static void clear_users(void)
+{
+    struct user_ent *p;
+    struct user_ent *temp;
+    int i;
+
+    for (i = 0; i < USER_ENT_HASH_SIZE; i++) {
+        if ( user_ent_hash[i] != 0 ) {
+            p = user_ent_hash[i];
+            while ( p != NULL ) {
+                temp = p;
+                p    = p->next;
+                free(temp);
+            }
+        }
+    user_ent_hash[i] = 0;
+    }
+}
+
 static char inet_socket__daddr_doc__[] =
 "daddr() -- get internet socket destination address";
 static PyObject *inet_socket__daddr(struct inet_socket *self,
@@ -159,7 +330,7 @@
 }
 
 static char inet_socket__sock_doc__[] =
-"saddr() -- get internet socket pointer";
+"sock() -- get internet socket pointer";
 static PyObject *inet_socket__sock(struct inet_socket *self,
 				   PyObject *args __unused)
 {
@@ -168,7 +339,7 @@
 }
 
 static char inet_socket__congestion_algorithm_doc__[] =
-"saddr() -- get internet socket congestion algorithm being used";
+"congestion_algorithm() -- get internet socket congestion algorithm being used";
 static PyObject *inet_socket__congestion_algorithm(struct inet_socket *self,
 						   PyObject *args __unused)
 {
@@ -180,6 +351,19 @@
 	return PyString_FromString(self->ext_congestion);
 }
 
+static char inet_socket__process_doc__[] =
+"process() -- get name of process";
+static PyObject *inet_socket__process(struct inet_socket *self,
+                    PyObject *args __unused)
+{
+    if (self->proc == NULL) {
+        PyErr_SetString(PyExc_OSError,          
+                "no process found or proc not specified");
+        return NULL;
+    }
+    return PyString_FromString(self->proc->process);
+}
+
 #define INET_SOCK__STR_METHOD(name, field, table, doc)		\
 static char inet_socket__##name##_doc__[] = #name "() -- " doc;	\
 static PyObject *inet_socket__##name(struct inet_socket *self,	\
@@ -211,6 +395,12 @@
 	return Py_BuildValue("l", self->ext_##ext->field); 	\
 }
 		
+#define INET_SOCK__PROC_INT_METHOD(name, field, doc)            \
+static char inet_socket__##name##_doc__[] = #name "() -- " doc; \
+static PyObject *inet_socket__##name(struct inet_socket *self,  \
+                     PyObject *args __unused)   \
+{ return Py_BuildValue("i", self->proc->field); }
+
 INET_SOCK__NET_INT_METHOD(dport, id.idiag_dport,
 			  "get internet socket destination port");
 INET_SOCK__NET_INT_METHOD(sport, id.idiag_sport,
@@ -266,6 +456,10 @@
 			  "get socket congestion window");
 INET_SOCK__EXT_INT_METHOD(ssthresh, protocol, tcpi_snd_ssthresh,
 			  "get socket slow start threshold");
+INET_SOCK__PROC_INT_METHOD(pid, pid,
+              "get process id");
+INET_SOCK__PROC_INT_METHOD(fd, fd,
+              "get file descriptor");
 
 #define INET_SOCK__METHOD(name)	{			\
 	.ml_name  = #name,				\
@@ -305,6 +499,9 @@
 	INET_SOCK__METHOD(ato),
 	INET_SOCK__METHOD(cwnd),
 	INET_SOCK__METHOD(ssthresh),
+    INET_SOCK__METHOD(process),
+    INET_SOCK__METHOD(pid),
+    INET_SOCK__METHOD(fd),
 	{ .ml_name = NULL, }
 };
 
@@ -332,7 +529,7 @@
 };
 
 /* constructor */
-static PyObject *inet_socket__new(struct inet_diag_msg *r, int nlmsg_len)
+static PyObject *inet_socket__new(struct inet_diag_msg *r, int nlmsg_len, struct user_ent *proc)
 {
 	struct inet_socket *self;
 
@@ -349,6 +546,7 @@
 	self->ext_memory = NULL;
 	self->ext_protocol = NULL;
 	self->ext_congestion = NULL;
+    self->proc = NULL;
 
 	if (nlmsg_len) {
 		struct rtattr *tb[INET_DIAG_MAX + 1];
@@ -387,7 +585,18 @@
 			if (self->ext_congestion == NULL)
 				goto out_err;
 		}
+
+        if( proc != NULL) {
+            self->proc = malloc(sizeof(*self->proc));
+            if (self->proc == NULL)
+                goto out_err;
+            self->proc->ino = proc->ino;
+            self->proc->pid = proc->pid;
+            self->proc->fd = proc->fd;
+            strcpy(self->proc->process, proc->process);
+        }
 	}
+    free(proc);
 
 	return (PyObject *)self;
 out_err:
@@ -399,6 +608,8 @@
 struct inet_diag {
 	PyObject_HEAD
 	int		socket;		/* NETLINK socket */
+    char               *bytecode;   /* NETLINK filter */
+    struct diag_filter *filter;     /* NETLINK filter */
 	char		buf[8192];
 	struct nlmsghdr *h;
 	size_t		len;
@@ -408,6 +619,8 @@
 static void inet_diag__dealloc(struct inet_diag *self)
 {
 	close(self->socket);
+    clear_users();
+    free(self->bytecode);
 	PyObject_Del(self);
 }
 
@@ -491,9 +704,16 @@
 	const int nlmsg_len = self->h->nlmsg_len - NLMSG_LENGTH(sizeof(*r));
 	self->h = NLMSG_NEXT(self->h, self->len);
 
-	return inet_socket__new(r, nlmsg_len);
+    struct user_ent *found;
+    if (!(found=malloc(sizeof(struct user_ent)))) abort();
+    if ( show_users > 0 ) {
+        find_users(r->idiag_inode, found);
+    }
+
+    return inet_socket__new(r, nlmsg_len, found);
 }
 
+
 static struct PyMethodDef inet_diag__methods[] = {
 	{
 		.ml_name  = "get",
@@ -529,25 +749,193 @@
 	.tp_getattr	= (getattrfunc)inet_diag__getattr,
 };
 
+struct aafilter
+{
+    int               family;
+    int               mask;
+    unsigned long int prefix;
+    int               port;
+    struct aafilter   *next;
+};
+
+static void filter_patch(char *a, int len, int reloc)
+{
+    while (len > 0) {
+        struct inet_diag_bc_op *op = (struct inet_diag_bc_op*)a;
+        if (op->no == len+4)
+            op->no += reloc;
+        len -= op->yes;
+        a += op->yes;
+    }
+    if (len < 0)
+        abort();
+}
+
+static int filter_bytecompile(struct diag_filter *f, char **bytecode)
+{
+    switch (f->type) {
+        case DIAG_BC_AUTO:
+    {
+        if (!(*bytecode=malloc(4))) abort();
+        ((struct inet_diag_bc_op*)*bytecode)[0] = (struct inet_diag_bc_op){ INET_DIAG_BC_AUTO, 4, 8 };
+        return 4;
+    }
+        case DIAG_BC_D_COND:
+        case DIAG_BC_S_COND:
+    {
+        struct aafilter *a = (void*)f->pred;
+        struct aafilter *b;
+        char *ptr;
+        int  code = (f->type == DIAG_BC_D_COND ? INET_DIAG_BC_D_COND : INET_DIAG_BC_S_COND);
+        int len = 0;
+
+        for (b=a; b; b=b->next) {
+            len += 4 + sizeof(struct inet_diag_hostcond);
+            if (a->family == AF_INET6)
+                len += 16;
+            else
+                len += 4;
+            if (b->next)
+                len += 4;
+        }
+        if (!(ptr = malloc(len))) abort();
+        *bytecode = ptr;
+        for (b=a; b; b=b->next) {
+            struct inet_diag_bc_op *op = (struct inet_diag_bc_op *)ptr;
+            int alen = (a->family == AF_INET6 ? 16 : 4);
+            int oplen = alen + 4 + sizeof(struct inet_diag_hostcond);
+            struct inet_diag_hostcond *cond = (struct inet_diag_hostcond*)(ptr+4);
+
+            *op = (struct inet_diag_bc_op){ code, oplen, oplen+4 };
+            cond->family = a->family;
+            cond->port = a->port;
+            cond->prefix_len = a->mask;
+            memcpy(cond->addr, &a->prefix, alen);
+            ptr += oplen;
+            if (b->next) {
+                op = (struct inet_diag_bc_op *)ptr;
+                *op = (struct inet_diag_bc_op){ INET_DIAG_BC_JMP, 4, len - (ptr-*bytecode)};
+                ptr += 4;
+            }
+        }
+        return ptr - *bytecode;
+    }
+        case DIAG_BC_D_GE:
+    {
+        struct aafilter *x = (void*)f->pred;
+        if (!(*bytecode=malloc(8))) abort();
+        ((struct inet_diag_bc_op*)*bytecode)[0] = (struct inet_diag_bc_op){ INET_DIAG_BC_D_GE, 8, 12 };
+        ((struct inet_diag_bc_op*)*bytecode)[1] = (struct inet_diag_bc_op){ 0, 0, x->port };
+        return 8;
+    }
+        case DIAG_BC_D_LE:
+    {
+        struct aafilter *x = (void*)f->pred;
+        if (!(*bytecode=malloc(8))) abort();
+        ((struct inet_diag_bc_op*)*bytecode)[0] = (struct inet_diag_bc_op){ INET_DIAG_BC_D_LE, 8, 12 };
+        ((struct inet_diag_bc_op*)*bytecode)[1] = (struct inet_diag_bc_op){ 0, 0, x->port };
+        return 8;
+    }
+        case DIAG_BC_S_GE:
+    {
+        struct aafilter *x = (void*)f->pred;
+        if (!(*bytecode=malloc(8))) abort();
+        ((struct inet_diag_bc_op*)*bytecode)[0] = (struct inet_diag_bc_op){ INET_DIAG_BC_S_GE, 8, 12 };
+        ((struct inet_diag_bc_op*)*bytecode)[1] = (struct inet_diag_bc_op){ 0, 0, x->port };
+        return 8;
+    }
+        case DIAG_BC_S_LE:
+    {
+        struct aafilter *x = (void*)f->pred;
+        if (!(*bytecode=malloc(8))) abort();
+        ((struct inet_diag_bc_op*)*bytecode)[0] = (struct inet_diag_bc_op){ INET_DIAG_BC_S_LE, 8, 12 };
+        ((struct inet_diag_bc_op*)*bytecode)[1] = (struct inet_diag_bc_op){ 0, 0, x->port };
+        return 8;
+    }
+
+        case DIAG_FILTER_AND:
+    {
+        char *a1, *a2, *a, l1, l2;
+        l1 = filter_bytecompile(f->pred, &a1);
+        l2 = filter_bytecompile(f->post, &a2);
+        if (!(a = malloc(l1+l2))) abort();
+        memcpy(a, a1, l1);
+        memcpy(a+l1, a2, l2);
+        free(a1); free(a2);
+        filter_patch(a, l1, l2);
+        *bytecode = a;
+        return l1+l2;
+    }
+        case DIAG_FILTER_OR:
+    {
+        char *a1, *a2, *a, l1, l2;
+        l1 = filter_bytecompile(f->pred, &a1);
+        l2 = filter_bytecompile(f->post, &a2);
+        if (!(a = malloc(l1+l2+4))) abort();
+        memcpy(a, a1, l1);
+        memcpy(a+l1+4, a2, l2);
+        free(a1); free(a2);
+        *(struct inet_diag_bc_op*)(a+l1) = (struct inet_diag_bc_op){ INET_DIAG_BC_JMP, 4, l2+4 };
+        *bytecode = a;
+        return l1+l2+4;
+    }
+        case DIAG_FILTER_NOT:
+    {
+        char *a1, *a, l1;
+        l1 = filter_bytecompile(f->pred, &a1);
+        if (!(a = malloc(l1+4))) abort();
+        memcpy(a, a1, l1);
+        free(a1);
+        *(struct inet_diag_bc_op*)(a+l1) = (struct inet_diag_bc_op){ INET_DIAG_BC_JMP, 4, 8 };
+        *bytecode = a;
+        return l1+4;
+    }
+        default:
+        abort();
+    }
+}
+
 /* constructor */
 static char inet_diag_create__doc__[] =
-"create() -- creates a new inet_diag socket.";
+"create([states, extensions, socktype, src, sport, dst, dport, le_spt, le_dpt, ge_spt, ge_dpt, join=DIAG_FILTER_AND])\n\n\
+Creates a new inet_diag socket object. Filters include:\n\
+- socket states (SENT, RECV, etc.)\n\
+- source addr and source port (must be used together)\n\
+- dest addr and dest port (must be used together)\n\
+- <=, >= source and dest port\n\n\
+All specified source and dest filters can either be joined with DIAG_FILTER_AND or DIAG_FILTER_OR.";
 static PyObject *inet_diag__create(PyObject *mself __unused, PyObject *args,
 				   PyObject *keywds)
 {
 	int states = default_states;
 	int extensions = INET_DIAG_NONE;
-	int socktype = TCPDIAG_GETSOCK;
-	static char *kwlist[] = { "states", "extensions", "socktype" };
+    int socktype = IPPROTO_TCP;
+    const char *src;
+    const char *dst;
+    int sport  = -1;
+    int dport  = -1;
+    int le_spt = -1;
+    int le_dpt = -1;
+    int ge_spt = -1;
+    int ge_dpt = -1;
+    int proc   = 0;
+    int join   = DIAG_FILTER_AND;
+    static char *kwlist[] = { "states", "extensions", "socktype", "src", "dst", "sport", "dport", "le_spt", "le_dpt", "ge_spt", "ge_dpt", "join", "proc" };
 	struct inet_diag *self = PyObject_NEW(struct inet_diag,
 					      &inet_diag_type);
 	if (self == NULL)
 		return NULL;
 
-	if (!PyArg_ParseTupleAndKeywords(args, keywds, "|iii", kwlist,
-					 &states, &extensions, &socktype))
+    if (!PyArg_ParseTupleAndKeywords(args, keywds, "|iiissiiiiiiii", kwlist,
+                     &states, &extensions, &socktype, &src, &dst, &sport, &dport, &le_spt, &le_dpt, &ge_spt, &ge_dpt, &join, &proc))
 		goto out_err;
 
+    /* TODO: have different levels of process identification */
+    if ( proc > 0 ) {
+        show_users++;
+        user_ent_hash_build();
+    }
+
 	self->socket = socket(AF_NETLINK, SOCK_RAW, NETLINK_INET_DIAG);
 	if (self->socket < 0)
 		goto out_err;
@@ -555,35 +943,189 @@
 	struct sockaddr_nl nladdr = {
 		.nl_family = AF_NETLINK,
 	};
+
 	struct {
 		struct nlmsghdr nlh;
-		struct inet_diag_req r;
+        struct inet_diag_req_v2 r;
 	} req = {
 		.nlh = {
 			.nlmsg_len   = sizeof(req),
-			.nlmsg_type  = socktype,
+            .nlmsg_type  = SOCK_DIAG_BY_FAMILY,
 			.nlmsg_flags = NLM_F_ROOT | NLM_F_MATCH | NLM_F_REQUEST,
 			.nlmsg_seq   = 123456,
 		},
 		.r = {
-			.idiag_family = AF_INET,
+            .sdiag_family    = AF_INET,
+            .sdiag_protocol  = socktype,
 			.idiag_states = states,
 			.idiag_ext    = extensions,
 		},
 	};
-	struct iovec iov[1] = {
-		[0] = {
+
+    // filter preparation
+    struct diag_filter *filter;
+    int f_exists = 0;
+    if ( src != NULL && sport != -1 ) {
+        struct in_addr in_src = {};
+        inet_aton(src, &in_src);
+
+        filter = &(struct diag_filter){
+            .type = DIAG_BC_S_COND,
+            .pred = (void*)&(struct aafilter){
+                .family = AF_INET,
+                .mask   = 32,
+                .prefix = in_src.s_addr,
+                .port   = sport,
+                .next   = NULL,
+            },
+            .post = NULL
+        };
+        f_exists = 1;
+    }
+
+    if ( dst != NULL && dport != -1 ) {
+        struct in_addr in_dst = {};
+        inet_aton(dst, &in_dst);
+
+        struct diag_filter tmp_dst = (struct diag_filter){
+            .type = DIAG_BC_D_COND,
+            .pred = (void*)&(struct aafilter){
+                .family = AF_INET,
+                .mask   = 32,
+                .prefix = in_dst.s_addr,
+                .port   = dport,
+                .next   = NULL,
+            },
+            .post = NULL
+        };
+
+        if ( f_exists != 0 ) {
+            filter = &(struct diag_filter){
+                .type = join,
+                .pred = filter,
+                .post = &tmp_dst,
+            };
+        } else {
+            filter   = &tmp_dst;
+            f_exists = 1;
+        }
+    }
+
+    if ( le_spt != -1 ) {
+        struct diag_filter tmp_le_spt = {
+            .type = DIAG_BC_S_LE,
+            .pred = (void*)&(struct aafilter){
+                .port = le_spt,
+                .next = NULL,
+            },
+            .post = NULL
+        };
+
+        if ( f_exists != 0 ) {
+            filter = &(struct diag_filter){
+                .type = join,
+                .pred = filter,
+                .post = &tmp_le_spt,
+            };
+        } else {
+            filter   = &tmp_le_spt;
+            f_exists = 1;
+        }
+    }
+
+    if ( le_dpt != -1 ) {
+        struct diag_filter tmp_le_dpt = {
+            .type = DIAG_BC_D_LE,
+            .pred = (void*)&(struct aafilter){
+                .port = le_dpt,
+                .next = NULL,
+            },
+            .post = NULL
+        };
+
+        if ( f_exists != 0 ) {
+            filter = &(struct diag_filter){
+                .type = join,
+                .pred = filter,
+                .post = &tmp_le_dpt,
+            };
+        } else {
+            filter   = &tmp_le_dpt;
+            f_exists = 1;
+        }
+    }
+
+    if ( ge_spt != -1 ) {
+        struct diag_filter tmp_ge_spt = {
+            .type = DIAG_BC_S_GE,
+            .pred = (void*)&(struct aafilter){
+                .port = ge_spt,
+                .next = NULL,
+            },
+            .post = NULL
+        };
+
+        if ( f_exists != 0 ) {
+            filter = &(struct diag_filter){
+                .type = join,
+                .pred = filter,
+                .post = &tmp_ge_spt,
+            };
+        } else {
+            filter   = &tmp_ge_spt;
+            f_exists = 1;
+        }
+    }
+
+    if ( ge_dpt != -1 ) {
+        struct diag_filter tmp_ge_dpt = {
+            .type = DIAG_BC_D_GE,
+            .pred = (void*)&(struct aafilter){
+                .port = ge_dpt,
+                .next = NULL,
+            },
+            .post = NULL
+        };
+
+        if ( f_exists != 0 ) {
+            filter = &(struct diag_filter){
+                .type = join,
+                .pred = filter,
+                .post = &tmp_ge_dpt,
+            };
+        } else {
+            filter   = &tmp_ge_dpt;
+            f_exists = 1;
+        }
+    }
+
+    struct iovec iov[3];
+    iov[0] = (struct iovec){
 			.iov_base = &req,
 			.iov_len  = sizeof(req),
-		},
 	};
+
+    // append the filter    
+    struct rtattr rta; 
+
+    int filter_len = 0;
+    if ( f_exists != 0 ) {
+        filter_len   = filter_bytecompile(filter, &(self->bytecode));
+        rta.rta_type = INET_DIAG_REQ_BYTECODE;
+        rta.rta_len  = RTA_LENGTH(filter_len);
+
+        iov[1] = (struct iovec){ &rta, sizeof(rta) };
+        iov[2] = (struct iovec){ self->bytecode, filter_len };
+
+        req.nlh.nlmsg_len += RTA_LENGTH(filter_len);
+    }
+
 	struct msghdr msg = {
 		.msg_name    = &nladdr,
 		.msg_namelen = sizeof(nladdr),
 		.msg_iov     = iov,
-		.msg_iovlen  = 1,
+        .msg_iovlen  = ( f_exists != 0 ? 3 : 1 ),
 	};
-
 	if (sendmsg(self->socket, &msg, 0) < 0)
 		goto out_err;
 
@@ -608,7 +1150,16 @@
 PyMODINIT_FUNC initinet_diag(void)
 {
 	PyObject *m;
-	m = Py_InitModule("inet_diag", python_inet_diag__methods);
+    m = Py_InitModule3("inet_diag", python_inet_diag__methods, "Example:\n\n\
+    > import inet_diag\n\
+    > from socket import IPPROTO_TCP\n\
+    > idiag = inet_diag.create(states = inet_diag.default_states, extensions = inet_diag.EXT_MEMORY, socktype = IPPROTO_TCP, le_dpt = 500)\n\
+    > while True:\n\
+    >     try:\n\
+    >         s = idiag.get()\n\
+    >     except:\n\
+    >         break\n\
+    >     print s");
 	PyModule_AddIntConstant(m, "SS_ESTABLISHED", SS_ESTABLISHED);
 	PyModule_AddIntConstant(m, "SS_SYN_SENT",    SS_SYN_SENT);
 	PyModule_AddIntConstant(m, "SS_SYN_RECV",    SS_SYN_RECV);
@@ -622,6 +1173,7 @@
 	PyModule_AddIntConstant(m, "SS_CLOSING",     SS_CLOSING);
 	PyModule_AddIntConstant(m, "SS_ALL",	     SS_ALL);
 	PyModule_AddIntConstant(m, "default_states", default_states);
+    PyModule_AddIntConstant(m, "listen_states", listen_states);
 	PyModule_AddIntConstant(m, "EXT_MEMORY",     1 << (INET_DIAG_MEMINFO - 1));
 	PyModule_AddIntConstant(m, "EXT_PROTOCOL",   1 << (INET_DIAG_INFO - 1));
 	PyModule_AddIntConstant(m, "EXT_TCP_VEGAS",  1 << (INET_DIAG_VEGASINFO - 1));
@@ -632,4 +1184,6 @@
 	PyModule_AddIntConstant(m, "PROTO_OPT_ECN", TCPI_OPT_ECN);
 	PyModule_AddIntConstant(m, "TCPDIAG_GETSOCK", TCPDIAG_GETSOCK);
 	PyModule_AddIntConstant(m, "DCCPDIAG_GETSOCK", DCCPDIAG_GETSOCK);
+    PyModule_AddIntConstant(m, "DIAG_FILTER_AND", DIAG_FILTER_AND);
+    PyModule_AddIntConstant(m, "DIAG_FILTER_OR", DIAG_FILTER_OR);
 }