1
\$\begingroup\$

Can I get some guidance around the usage of volatile keyword and design of the existing concurrent dictionary implementation, please.

Here are a few design considerations with my implementation:

  1. Initially I thought I could have a simple dictionary and guard it by a single lock. It would work but I think the throughput would be slow. So then, I decided on a lock per bucket solution.

  2. With a lock per bucket solution, it then occurred to me that another TryAdd operation could potentially resize the dictionary which would change the underlying distribution. To solve this problem, I wrapped my implementation for each public function (with the exception of Count) in a while loop

Questions for reviewers:

  1. The lock per bucket + while loop solution seems like it could work but requires some coordination in the code to get right. Is there any way I can better this solution?

  2. I am always confused by the usage of volatile. I would really appreciate some guidance what is the best use of volatile in general and whether it is really needed in my implementation.

Here is the code:

I'm supporting 4 APIs: TryAdd, TryRemove, ContainsKey, Count


    namespace DictionaryImplementations
    {
        using System;
        using System.Collections.Generic;
        using System.Linq;
        using System.Threading;
    
        public class MyConcurrentDictionary<TKey, TValue>
        {
            // the internal data structures
            internal class Table
            {
                internal readonly object[] locks;
    
                internal readonly List<Entry<TKey, TValue>>[] buckets;
    
                internal Table(object[] locks, List<Entry<TKey, TValue>>[] buckets)
                {
                    this.locks = locks;
                    this.buckets = buckets;
                }
            }
    
            // we don't configure this dynamically. Don't need `volatile` for this
            private double loadFactor;
    
            // we use interlock to increment/decrement this value
            private int count;
    
            // both are `volatile` since we perform assignment when having acquired all locks
            // both assignments would be atomic (valueType and reference ptr)
            private volatile int bucketSize;
            private volatile Table table;
    
            // unchanging fields
            private const int initialBucketSize = 10;
            private const double loadFactorCeiling = 0.5;
    
            public MyConcurrentDictionary()
            {
                object[] locks = new object[initialBucketSize];
                for (int i = 0; i < locks.Length; i++)
                {
                    locks[i] = new object();
                }
    
                List<Entry<TKey, TValue>>[] buckets = new List<Entry<TKey, TValue>>[initialBucketSize];
                for (int i = 0; i < buckets.Length; i++)
                {
                    buckets[i] = new List<Entry<TKey, TValue>>();
                }
    
                Table table = new Table(locks, buckets);
    
                // set values
                this.table = table;
                this.bucketSize = initialBucketSize;
    
                this.count = 0;
                this.loadFactor = loadFactorCeiling;
            }
           
            /// <returns>true if the k/v pair was added, false if key already exists</returns>
            public bool TryAdd(TKey key, TValue value)
            {
                int hashCode = key.GetHashCode();
    
                var localTable = this.table;
                var localBucketSize = this.bucketSize;
    
                int index = 0;
    
                while (true)
                {
                    try
                    {
                        index = hashCode % localBucketSize;
    
                        // acquire the lock
                        this.TryAcquireLocks(index, index + 1);
    
                        if (localTable != this.table)
                        {
                            // the table changed
                            continue;
                        }
    
                        foreach (var entry in localTable.buckets[index])
                        {
                            if (entry.key.Equals(key))
                            {
                                return false;
                            }
                        }
    
                        Entry<TKey, TValue> newEntry = new Entry<TKey, TValue>()
                        {
                            key = key,
                            value = value
                        };
    
                        localTable.buckets[index].Add(newEntry);
    
                        Interlocked.Increment(ref this.count);
                        return true;
                    }
                    finally
                    {
                        // release lock held
                        this.ReleaseLocks(index, index + 1);
    
                        // reset local vars
                        localBucketSize = this.bucketSize;
                        localTable = this.table;
    
                        // attempt resize operation
                        this.TryResize();
                    }
                }
            }
    
            /// <returns>true if the kvp was removed, false if the key wasn't found</returns>
            public bool TryRemove(TKey key, out TValue oldValue)
            {
                oldValue = default(TValue);
    
                int hashCode = key.GetHashCode();
    
                // create local copy of vars
                var localTable = this.table;
                var localBucketSize = this.bucketSize;
    
                int index = 0;
    
                while (true)
                {
                    try
                    {
                        index = hashCode % localBucketSize;
    
                        // acquire the lock
                        this.TryAcquireLocks(index, index + 1);
    
                        if (localTable != this.table)
                        {
                            // the table changed
                            continue;
                        }
    
                        bool found = false;
                        int entryIndex = 0;
                        foreach (var entry in localTable.buckets[index])
                        {
                            if (!entry.key.Equals(key))
                            {
                                entryIndex++;
                            }
                            else
                            {
                                found = true;
                                break;
                            }
                        }
    
                        if (!found)
                        {
                            return false;
                        }
    
                        oldValue = localTable.buckets[index][entryIndex].value;
                        localTable.buckets[index].RemoveAt(entryIndex);
    
                        // `volatile` doesn't work in this hashmap model since we have locks for each bucket
                        // since increment isn't an atomic operation, using `volatile` alone will not help
                        Interlocked.Decrement(ref this.count);
                        return true;
                    }
                    finally
                    {
                        // release lock held
                        this.ReleaseLocks(index, index + 1);
    
                        // reset local vars
                        localBucketSize = this.bucketSize;
                        localTable = this.table;
                    }
                }
            }
    
            /// <returns>the count of this dictionary</returns>
            public int Count()
            {
                // write perspective: we have interlock operations that guarantee atomicity
                // read perspective: reading an integer is atomic by default as well
                return this.count;
            }
    
            /// <returns>true if this key exists, otherwise false</returns>
            public bool ContainsKey(TKey key)
            {
                int hashCode = key.GetHashCode();
    
                // create local copy of vars
                var localTable = this.table;
                var localBucketSize = this.bucketSize;
    
                int index = 0;
    
                while (true)
                {
                    try
                    {
                        index = hashCode % localBucketSize;
    
                        // acquire the lock
                        // in this case, we need to take a lock to guard against collection being modified
                        this.TryAcquireLocks(index, index + 1);
    
                        if (localTable != this.table)
                        {
                            continue;
                        }
    
                        List<Entry<TKey, TValue>> bucket = localTable.buckets[index];
                        return bucket.Any(item => item.key.Equals(key));
                    }
                    finally
                    {
                        // release lock held
                        this.ReleaseLocks(index, index + 1);
    
                        // reset local vars
                        localBucketSize = this.bucketSize;
                        localTable = this.table;
                    }
                }
            }
    
            private void TryResize()
            {
                // `int` read is atomic by default
                double currentLoad = (this.count * (1.0)) / this.bucketSize;
                if (currentLoad < this.loadFactor)
                {
                    return;
                }
    
                // locks are re-entrant for the same thread. So, we should not deadlock when acquiring same lock
                this.TryAcquireAllLocks();
    
                // determine the number of locks needed to be released at end of the routine
                var prevLockCount = this.table.locks.Length;
                try
                {
                    // simple algo, double the hashtable size
                    int newBucketSize = this.bucketSize * 2;
                    object[] newLocks = new object[newBucketSize];
    
                    // copy over existing locks
                    Array.Copy(this.table.locks, newLocks, this.table.locks.Length);
    
                    // create new locks
                    for (int i = this.table.locks.Length; i < newBucketSize; i++)
                    {
                        newLocks[i] = new object();
                    }
    
                    // create new internal bucket
                    List<Entry<TKey, TValue>>[] newBuckets = new List<Entry<TKey, TValue>>[newBucketSize];
                    for (int i = 0; i < newBuckets.Length; i++)
                    {
                        newBuckets[i] = new List<Entry<TKey, TValue>>();
                    }
    
                    // re-compute distribution
                    foreach (List<Entry<TKey, TValue>> bucket in this.table.buckets)
                    {
                        foreach (Entry<TKey, TValue> entry in bucket)
                        {
                            int hashCode = entry.key.GetHashCode();
                            int newIndex = hashCode % newBucketSize;
    
                            newBuckets[newIndex].Add(new Entry<TKey, TValue>() { key = entry.key, value = entry.value });
                        }
                    }
    
                    Table newTable = new Table(newLocks, newBuckets);
    
                    // perform new assignments
                    this.bucketSize = newBucketSize;
                    this.table = newTable;
                }
                finally
                {
                    // release previously held locks
                    this.ReleaseLocks(0, prevLockCount);
                }
            }
    
            private void TryAcquireAllLocks()
            {
                // since `bucketSize` is a `volatile` member variable, if another thread performs a resize operation,
                // that effect will be seen in this thread. This should ensure we hold all the locks by end of this function
                for (int i = 0; i < this.bucketSize; i++)
                {
                    Monitor.Enter(this.table.locks[i]);
                }
            }
    
            private void TryAcquireLocks(int inclusiveStart, int exclusiveEnd)
            {
                for (int i = inclusiveStart; i < exclusiveEnd; i++)
                {
                    Monitor.Enter(this.table.locks[i]);
                }
            }
    
            private void ReleaseLocks(int inclusiveStart, int exclusiveEnd)
            {
                for (int i = inclusiveStart; i < exclusiveEnd; i++)
                {
                    Monitor.Exit(this.table.locks[i]);
                }
            }
        }
    }

