Web lists-archives.com

[PATCH 01/23] net, sunrpc: convert rpc_cred.cr_count from atomic_t to refcount_t




refcount_t type and corresponding API should be
used instead of atomic_t when the variable is used as
a reference counter. This allows to avoid accidental
refcounter overflows that might lead to use-after-free
situations.

Signed-off-by: Elena Reshetova <elena.reshetova@xxxxxxxxx>
Signed-off-by: Hans Liljestrand <ishkamiel@xxxxxxxxx>
Signed-off-by: Kees Cook <keescook@xxxxxxxxxxxx>
Signed-off-by: David Windsor <dwindsor@xxxxxxxxx>
---
 include/linux/sunrpc/auth.h |  8 ++++----
 net/sunrpc/auth.c           | 12 ++++++------
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/include/linux/sunrpc/auth.h b/include/linux/sunrpc/auth.h
index b1bc62b..bd36e0b 100644
--- a/include/linux/sunrpc/auth.h
+++ b/include/linux/sunrpc/auth.h
@@ -15,7 +15,7 @@
 #include <linux/sunrpc/msg_prot.h>
 #include <linux/sunrpc/xdr.h>
 
-#include <linux/atomic.h>
+#include <linux/refcount.h>
 #include <linux/rcupdate.h>
 #include <linux/uidgid.h>
 #include <linux/utsname.h>
@@ -68,7 +68,7 @@ struct rpc_cred {
 #endif
 	unsigned long		cr_expire;	/* when to gc */
 	unsigned long		cr_flags;	/* various flags */
-	atomic_t		cr_count;	/* ref count */
+	refcount_t		cr_count;	/* ref count */
 
 	kuid_t			cr_uid;
 
@@ -209,7 +209,7 @@ static inline
 struct rpc_cred *	get_rpccred(struct rpc_cred *cred)
 {
 	if (cred != NULL)
-		atomic_inc(&cred->cr_count);
+		refcount_inc(&cred->cr_count);
 	return cred;
 }
 
@@ -226,7 +226,7 @@ struct rpc_cred *	get_rpccred(struct rpc_cred *cred)
 static inline struct rpc_cred *
 get_rpccred_rcu(struct rpc_cred *cred)
 {
-	if (atomic_inc_not_zero(&cred->cr_count))
+	if (refcount_inc_not_zero(&cred->cr_count))
 		return cred;
 	return NULL;
 }
diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c
index 2bff63a..b6439b9 100644
--- a/net/sunrpc/auth.c
+++ b/net/sunrpc/auth.c
@@ -310,7 +310,7 @@ rpcauth_unhash_cred(struct rpc_cred *cred)
 
 	cache_lock = &cred->cr_auth->au_credcache->lock;
 	spin_lock(cache_lock);
-	ret = atomic_read(&cred->cr_count) == 0;
+	ret = refcount_read(&cred->cr_count) == 0;
 	if (ret)
 		rpcauth_unhash_cred_locked(cred);
 	spin_unlock(cache_lock);
@@ -470,12 +470,12 @@ rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
 		list_del_init(&cred->cr_lru);
 		number_cred_unused--;
 		freed++;
-		if (atomic_read(&cred->cr_count) != 0)
+		if (refcount_read(&cred->cr_count) != 0)
 			continue;
 
 		cache_lock = &cred->cr_auth->au_credcache->lock;
 		spin_lock(cache_lock);
-		if (atomic_read(&cred->cr_count) == 0) {
+		if (refcount_read(&cred->cr_count) == 0) {
 			get_rpccred(cred);
 			list_add_tail(&cred->cr_lru, free);
 			rpcauth_unhash_cred_locked(cred);
@@ -642,7 +642,7 @@ rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
 {
 	INIT_HLIST_NODE(&cred->cr_hash);
 	INIT_LIST_HEAD(&cred->cr_lru);
-	atomic_set(&cred->cr_count, 1);
+	refcount_set(&cred->cr_count, 1);
 	cred->cr_auth = auth;
 	cred->cr_ops = ops;
 	cred->cr_expire = jiffies;
@@ -715,12 +715,12 @@ put_rpccred(struct rpc_cred *cred)
 		return;
 	/* Fast path for unhashed credentials */
 	if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) == 0) {
-		if (atomic_dec_and_test(&cred->cr_count))
+		if (refcount_dec_and_test(&cred->cr_count))
 			cred->cr_ops->crdestroy(cred);
 		return;
 	}
 
-	if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
+	if (!refcount_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
 		return;
 	if (!list_empty(&cred->cr_lru)) {
 		number_cred_unused--;
-- 
2.7.4