[PATCH] Filter rule comparators

Currently, audit only supports the "=" and "!=" operators in the -F
filter rules.

This patch reworks the support for "=" and "!=", and adds support
for ">", ">=", "<", and "<=".

This turned out to be a pretty clean, and simply process.  I ended up
using the high order bits of the "field", as suggested by Steve and Amy.
This allowed for no changes whatsoever to the netlink communications.
See the documentation within the patch in the include/linux/audit.h
area, where there is a table that explains the reasoning of the bitmask
assignments clearly.

The patch adds a new function, audit_comparator(left, op, right).
This function will perform the specified comparison (op, which defaults
to "==" for backward compatibility) between two values (left and right).
If the negate bit is on, it will negate whatever that result was.  This
value is returned.

Signed-off-by: Dustin Kirkland <dustin.kirkland@us.ibm.com>
Signed-off-by: David Woodhouse <dwmw2@infradead.org>
diff --git a/include/linux/audit.h b/include/linux/audit.h
index da3c019..2408cb7 100644
--- a/include/linux/audit.h
+++ b/include/linux/audit.h
@@ -98,6 +98,13 @@
 #define AUDIT_WORD(nr) ((__u32)((nr)/32))
 #define AUDIT_BIT(nr)  (1 << ((nr) - AUDIT_WORD(nr)*32))
 
+/* This bitmask is used to validate user input.  It represents all bits that
+ * are currently used in an audit field constant understood by the kernel.
+ * If you are adding a new #define AUDIT_<whatever>, please ensure that
+ * AUDIT_UNUSED_BITS is updated if need be. */
+#define AUDIT_UNUSED_BITS	0x0FFFFC00
+
+
 /* Rule fields */
 				/* These are useful when checking the
 				 * task structure at task creation time
@@ -128,8 +135,28 @@
 #define AUDIT_ARG2      (AUDIT_ARG0+2)
 #define AUDIT_ARG3      (AUDIT_ARG0+3)
 
-#define AUDIT_NEGATE    0x80000000
+#define AUDIT_NEGATE			0x80000000
 
+/* These are the supported operators.
+ *	4  2  1
+ *	=  >  <
+ *	-------
+ *	0  0  0		0	nonsense
+ *	0  0  1		1	<
+ *	0  1  0		2	>
+ *	0  1  1		3	!=
+ *	1  0  0		4	=
+ *	1  0  1		5	<=
+ *	1  1  0		6	>=
+ *	1  1  1		7	all operators
+ */
+#define AUDIT_LESS_THAN			0x10000000
+#define AUDIT_GREATER_THAN		0x20000000
+#define AUDIT_NOT_EQUAL			0x30000000
+#define AUDIT_EQUAL			0x40000000
+#define AUDIT_LESS_THAN_OR_EQUAL	(AUDIT_LESS_THAN|AUDIT_EQUAL)
+#define AUDIT_GREATER_THAN_OR_EQUAL	(AUDIT_GREATER_THAN|AUDIT_EQUAL)
+#define AUDIT_OPERATORS			(AUDIT_EQUAL|AUDIT_NOT_EQUAL)
 
 /* Status symbols */
 				/* Mask values */
diff --git a/kernel/auditsc.c b/kernel/auditsc.c
index 51a4f58a..95076fa 100644
--- a/kernel/auditsc.c
+++ b/kernel/auditsc.c
@@ -2,6 +2,7 @@
  * Handles all system-call specific auditing features.
  *
  * Copyright 2003-2004 Red Hat Inc., Durham, North Carolina.
+ * Copyright (C) 2005 IBM Corporation
  * All Rights Reserved.
  *
  * This program is free software; you can redistribute it and/or modify
@@ -27,6 +28,9 @@
  * this file -- see entry.S) is based on a GPL'd patch written by
  * okir@suse.de and Copyright 2003 SuSE Linux AG.
  *
+ * The support of additional filter rules compares (>, <, >=, <=) was
+ * added by Dustin Kirkland <dustin.kirkland@us.ibm.com>, 2005.
+ *
  */
 
 #include <linux/init.h>
@@ -252,6 +256,7 @@
 				  struct list_head *list)
 {
 	struct audit_entry  *entry;
+	int i;
 
 	/* Do not use the _rcu iterator here, since this is the only
 	 * addition routine. */
@@ -261,6 +266,16 @@
 		}
 	}
 