And here is a set of basic unit tests to check the implementation. I boosted the number of add operations upto 10k and the UT didn't hang (which would indicate a deadlock). Are there any other corner cases I'm missing?


    using DictionaryImplementations;
    using Microsoft.VisualStudio.TestTools.UnitTesting;
    using System.Collections.Generic;
    using System.Threading.Tasks;
    
    namespace ConcurrentDictionaryTests
    {
        [TestClass]
        public class UnitTest1
        {
            [TestMethod]
            public async Task ThreeThreadAdd()
            {
                MyConcurrentDictionary<int, int> dict = new MyConcurrentDictionary<int, int>();
    
                // two items belonging to same bucket
                var task1 = Task.Run(() => dict.TryAdd(1, 1));
                var sameBucketAsTask1 = Task.Run(() => dict.TryAdd(11, 1));
                
                // another bucket
                var task2 = Task.Run(() => dict.TryAdd(2, 2));
    
                await Task.WhenAll(task1, sameBucketAsTask1, task2);
    
                Assert.AreEqual(3, dict.Count());
            }
    
            [TestMethod]
            public async Task TestConcurrentAdd()
            {
                List<Task> tasks = new List<Task>();
    
                MyConcurrentDictionary<int, int> dict = new MyConcurrentDictionary<int, int>();
                int size = 1000;
                for (int i = 0; i < size; i++)
                {
                    // need to write to local variable as the value can change
                    int curr = i;
                    var task = Task.Run(() => 
                    {
                        dict.TryAdd(curr, curr + 1);
    
                        Assert.IsTrue(dict.ContainsKey(curr));
                    });
                    tasks.Add(task);
                }
    
                await Task.WhenAll(tasks);
    
                Assert.AreEqual(size, dict.Count());
            }
    
            [TestMethod]
            public async Task TestConcurrentRemove()
            {
                List<Task> tasks = new List<Task>();
    
                MyConcurrentDictionary<int, int> dict = new MyConcurrentDictionary<int, int>();
                int addSize = 10000;
                for (int i = 0; i < addSize; i++)
                {
                    // need to write to local variable as the value can change
                    int curr = i;
                    var task = Task.Run(() =>
                    {
                        dict.TryAdd(curr, curr + 1);
                    });
    
                    tasks.Add(task);
                }
    
                await Task.WhenAll(tasks);
    
                tasks.Clear();
    
                int removeSize = 100;
                for (int i = 0; i < removeSize; i++)
                {
                    int curr = i;
                    var task = Task.Run(() =>
                    {
                        var result = dict.TryRemove(curr, out int oldValue);
                        
                        Assert.IsTrue(result);
                        Assert.AreEqual(curr + 1, oldValue);
                        Assert.IsFalse(dict.ContainsKey(curr));
                    });
    
                    tasks.Add(task);
                }
    
                await Task.WhenAll(tasks);
    
                // final size
                Assert.AreEqual(addSize-removeSize, dict.Count());
            }
        }
    }

\$\endgroup\$

0