+	for (i = 0; i < rule->field_count; i++) {
+		if (rule->fields[i] & AUDIT_UNUSED_BITS)
+			return -EINVAL;
+		if ( rule->fields[i] & AUDIT_NEGATE )
+			rule->fields[i] |= AUDIT_NOT_EQUAL;
+		else if ( (rule->fields[i] & AUDIT_OPERATORS) == 0 )
+			rule->fields[i] |= AUDIT_EQUAL;
+		rule->fields[i] &= (~AUDIT_NEGATE);
+	}
+
 	if (!(entry = kmalloc(sizeof(*entry), GFP_KERNEL)))
 		return -ENOMEM;
 	if (audit_copy_rule(&entry->rule, rule)) {
@@ -394,6 +409,26 @@
 	return err;
 }
 
+static int audit_comparator(const u32 left, const u32 op, const u32 right)
+{
+	switch (op) {
+	case AUDIT_EQUAL:
+		return (left == right);
+	case AUDIT_NOT_EQUAL:
+		return (left != right);
+	case AUDIT_LESS_THAN:
+		return (left < right);
+	case AUDIT_LESS_THAN_OR_EQUAL:
+		return (left <= right);
+	case AUDIT_GREATER_THAN:
+		return (left > right);
+	case AUDIT_GREATER_THAN_OR_EQUAL:
+		return (left >= right);
+	default:
+		return -EINVAL;
+	}
+}
+
 /* Compare a task_struct with an audit_rule.  Return 1 on match, 0
  * otherwise. */
 static int audit_filter_rules(struct task_struct *tsk,
@@ -404,62 +439,63 @@
 	int i, j;
 
 	for (i = 0; i < rule->field_count; i++) {
-		u32 field  = rule->fields[i] & ~AUDIT_NEGATE;
+		u32 field  = rule->fields[i] & ~AUDIT_OPERATORS;
+		u32 op  = rule->fields[i] & AUDIT_OPERATORS;
 		u32 value  = rule->values[i];
 		int result = 0;
 
 		switch (field) {
 		case AUDIT_PID:
-			result = (tsk->pid == value);
+			result = audit_comparator(tsk->pid, op, value);
 			break;
 		case AUDIT_UID:
-			result = (tsk->uid == value);
+			result = audit_comparator(tsk->uid, op, value);
 			break;
 		case AUDIT_EUID:
-			result = (tsk->euid == value);
+			result = audit_comparator(tsk->euid, op, value);
 			break;
 		case AUDIT_SUID:
-			result = (tsk->suid == value);
+			result = audit_comparator(tsk->suid, op, value);
 			break;
 		case AUDIT_FSUID:
-			result = (tsk->fsuid == value);
+			result = audit_comparator(tsk->fsuid, op, value);
 			break;
 		case AUDIT_GID:
-			result = (tsk->gid == value);
+			result = audit_comparator(tsk->gid, op, value);
 			break;
 		case AUDIT_EGID:
-			result = (tsk->egid == value);
+			result = audit_comparator(tsk->egid, op, value);
 			break;
 		case AUDIT_SGID:
-			result = (tsk->sgid == value);
+			result = audit_comparator(tsk->sgid, op, value);
 			break;
 		case AUDIT_FSGID:
-			result = (tsk->fsgid == value);
+			result = audit_comparator(tsk->fsgid, op, value);
 			break;
 		case AUDIT_PERS:
-			result = (tsk->personality == value);
+			result = audit_comparator(tsk->personality, op, value);
 			break;
 		case AUDIT_ARCH:
-			if (ctx) 
-				result = (ctx->arch == value);
+ 			if (ctx)
+				result = audit_comparator(ctx->arch, op, value);
 			break;
 
 		case AUDIT_EXIT:
 			if (ctx && ctx->return_valid)
-				result = (ctx->return_code == value);
+				result = audit_comparator(ctx->return_code, op, value);
 			break;
 		case AUDIT_SUCCESS:
 			if (ctx && ctx->return_valid) {
 				if (value)
-					result = (ctx->return_valid == AUDITSC_SUCCESS);
+					result = audit_comparator(ctx->return_valid, op, AUDITSC_SUCCESS);
 				else
-					result = (ctx->return_valid == AUDITSC_FAILURE);
+					result = audit_comparator(ctx->return_valid, op, AUDITSC_FAILURE);
 			}
 			break;
 		case AUDIT_DEVMAJOR:
 			if (ctx) {
 				for (j = 0; j < ctx->name_count; j++) {
-					if (MAJOR(ctx->names[j].dev)==value) {
+					if (audit_comparator(MAJOR(ctx->names[j].dev),	op, value)) {
 						++result;
 						break;
 					}
@@ -469,7 +505,7 @@
 		case AUDIT_DEVMINOR:
 			if (ctx) {
 				for (j = 0; j < ctx->name_count; j++) {
-					if (MINOR(ctx->names[j].dev)==value) {
+					if (audit_comparator(MINOR(ctx->names[j].dev), op, value)) {
 						++result;
 						break;
 					}
@@ -479,7 +515,7 @@
 		case AUDIT_INODE:
 			if (ctx) {
 				for (j = 0; j < ctx->name_count; j++) {
-					if (ctx->names[j].ino == value) {
+					if (audit_comparator(ctx->names[j].ino, op, value)) {
 						++result;
 						break;
 					}
@@ -489,19 +525,17 @@
 		case AUDIT_LOGINUID:
 			result = 0;
 			if (ctx)
-				result = (ctx->loginuid == value);
+				result = audit_comparator(ctx->loginuid, op, value);
 			break;
 		case AUDIT_ARG0:
 		case AUDIT_ARG1:
 		case AUDIT_ARG2:
 		case AUDIT_ARG3:
 			if (ctx)
-				result = (ctx->argv[field-AUDIT_ARG0]==value);
+				result = audit_comparator(ctx->argv[field-AUDIT_ARG0], op, value);
 			break;
 		}
 
-		if (rule->fields[i] & AUDIT_NEGATE)
-			result = !result;
 		if (!result)
 			return 0;
 	}
@@ -550,49 +584,48 @@
 
 	rcu_read_lock();
 	if (!list_empty(list)) {
-		    int word = AUDIT_WORD(ctx->major);
-		    int bit  = AUDIT_BIT(ctx->major);
+		int word = AUDIT_WORD(ctx->major);
+		int bit  = AUDIT_BIT(ctx->major);
 
-		    list_for_each_entry_rcu(e, list, list) {
-			    if ((e->rule.mask[word] & bit) == bit
-				&& audit_filter_rules(tsk, &e->rule, ctx, &state)) {
-				    rcu_read_unlock();
-				    return state;
-			    }
-		    }
+		list_for_each_entry_rcu(e, list, list) {
+			if ((e->rule.mask[word] & bit) == bit
+					&& audit_filter_rules(tsk, &e->rule, ctx, &state)) {
+				rcu_read_unlock();
+				return state;
+			}
+		}
 	}
 	rcu_read_unlock();
 	return AUDIT_BUILD_CONTEXT;
 }
 
 static int audit_filter_user_rules(struct netlink_skb_parms *cb,
-			      struct audit_rule *rule,
-			      enum audit_state *state)
+				   struct audit_rule *rule,
+				   enum audit_state *state)
 {
 	int i;
 
 	for (i = 0; i < rule->field_count; i++) {
-		u32 field  = rule->fields[i] & ~AUDIT_NEGATE;
+		u32 field  = rule->fields[i] & ~AUDIT_OPERATORS;
+		u32 op  = rule->fields[i] & AUDIT_OPERATORS;
 		u32 value  = rule->values[i];
 		int result = 0;
 
 		switch (field) {
 		case AUDIT_PID:
-			result = (cb->creds.pid == value);
+			result = audit_comparator(cb->creds.pid, op, value);
 			break;
 		case AUDIT_UID:
-			result = (cb->creds.uid == value);
+			result = audit_comparator(cb->creds.uid, op, value);
 			break;
 		case AUDIT_GID:
-			result = (cb->creds.gid == value);
+			result = audit_comparator(cb->creds.gid, op, value);
 			break;
 		case AUDIT_LOGINUID:
-			result = (cb->loginuid == value);
+			result = audit_comparator(cb->loginuid, op, value);
 			break;
 		}
 
-		if (rule->fields[i] & AUDIT_NEGATE)
-			result = !result;
 		if (!result)
 			return 0;
 	}