# ConcurrentHashMap

# 为什么需要ConcurrentHashMap,HashMap为什么有线程安全问题

在前面的HashMap源码分析文章中我们提到在多个线程访问同一个HashMap对象时要注意线程安全问题,那么HashMap究竟有什么样的线程安全问题呢?

线程安全问题源于多个线程对共享可变变量进行并发修改读取且缺少必要的同步操作,可能导致原子性、可见性、重排序等问题。

下面介绍几种可能出现的线程安全问题

# 线程安全问题举例1: 数据丢失问题

在HashMap源码分析文章中我们了解到在put方法写入key value数据时,如果没有找到对应的key,则会在链表的末端写入一个新的Node。

img.png

如果这时同时有两个线程都要在这个链表尾部添加新的Node节点可能出现什么情况呢?

我们再看一下HashMap中向结尾写入Node节点的代码,是先判断next==null,然后给结尾的Node对象的next字段赋值为新的Node对象。 当线程1和线程2同时执行到if判断语句,同时得到true,然后同时执行if内部的代码,就可能导致线程1先给p.next赋值Node1,然后线程2又给p.next赋值Node2, 导致Node1的数据丢失。

if ((e = p.next) == null){
    p.next=newNode(hash,key,value,null);
}
1
2
3

img.png

# 线程安全问题举例2:size不准确

在HashMap中,保存了一个size字段,来记录当前HashMap中元素的数量,

transient int size;
1

在put方法的实现中,如果增加了新的key value,会通过++size增加size值。

if (++size > threshold)
    resize();
1
2

阅读过《Java并发编程实战》的朋友应该了解Java代码中修改int字段的++操作不是一个原子操作。 相当于如下代码,先读取当前对象的size字段值,给size字段值加1,再赋值给当前对象的size字段。

int sizeVal = this.size;
this.size = sizeVal + 1;
1
2

也可以从JVM字节码的角度来解释一下,++size经过Java编译器编译后,会得到如下若干条指令, 包含了getfield, iadd, putfield几个步骤。由于size位于位于堆内存所有线程均可访问,所以这个多步骤的++size操作可能会出现数据覆盖的情况。

ALOAD 0
DUP
GETFIELD size : I
ICONST_1
IADD
PUTFIELD size : I
1
2
3
4
5
6

另外由于size字段没有使用volatile声明,是不能保证可见性的,也就是一个线程修改了size,另一个线程可能读取不到最新的size值,进一步导致size值出错或size()方法返回的值不准确。

# 线程安全问题举例3:可见性问题

HashMap中的Node数组字段的读取和修改没有使用必要的加锁等同步机制,所以可能出现可见性问题, 比如一个线程对table进行了初始化,但是另一个线程在之后读取还可能读取到null。

transient Node<K,V>[] table;
1

另外节点Node类中的字段,也没有volatile等机制保证可见性,可会出现线程间的可见性问题

static class Node<K,V> implements Map.Entry<K,V> {
    final int hash;
    final K key;
    V value;
    Node<K, V> next;
    /// ...
}
1
2
3
4
5
6
7

# 其他的线程安全问题

其他的线程安全问题还有很多,在实际的面试过程中大部分候选人都会提到HashMap resize形成链表死循环的问题,不过我认为大家牢记原子性、可见性、重排序几个问题即可,不需要死记硬背。

# ConcurrentHashMap源码分析

我们为什么选择JDK17版本的代码进行分析呢,因为ConcurrentHashMap代码还是出现了一些bug(比如sizeCtl和resizeStamp的换算),我们以最新版本的代码分析能够避免走一些弯路 JDK17整体上和JDK8的实现是一致的,除了部分bug的修复。

# ConcurrentHashMap的线程安全性保证

在前面我们提到了HashMap在多线程环境下使用可能出现的各种线程安全问题。 解决线程安全问题可以使用Hashtable或Collections.synchronizedMap,不过Hashtable或synchronizedMap 是在每个方法上增加对象级别的锁,并发性能较差(读请求之间也需要加锁等待)。

使用ConcurrentHashMap能够让我们获得与Hashtable一致的线程安全保证(除了Hashtable的对this加锁这一个特性),不会出现文章前面提到的原子性、可见性、重排序问题。 get读操作不会加锁.

在解决并发问题时,有一些优化的策略,比如分段锁(类似分库分表sharding)降低锁粒度、cas无锁优化、volatile保证可见性等等优化策略,在ConcurrentHashMap实现中展现的淋漓尽致。 学习ConcurrentHashMap能够让我们学习到并发编程的宝贵实践知识。

# ConcurrentHashMap总体设计

img.png

ConcurrentHashMap也是使用数组链表的基本存储方式。内部维护两个Node数组,一个table一个nextTable, nextTable用于扩容时迁移数据。 当一个链表过长时,会转换为红黑树。 为了支持并发环境使用,ConcurrentHashMap内部通过volatile、Unsafe方法读写数组元素、cas添加链表头结点等非常多的方式来解决并发环境下线程安全问题,

ConcurrentHashMap中大部分Node存储数据,还有一些类型的Node(TreeNode, ForwardingNode, ReservationNodes)用于红黑树、resize扩容、computeIfAbsent防止重复执行compute函数等功能。 这些特殊的Node的hash值为负数(普通Node hash值是正整数),在读取遍历时会特殊处理。

ConcurrentHashMap中访问、修改Node数组需要volatile或原子性的读写、cas等,使用的是sun.misc.Unsafe实现。

在Node[]数组中写入链表第一个Node使用cas来实现,其他的修改操作(insert, delete, replace等)需要加锁,使用的是对第一个Node节点加synchronized。

Map中元素数量超过threshold(数组长度 * loadFactor 0.75),会进行扩容。 注意和HashMap不同的是,ConcurrentHashMap虽然可以通过构造函数指定loadFactor,但是只是用来计算初始的数组capacity,对后续的threshold没有影响, 扩容的阈值是按照loadFactory0.75来计算的。

Map元素数量计数使用类似LongAdder的方式来统计。

ConcurrentHashMap为了减少内存使用,做了很多优化,所以一些不是很直观,不过不用担心,让我们对各个代码进行详细分析。

# ConcurrentHashMap字段

先看下ConcurrentHashMap中的默认配置

MAXIMUM_CAPACITY = 1 << 30: Node数组的最大长度 DEFAULT_CAPACITY = 16: Node数组默认初始长度16 LOAD_FACTOR = 0.75f: loadFactor, TREEIFY_THRESHOLD = 8: 链表长度>=8,转变为红黑树Node UNTREEIFY_THRESHOLD = 6: 红黑树节点数量<=6,从红黑树转回链表 MIN_TREEIFY_CAPACITY = 64: treeifyBin时如果数组长度<MIN_TREEIFY_CAPACITY,会优先扩容数组长度而不是转回红黑树 MIN_TRANSFER_STRIDE: 每个线程单次transfer最小的index range长度 RESIZE_STAMP_BITS: sizeCtl字段中用来保存stamp信息的bits位数 MAX_RESIZERS = (1 << (32 - RESIZE_STAMP_BITS)) - 1: 最多有多少个线程可以帮助resize RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS: sizeCtl为了记录stamp需要的bit shift

resizeStamp是什么

特殊的hash值

MOVED = -1: forwarding node,表示这个Node已经transfer到新的table TREEBIN = -2: 表示这个Node是红黑树节点的root RESERVED = -3: 保留节点,用于computeIfAbsent等方法

# 关键字段

ConcurrentHashMap中Node[] table用来是Node数组,nextTable用来扩容。

baseCount用来计数,如果有cas更新冲突会通过CounterCell[] counterCells计数降低冲突

CounterCell[] counterCells: 把计数分散到多个CounterCell上减少cas冲突

transferIndex记录resize时下一个要transfer的index。

sizeCtl: 初始时保存map的capacity初始值,然后用于resize控制。为负数的时候说明table正在初始化或resize,-1说明在初始化,如果是resize时,一部分bit位用来保存当前正在resize的线程的数量信息。

cellsBusy: resize、创建CounterCells时的cas spinlock

transient volatile Node<K,V>[] table;
private transient volatile Node<K,V>[] nextTable;
private transient volatile long baseCount;
private transient volatile int sizeCtl;
private transient volatile int transferIndex;
private transient volatile int cellsBusy;
private transient volatile CounterCell[] counterCells;
1
2
3
4
5
6
7

# Node节点定义

val和next字段增加了volatile,用来保证val和next可见性和防止重排序。

static class Node<K,V> implements Map.Entry<K,V> {
    final int hash;
    final K key;
    volatile V val;
    volatile Node<K, V> next;

    // ...
}
1
2
3
4
5
6
7
8

# get方法实现

get方法不会加锁,通过volatile关键字和tabAt中Unsafe.getObjectVolatile方法保证可见性 执行流程和HashMap类似,先计算key hash值,通过hash值定位到数组中的Node, 对首个Node进行判断,如果不是,进行链表遍历。 对于特殊的Node(hash值小于0),委托给对应的Node的find方法来查找。

get方法是否会因为扩容时链表迁移导致数据查询不正确呢?在ConcurrentHashMap中是不会的, ConcurrentHashMap在迁移时会这种情况进行处理,比如创建新的Node而不是复用之前的Node等来解决next引用的问题,在讲解transfer实现时会详细分析。

public V get(Object key) {
    Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
    // 计算hash
    int h = spread(key.hashCode());
    // 判断Node数组是否为空、长度是否为0、对应的index上的Node是否为空
    if ((tab = table) != null && (n = tab.length) > 0 &&
        (e = tabAt(tab, (n - 1) & h)) != null) {
        // 先判断一下第一个节点,判断hash值相同后判断引用值相等或equals
        if ((eh = e.hash) == h) {
            if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                // 如果找到直接返回
                return e.val;
        }
        // 如果hash小于0,说明是特殊节点,交给对应的Node的find方法
        else if (eh < 0)
            return (p = e.find(h, key)) != null ? p.val : null;
        // 其他情况,遍历链表,对每个Node进行判断,如果找到返回
        while ((e = e.next) != null) {
            if (e.hash == h &&
                ((ek = e.key) == key || (ek != null && key.equals(ek))))
                return e.val;
        }
    }
    // Node数组为空或数组对应的Node为空、或遍历链表没有找到,返回null
    return null;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

spread和HashMap中的hash方法类似,先把hash值和右移16位进行异或,不过又和HASH_BITS(0x7fffffff)进行了与操作,来保证hash值是正整数。

static final int spread(int h) {
    return (h ^ (h >>> 16)) & HASH_BITS;
}
1
2
3

tabAt方法用于获取table数组中某个index的Node,通过Unsafe.getObjectVolatile保证数组元素获取volatile语义。

static final <K,V> Node<K,V> tabAt(Node<K,V>[] tab, int i) {
    return (Node<K,V>)U.getObjectVolatile(tab, ((long)i << ASHIFT) + ABASE);
}
1
2
3

特殊类型的Node的find方法由各自实现

# ReservationNode

ReservationNode是computeIfAbsent和compute方法在写入数据时发现对应的Node数组的index上Node为null,则会在对应的数组上cas创建一个ReservationNode并加锁,加锁之后执行compute方法, ReservationNode的find方法返回null,是因为这时对应的value还没有计算写入完成

static final class ReservationNode<K,V> extends Node<K,V> {
    ReservationNode() {
        super(RESERVED, null, null);
    }

    Node<K,V> find(int h, Object k) {
        return null;
    }
}
1
2
3
4
5
6
7
8
9

# ForwardingNode

在resize时,如果数组中某个index(bucket)上的数据已经迁移到了新table(nextTable),则会在旧table的index上放置 一个ForwardingNode,这样get时如果发现Node是ForwardingNode,就会到新的table去查询数据。 ForwardingNode的find实现是到nextTable中查找数据,其中为了处理nextTable又发生了resize迁移,做了循环判断处理。

static final class ForwardingNode<K,V> extends Node<K,V> {
    final Node<K,V>[] nextTable;
    ForwardingNode(Node<K,V>[] tab) {
        super(MOVED, null, null);
        this.nextTable = tab;
    }

    Node<K,V> find(int h, Object k) {
        // loop to avoid arbitrarily deep recursion on forwarding nodes
        outer: for (Node<K,V>[] tab = nextTable;;) {
            Node<K,V> e; int n;
            if (k == null || tab == null || (n = tab.length) == 0 ||
                (e = tabAt(tab, (n - 1) & h)) == null)
                return null;
            for (;;) {
                int eh; K ek;
                if ((eh = e.hash) == h &&
                    ((ek = e.key) == k || (ek != null && k.equals(ek))))
                    return e;
                if (eh < 0) {
                    if (e instanceof ForwardingNode) {
                        tab = ((ForwardingNode<K,V>)e).nextTable;
                        continue outer;
                    }
                    else
                        return e.find(h, k);
                }
                if ((e = e.next) == null)
                    return null;
            }
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

# TreeBin

# TreeNode

# put方法实现流程

put和putIfAbsent都交给putVal方法来实现。

final V putVal(K key, V value, boolean onlyIfAbsent) {
    if (key == null || value == null) throw new NullPointerException();
    int hash = spread(key.hashCode());
    int binCount = 0;
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh; K fk; V fv;
        if (tab == null || (n = tab.length) == 0)
            tab = initTable();
        else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
            if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value)))
                break;                   // no lock when adding to empty bin
        }
        else if ((fh = f.hash) == MOVED)
            tab = helpTransfer(tab, f);
        else if (onlyIfAbsent // check first node without acquiring lock
                 && fh == hash
                 && ((fk = f.key) == key || (fk != null && key.equals(fk)))
                 && (fv = f.val) != null)
            return fv;
        else {
            V oldVal = null;
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    if (fh >= 0) {
                        binCount = 1;
                        for (Node<K,V> e = f;; ++binCount) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                oldVal = e.val;
                                if (!onlyIfAbsent)
                                    e.val = value;
                                break;
                            }
                            Node<K,V> pred = e;
                            if ((e = e.next) == null) {
                                pred.next = new Node<K,V>(hash, key, value);
                                break;
                            }
                        }
                    }
                    else if (f instanceof TreeBin) {
                        Node<K,V> p;
                        binCount = 2;
                        if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                       value)) != null) {
                            oldVal = p.val;
                            if (!onlyIfAbsent)
                                p.val = value;
                        }
                    }
                    else if (f instanceof ReservationNode)
                        throw new IllegalStateException("Recursive update");
                }
            }
            if (binCount != 0) {
                if (binCount >= TREEIFY_THRESHOLD)
                    treeifyBin(tab, i);
                if (oldVal != null)
                    return oldVal;
                break;
            }
        }
    }
    addCount(1L, binCount);
    return null;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

分步骤拆解分析

计算完hash后,在一个循环中执行,每次循环重新读取table,这是因为可能有数组没创建、链表头结点Node没有创建等需要cas重试、当前Node数组在进行transfer等情况

for (Node<K,V>[] tab = table;;)
1

判断tab为null或tab.length==0的情况,由initTable方法负责初始化Node数组

if (tab == null || (n = tab.length) == 0)
    tab = initTable();
1
2

Node数组不为空后,通过tabAt读取到对应的Node节点,如果发现为null,则尝试通过cas写入新节点。 如果失败,再重新循环,成功则结束循环。

else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
    if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value)))
        break;                   // no lock when adding to empty bin
}
1
2
3
4

如果发现节点hash值为MOVED,说明是ForwardingNode,map正在进行resize,则调用 helpTransfer方法加快resize。

else if ((fh = f.hash) == MOVED)
    tab = helpTransfer(tab, f);
1
2

如果onlyIfAbsent为true并且第一个Node的key和要写入的相等(==或equals),则直接返回。

else if (onlyIfAbsent // check first node without acquiring lock
         && fh == hash
         && ((fk = f.key) == key || (fk != null && key.equals(fk)))
         && (fv = f.val) != null)
    return fv;
1
2
3
4
5

剩下的部分就是要对Node链表或红黑树进行修改了,需要先给头结点加synchronized锁 加完锁之后还需要在锁内再次判断头结点是否还是之前获取到的节点,因为在加锁之前,可能有其他线程对数据进行了修改,比如删除了之前的Node或这个头结点transfer到了新的Node数组,需要检查头节点是否还是同一个,避免不同线程加的锁不一样或写入到旧的已经迁移的数组bucket上。

synchronized (f) {
    if (tabAt(tab, i) == f) {
1
2

加完锁且判断头节点还是之前的节点后,判断如果hash值>=0说明Node是普通Nodel存放的普通的数据,也是链表Node。 通过和HashMap一样的方式遍历链表每个Node,如果发现有相等的key(==或equals),并且不是onlyIfAbsent,则修改Node的val值。 如果到了链表结尾(next==null)没有找到,则追加一个新的Node。 这些判断和操作在当前头结点Node锁的保护下,所以不会出现HashMap的原子性等线程安全问题。

if (fh >= 0) {
    binCount = 1;
    for (Node<K,V> e = f;; ++binCount) {
        K ek;
        if (e.hash == hash &&
            ((ek = e.key) == key ||
             (ek != null && key.equals(ek)))) {
            oldVal = e.val;
            if (!onlyIfAbsent)
                e.val = value;
            break;
        }
        Node<K,V> pred = e;
        if ((e = e.next) == null) {
            pred.next = new Node<K,V>(hash, key, value);
            break;
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

如果头结点Node的hash小于0,可能是TreeBin等类型Node,如果是TreeBin说明是红黑树节点,则委托给TreeBin的putTreeVal方法。

else if (f instanceof TreeBin) {
    Node<K,V> p;
    binCount = 2;
    if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                   value)) != null) {
        oldVal = p.val;
        if (!onlyIfAbsent)
            p.val = value;
    }
}
1
2
3
4
5
6
7
8
9
10

如果发现Node节点是ReservationNode,说明出现了在computeIfAbsent或compute的computeFunction内更新当前ConcurrentHashMap 的情况,且更新的节点在同一个bucket上,目前不支持这种递归更新,抛出异常

else if (f instanceof ReservationNode)
    throw new IllegalStateException("Recursive update");
1
2

最后判断一下binCount,binCount表示的是当前链表的Node数量(如果是红黑树则固定为2),如果达到了TREEIFY_THRESHOLD,调用treeifyBin把当前的链表转为红黑树。

如果oldVal != null,说明已有对应的key,不论是否更新,不会影响Map元素数量,不需要执行后面的resize,直接返回。

if (binCount != 0) {
    if (binCount >= TREEIFY_THRESHOLD)
        treeifyBin(tab, i);
    if (oldVal != null)
        return oldVal;
    break;
}
1
2
3
4
5
6
7

最后,调用addCount,给map元素计数加1,并且会在需要resize时进行扩容。

addCount(1L, binCount);
1

# 数组初始化

ConcurrentHashMap的Node数组table是延迟初始化的,在第一次写入数据时才创建数组,以降低内存占用。 初始化过程是cas更新sizeCtl字段,更新成功说明获取到了锁,再进行一次数组为空的double check,执行数组创建任务。 ConcurrentHashMap可以通过构造函数指定initialCapacity,计算出来的初始数组大小保存在sizeCtl字段,然后在initTable方法中复制到局部变量。 初始化线程通过把sizeCtl cas成-1说明抢占到了初始化的锁,并且在创建完成数组后,把sizeCtl改为n - (n >>> 2)也就是0.75 * n sizeCtl在后面会作为threshold,如果发现数组元素数量sumCount>=sizeCtl,则会进行扩容transfer。 如果一个线程cas成功确发现数组已经不为空,说明有其他线程已经完成了初始化,这时要把sizeCtl恢复成cas之前的值,避免覆盖掉sizeCtl。

private final Node<K,V>[] initTable() {
    Node<K,V>[] tab; int sc;
    // cas可能失败所以需要循环,类似加锁条件等待
    // Node数组不为空或且长度不为0时退出循环
    while ((tab = table) == null || tab.length == 0) {
        // sizeCtl默认是初始的capacity或0,如果<0说明其他线程cas成功在创建数组,当前线程yield让出调度优先级然后重试
        if ((sc = sizeCtl) < 0)
            Thread.yield(); // lost initialization race; just spin
        // 如果sizeCtl>=0,则cas
        else if (U.compareAndSetInt(this, SIZECTL, sc, -1)) {
            try {
                // 如果成功,还需要double check再判断table是否没有
                if ((tab = table) == null || tab.length == 0) {
                    // sc如果大于0说明通过构造函数指定了capacity,否则使用默认的capacity
                    int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                    // 创建Node数组
                    Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                    // 赋值
                    table = tab = nt;
                    // sc设置为扩容的threshold,也就是0.75 * n
                    sc = n - (n >>> 2);
                }
            } finally {
                sizeCtl = sc;
            }
            break;
        }
    }
    return tab;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# 扩容

# 扩容触发时机

ConcurrentHashMap在元素数量>=sizeCtl的时候发起扩容,在一个线程扩容的过程中,其他线程在修改HashMap时,如果发现自己要修改的Node数组的头结点是ForwardingNode, 会发现现在ConcurrentHashMap正在迁移这个bucket,这个线程不会盲目等待,而是加入到ConcurrentHashMap的transfer任务中,transfer任务会把Node数组划分成几个分区,这样 每个线程可以独自负责各自的分区,共同完成transfer任务。

扩容的触发代码在addCount方法中,addCount既修改map元素计数,又负责发起扩容transfer。 当putVal、compute等修改map的方法,给map中新增了元素数,会通过传入增加的数量来让addCount 检查当前是否需要扩容,clear、replaceNode等方法不会增加元素传入的check小于0。

注意addCount时,如果有cas冲突说明并发比较高,只会放一部分线程去检查resize以及执行transfer,所以ConcurrentHashMap的扩容是有可能出现一些延迟的,也就是一个线程执行addCount, 即使增加了计数后超过了threshold,这个线程有可能也不会去执行resize。

binCount: binCount表示的是Node数组上

binCount什么时候为0: cas写入数组Node成功 binCount什么时候为1: put方法链表第一个元素就是要找的key或者链表只有一个key binCount什么时候>1: Node为TreeNode,或查找key时链表遍历的节点长度大于1 binCount什么时候<0: 删除数据的时候(clear, replaceNode删除)

扩容条件: 元素数量 >= sizeCtl(0.75 * n) 并且数组长度 小于 MAXIMUM_CAPACITY(1 << 30)

因为addCount可能出现并发调用,所以要处理并发情况,比如当前线程发现s >= sizeCtl,但是接下来要执行的时候,其他线程已经完成扩容了等情况。

在扩容开始后,sizeCtl会变为负数,并且左边16位保存扩容前的长度信息,也是具体某一次扩容的唯一标识(类似时间戳,所以叫resizeStamp), 右边16位保存resizer也就是参与到transfer的线程的数量。

对于addCount的调用线程,除了第一个进入到transfer的线程,其他线程并不一定需要去帮助transfer,而是可以直接返回,也可以看到addCount中发现在transfer时判断是否要去帮助现有的transfer的条件判断没有那么精确。

private final void addCount(long x, int check) {
    CounterCell[] cs; long b, s;
    if ((cs = counterCells) != null ||
        !U.compareAndSetLong(this, BASECOUNT, b = baseCount, s = b + x)) {
        CounterCell c; long v; int m;
        boolean uncontended = true;
        if (cs == null || (m = cs.length - 1) < 0 ||
            (c = cs[ThreadLocalRandom.getProbe() & m]) == null ||
            !(uncontended =
                U.compareAndSetLong(c, CELLVALUE, v = c.value, v + x))) {
            fullAddCount(x, uncontended);
            return;
        }
        if (check <= 1)
            return;
        s = sumCount();
    }    
    // putVal、compute等可能新增元素的check>=0
    if (check >= 0) {
        Node<K,V>[] tab, nt; int n, sc;
        // s是元素数量,如果s >= sizeCtl 并且table != null
        while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
               (n = tab.length) < MAXIMUM_CAPACITY) {
            // resizeStamp是16bit的数字,首位是1,右侧是n(2的次方)的左侧0的数量,也就是保存了n的数据,这个可以作为一次transfer的一个标识。
            int rs = resizeStamp(n) << RESIZE_STAMP_SHIFT;
            // sz < 0说明其他的线程已经在修改了sizeCtl
            if (sc < 0) {
                // 以下几个判断条件说明如下
                // sc == rs + MAX_RESIZERS: 说明现在resize的线程已经达到上限,当前线程不再参与
                // sc == rs + 1: 说明resize中所有的最后一个线程已经完成transfer任务,每个线程完成后会给sizeCtl-1,减到rs + 1说明都已经完成。
                // (nt == nextTable) == null: 说明要么transfer已经完成,要么transfer线程还没有创建完nextTable
                // transferIndex <= 0 说明已经transfer完了, transferIndex从n-1开始递减依次迁移。或者和nextTable == null一样,transfer线程还没有完成transferIndex = n的赋值。
                if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
                    (nt = nextTable) == null || transferIndex <= 0)
                    break;
                // 检查当前transfer条件满足,cas修改sizeCtl加1记录当前的transfer参与线程
                if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1))
                    // 调用transfer(tab, nt)参与到transfer中,nt != null
                    transfer(tab, nt);
            }
            // ConcurrentHashMap在初始化完成后sizeCtl保存的是threshold,所以是>0的
            // sz >0,会尝试通过cas,修改sizeCtl为 rs + 2,为什么是+2不是+1呢 TODO FIXME
            // 目前这里不好理解,+2 表明初始状态为 +1,不过目前看+0这个值作为初始和判断结束(所有线程完成任务)也没有问题,这里可能是历史设计,之前可能对+0有特殊判断处理,一直遗留下来。
            else if (U.compareAndSetInt(this, SIZECTL, sc,
                                         rs + 2))
                transfer(tab, null);
            // 重新读取count
            s = sumCount();
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

# helpTransfer

put、clear、remove、compute等修改方法如果发现对应的Node数组上的Node是ForwardingNode,说明当前map正在transfer,当前线程 会调用helpTransfer加入到transfer任务中加快transfer,不然当前线程还是要等待transfer完成。

else if ((fh = f.hash) == MOVED)
    tab = helpTransfer(tab, f);
1
2
final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
    Node<K,V>[] nextTab; int sc;
    // 从f.nextTable读取而不是ConcurrentHashMap中读取nextTable字段,避免其他线程transfer完成把nextTable又变为null,又会触发新一轮的transfer
    if (tab != null && (f instanceof ForwardingNode) &&
        (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
        // 计算resizeStamp
        int rs = resizeStamp(tab.length) << RESIZE_STAMP_SHIFT;
        // 判断nextTab还是nextTable,这些判断说明当前还在ForwardingNode对应的transfer中。
        while (nextTab == nextTable && table == tab &&
               (sc = sizeCtl) < 0) {
            // sc == rs + MAX_RESIZERS 避免sizeCtl低位16位溢出做最大线程限制
            // sc == rs + 1说明transfer已经完成
            // transferIndex说明transfer任务已经分配完成,不需要help
            if (sc == rs + MAX_RESIZERS || sc == rs + 1 ||
                transferIndex <= 0)
                break;
            // 如果成功修改sizeCtl为sizeCtl + 1, 则调用transfer(tab, nextTab)加入到transfer中。
            if (U.compareAndSetInt(this, SIZECTL, sc, sc + 1)) {
                transfer(tab, nextTab);
                break;
            }
        }
        return nextTab;
    }
    return table;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# transfer

transfer是扩容的核心逻辑

stride: 线程每次执行迁移的任务区间大小,比如从index (n - 1)包含 到(n - 1 - stride)包含分配给一个线程,这个线程负责这个区间的bucket从旧table 迁移到nextTable stride计算方式为max(NCPU > 1 ? (CPU数量) / 8n : n, MIN_TRANSFER_STRIDE) n是旧的table数组长度,MIN_TRANSFER_STRIDE为16 一个CPU时stride使用n是因为没有多核,多个线程并不能提高transfer效率。

TODO: 为什么是从后向前遍历迁移

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    // stride计算方式为max(NCPU > 1 ? (CPU数量) / 8n : n, MIN_TRANSFER_STRIDE) n是旧的table数组长度,MIN_TRANSFER_STRIDE为16
    // 一个CPU时stride使用n是因为没有多核,多个线程并不能提高transfer效率。
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
        stride = MIN_TRANSFER_STRIDE; // subdivide range
    // 一次transfer(以resizeStamp区分)只会有一个线程传入的nextTab为null
    if (nextTab == null) {            // initiating
        try {
            // 扩容数组长度为两倍
            @SuppressWarnings("unchecked")
            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
            nextTab = nt;
        } catch (Throwable ex) {      // try to cope with OOME
            // 如果出现OutOfMemory Exception,就不要继续扩容了,因为内存不够了
            sizeCtl = Integer.MAX_VALUE;
            return;
        }
        // 修改nextTable为新创建的数组,这样其他线程可以helpTransfer
        nextTable = nextTab;
        // 设置transferIndex为n
        transferIndex = n;
    }
    // 获取nextTable的长度
    int nextn = nextTab.length;
    // 创建一个共享的ForwardingNode,ForwardingNode中包含当前transfer的nextTable,在get方法就可以使用到
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    // advance控制是否要继续寻找任务区间
    boolean advance = true;
    // finishing表示是否完成了整体的transfer
    boolean finishing = false; // to ensure sweep before committing nextTab
    // i 和 bound初始为0;i是要迁移的Node数组index,每次从起始位置开始递减1迁移直到bound
    for (int i = 0, bound = 0;;) {
        Node<K,V> f; int fh;
        while (advance) {
            int nextIndex, nextBound;
            // --i >= bound说明当前区间还没迁移完,初始时i和bound=0,所以这里是i赋值了任务区间之后才会走到这个if条件内
            // 如果i < bound说明当前区间迁移完,这时再判断finishing如果true,也不需要再寻找了
            if (--i >= bound || finishing)
                advance = false;
            // 如果transferIndex <=0 说明也没有任务了,当前线程调用transfer可以返回了,设置i为-1、advance为false,后面的if (i < 0)判断就会退出
            else if ((nextIndex = transferIndex) <= 0) {
                i = -1;
                advance = false;
            }
            // 走到这里说明还有任务,则把transfer cas修改为nextBound,表明nextIndex不包含到nextBound包含这个区间由当前线程迁移
            else if (U.compareAndSetInt
                     (this, TRANSFERINDEX, nextIndex,
                      nextBound = (nextIndex > stride ?
                                   nextIndex - stride : 0))) {
                // 设置bound为nextBound
                bound = nextBound;
                // i遍历初始值为nextIndex - 1,因为nextIndex是从n数组结尾开始的,
                i = nextIndex - 1;
                // 已经找到了区间,退出找任务循环,执行下面的迁移逻辑
                advance = false;
            }
        }
        // 这里检查退出逻辑,
        // i < 0: 任务执行完(--i >= bound且bound=0)或((nextIndex = transferIndex) <= 0条件下设置i=-1)都会导致i < 0)
        // i >= n: i是通过transferIndex开始设置的,n是局部变量,为本次transfer的旧Node数组的长度,但是transferIndex是实例共享变量,如果当前轮次transfer已经完成,下一轮transfer开始,则会出现i > n的情况
        // i + n >= nextn: nextn是读取的传入的nextTab参数或创建的新数组的数组长度, n是传入的tab也就是旧数组的长度。能走到这里说明 0 <=0 < n, 也就是存在 2n > nextn的可能,说明nexttable的长度比oldTable的长度2倍小
        // 说明参数传入的nextTab一定不为null,因为如果为null,nextn=2n。也就是存在一种可能oldTable的长度2倍比nextTable长度长
        // 这个发生可能在helpTransfer中,虽然做了nextTab == nextTable && table == tab判断,但是在第一个判断 TODO
        if (i < 0 || i >= n || i + n >= nextn) {
            int sc;
            // 如果finishing为true,说明transfer整体完成
            if (finishing) {
                // 设置nextTable为null
                nextTable = null;
                // 设置table为nextTable
                table = nextTab;
                // 修改sizeCtl为2n的0.75,也就是3/2
                sizeCtl = (n << 1) - (n >>> 1);
                return;
            }
            // 任务完成但是finishing为false,则通过cas修改sizeCtl,表示任务中的线程数减1
            if (U.compareAndSetInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                // 如果cas之前的sc = resizeStamp(n) << RESIZE_STAMP_SHIFT) + 2,说明这是最后一个线程由这个线程收尾,否则其他的线程退出,
                if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                    return;
                // 最后一个线程收尾,设置finishing和advance为true,然后会执行到上面的修改table、sizeCtl的逻辑
                finishing = advance = true;
                // 将i修改为n,这样当前线程会重新对旧的table进行一遍遍历检查,然后再修改nextTable、table、sizeCtl
                // 为什么需要recheck? 在目前的版本是没有必要的,只是之前版本的逻辑没有删除,在maillist中doug lea进行了确认
                i = n; // recheck before commit
            }
        }
        // 如果对应的数组index上没有元素,则通过cas设置成ForwardingNode
        else if ((f = tabAt(tab, i)) == null)
            advance = casTabAt(tab, i, null, fwd);
        // 如果已经是MOVED,说明已经迁移完了,这个会出现在recheck上
        else if ((fh = f.hash) == MOVED)
            advance = true; // already processed
        else {
            // 其他情况需要对链表头节点加锁,加锁是避免迁移过程中有其他线程修改链表影响迁移的遍历
            synchronized (f) {
                // 加锁完成后需要double check,这是因为加锁之前可能有其他线程把头结点删掉等情况,需要在锁内重新确认f还是头结点
                if (tabAt(tab, i) == f) {
                    Node<K,V> ln, hn;
                    // 判断头结点hash值,>=0说明是普通的链表节点
                    if (fh >= 0) {
                        // runBit和lastRun用于找到一个Node节点,链表上这个Node(lastRun)及后面所有的Node迁移后都落到相同的index上
                        // 因为ConcurrentHashMap数组为2的次方以及每次扩容两倍的特性,可以计算出一个hash值在迁移后的index,要么等于之前的index,要么等于之前的index + 旧table长度
                        // 具体是在低位还是高位,只需看hash & n(n是旧table长度)是否为0,如果为0,是在低位(也就是index不变),如果不为0,是在高位。TODO 补充图示
                        // runBit也是记录的最后一串index相同的Node到底是在低位还是高位,为什么要找出lastRun呢?是为了在保证get等请求遍历能正常遍历不丢数据的情况下,尽量少创建Node。
                        // 对于lastRun前面的节点,如果要复用Node对象,则next指针可能会发生变化,会导致get请求等遍历旧链表时出错。ß
                        int runBit = fh & n;
                        Node<K,V> lastRun = f;
                        for (Node<K,V> p = f.next; p != null; p = p.next) {
                            // 对每个遍历到的Node和n计算&
                            int b = p.hash & n;
                            // 如果和runBit不同,说明之前的lastRun,不是要找的最后一串Node的起始Node,则更新runBit和lastRun
                            if (b != runBit) {
                                runBit = b;
                                lastRun = p;
                            }
                        }
                        // 上面的for循环完成后,lastRun和runBit确定,需要再判断下这一串Node是在低位还是高位
                        // hn是high node ln是low node,迁移时采用的是头插法,这两个指针是已经迁移的链表的头结点
                        // runBit为0,说明在低位
                        if (runBit == 0) {
                            // 保存lastRun为低位的已迁移的链表
                            ln = lastRun;
                            // 高位还没有迁移
                            hn = null;
                        }
                        else {
                            // 否则不等于0,说明lastRun这一串Node要迁移到高位,保存lastRun为hn
                            hn = lastRun;
                            // 低位还没有数据
                            ln = null;
                        }
                        // 然后从链表头还是遍历Node,直到lastRun,因为lastRun及后面的Node已经迁移走了
                        // 这里使用头插法,不会导致get请求遍历时多遍历节点、不会导致Iterator遍历时出现重复Node
                        for (Node<K,V> p = f; p != lastRun; p = p.next) {
                            int ph = p.hash; K pk = p.key; V pv = p.val;
                            // 和n做与运算,判断迁移后是低位还是高位
                            if ((ph & n) == 0)
                                // 等于0迁移到低位
                                ln = new Node<K,V>(ph, pk, pv, ln);
                            else
                                // 不等于0迁移到高位
                                hn = new Node<K,V>(ph, pk, pv, hn);
                        }
                        // 迁移完成后,通过setTabAt更新低位和高位的链表到nextTab上
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        // 迁移完成,给旧的table对应的index上设置为Forwarding节点,这样后续的get请求会转移到nextTable中
                        setTabAt(tab, i, fwd);
                        // advance = true表示当前的index已经完成,可以迁移下一个
                        advance = true;
                    }
                    else if (f instanceof TreeBin) {
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> lo = null, loTail = null;
                        TreeNode<K,V> hi = null, hiTail = null;
                        int lc = 0, hc = 0;
                        for (Node<K,V> e = t.first; e != null; e = e.next) {
                            int h = e.hash;
                            TreeNode<K,V> p = new TreeNode<K,V>
                                (h, e.key, e.val, null, null);
                            if ((h & n) == 0) {
                                if ((p.prev = loTail) == null)
                                    lo = p;
                                else
                                    loTail.next = p;
                                loTail = p;
                                ++lc;
                            }
                            else {
                                if ((p.prev = hiTail) == null)
                                    hi = p;
                                else
                                    hiTail.next = p;
                                hiTail = p;
                                ++hc;
                            }
                        }
                        ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                            (hc != 0) ? new TreeBin<K,V>(lo) : t;
                        hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                            (lc != 0) ? new TreeBin<K,V>(hi) : t;
                        setTabAt(nextTab, i, ln);
                        setTabAt(nextTab, i + n, hn);
                        setTabAt(tab, i, fwd);
                        advance = true;
                    }
                    else if (f instanceof ReservationNode)
                        // 如果递归compute map会出现这种情况,compute方法内修改map触发addCount扩容,但是发现外面还有一个自己创建的ReservationNode,无法迁移处理, TODO WHY
                        throw new IllegalStateException("Recursive update");
                }
            }
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

# 元素计数实现

计数整体的实现和LongAdder非常相似,可以参考我的另一篇LongAdder的详细分析。

为什么不直接使用LongAdder呢?ConcurrentHashMap中有一段相关的解释如下

We need to incorporate a specialization rather than just use a LongAdder in order to access implicit contention-sensing that leads to creation of multiple CounterCells

这句话不太好理解,目前我的理解是可以借助LongAdder直到目前的竞争情况,因为在addCount方法中会根据计数增加时counterCells相关的竞争情况 对是否要判断扩容条件做额外的处理,这个在前面有提到。 不过在代码实现方面和LongAdder几乎完全一致,LongAdder相比AtomicLong更适合统计类的场景也就是写多读少,并且不能保证完全精确性(因为要把baseCount和CounterCell加起来,加的过程中这些数据还会不断变化), ConcurrentHashMap对计数的读操作是size()方法以及判断扩容条件等,读操作相对较少且对精确性不是那么高。

计数存储在volatile的baseCount和CounterCell[]数组上,没有出现高并发竞争时,通过cas修改baseCount值来计数,如果出现竞争(cas有失败),则开始创建CounterCell[]数组, 每个CounterCell内保存一个volatile的count值,CounterCell数组元素延迟创建,也通过case更新,具体使用数组哪个CounterCell使用ThreadLocalRandom.getProbe和数组长度取余获得(同样也是与运算得到因为数组长度也是2的n次方并且每次2倍扩容) 如果CounterCell再失败会修改ThreadLocal的probe值重新rehash尝试其他的CounterCell,rehash再冲突就会尝试扩容

其中使用cellsBusy字段(volatile)通过cas作为自旋锁(0未加锁,成功把0 cas成1的线程说明拿到了自旋锁)

# 增减计数

增加计数的触发来自put、compute、clear等对map进行修改的方法

addCount实现

private final void addCount(long x, int check) {
    CounterCell[] cs; long b, s;
    // counterCells不为null直接进入到if语句中,否则为null说明还没创建counterCells数组,并发不高,尝试用cas修改baseCount
    // 如果cas baseCount失败说明出现并发竞争,则进入if创建counterCells
    if ((cs = counterCells) != null ||
        !U.compareAndSetLong(this, BASECOUNT, b = baseCount, s = b + x)) {
        CounterCell c; long v; int m;
        // uncontended表示当前这个cell是否没有竞争,默认为true即没有竞争(cell数组为空的情况或cas cell值成功的情况)
        boolean uncontended = true;
        // 如果CounterCell数组为空或ThreadLocalRandom.getProbe和数组长度取余对应的CounterCell为null,需要调用fullAddCount创建CounterCell数组或数组上的CounterCell,uncontended为true。
        if (cs == null || (m = cs.length - 1) < 0 ||
            (c = cs[ThreadLocalRandom.getProbe() & m]) == null ||
            // 如果CounterCell不为null,尝试cas,cas失败也调用fullAddCount, uncontended为false
            !(uncontended =
              U.compareAndSetLong(c, CELLVALUE, v = c.value, v + x))) {
            // 调用fullAddCount,完成数组初始化或CounterCell创建或rehash等方式修改计数
            fullAddCount(x, uncontended);
            return;
        }
        if (check <= 1)
            return;
        s = sumCount();
    }
    // ... check检查扩容相关逻辑
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

fullAddCount实现, fullAddCount负责初始化CounterCell数组以及创建对应的CounterCell对象,在CounterCell cas失败的情况下会尝试给线程probe rehash或给CounterCell数组扩容(数组超过CPU核数后不再扩容)

// See LongAdder version for explanation
private final void fullAddCount(long x, boolean wasUncontended) {
    int h;
    // 当前线程的ThreadLocalRandom.getProbe为0,说明还没初始化,要调用localInit初始化一下
    if ((h = ThreadLocalRandom.getProbe()) == 0) {
        ThreadLocalRandom.localInit();      // force initialization
        h = ThreadLocalRandom.getProbe();
        // 初始化之后,对应的数组index会变,所以可以尝试一下,把wasUncontended重新置为true
        wasUncontended = true;
    }
    // collide如果是true表明在cas 修改CounterCell的时候出现了竞争,对于CounterCell不存在的情况是false,这个字段主要用来控制是否要扩容。
    boolean collide = false;                // True if last slot nonempty
    for (;;) {
        CounterCell[] cs; CounterCell c; int n; long v;
        // counterCells不为空,则去找对应的CounterCell对象
        if ((cs = counterCells) != null && (n = cs.length) > 0) {
            // 如果数组对应的CounterCell对象为空
            if ((c = cs[(n - 1) & h]) == null) {
                // cas自旋锁,创建CounterCell对象
                if (cellsBusy == 0) {            // Try to attach new Cell
                    // 乐观创建,能够减少锁粒度,缺点是如果冲突比较大对象可能浪费一些
                    CounterCell r = new CounterCell(x); // Optimistic create
                    // 再check下cellBusy值然后cas自旋锁
                    if (cellsBusy == 0 &&
                        U.compareAndSetInt(this, CELLSBUSY, 0, 1)) {
                        // cas自旋锁成功再double check对应的数组元素是否还是null
                        boolean created = false;
                        try {               // Recheck under lock
                            CounterCell[] rs; int m, j;
                            // FIXME 这里为什么又判断了一些counterCells是不是空呢
                            if ((rs = counterCells) != null &&
                                (m = rs.length) > 0 &&
                                rs[j = (m - 1) & h] == null) {
                                // 创建CounterCell时顺便设置初始值
                                rs[j] = r;
                                created = true;
                            }
                        } finally {
                            // 修改cellsBusy为0,表明锁释放
                            cellsBusy = 0;
                        }
                        // created说明计数增加成功
                        if (created)
                            break;
                        // created false说明计数没有增加成功,需要重试写入
                        continue;           // Slot is now non-empty
                    }
                }
                // collide设置false因为CounterCell刚创建
                collide = false;
            }
            else if (!wasUncontended)       // CAS already known to fail
                // 如果wasUncontended为false,也就是对应的数组index位置有竞争,则跳出if执行后面的probe rehash
                wasUncontended = true;      // Continue after rehash
            // 走到这里说明CounterCell不为空,并且没有contended竞争,通过cas尝试更新
            else if (U.compareAndSetLong(c, CELLVALUE, v = c.value, v + x))
                break;
            // 如果cas修改CounterCell失败会走到这里,如果这时counterCells != cs说明其他线程进行了扩容,或者数组长度已经大于等于CPU了不再扩容,修改collide为false,重新尝试
            else if (counterCells != cs || n >= NCPU)
                collide = false;            // At max size or stale
            else if (!collide)
                // 走到这里说明有竞争冲突 collide修改为true,会再重新rehash一次重试,再出现冲突就会尝试扩容
                collide = true;
            // 如果collide为true,则会尝试cas加锁扩容数组
            else if (cellsBusy == 0 &&
                     U.compareAndSetInt(this, CELLSBUSY, 0, 1)) {
                try {
                    // double check没有其他线程修改了数组
                    if (counterCells == cs) // Expand table unless stale
                        // 双倍扩容
                        counterCells = Arrays.copyOf(cs, n << 1);
                } finally {
                    // 设置为0,释放自旋锁
                    cellsBusy = 0;
                }
                // 修改collide为false,表明不需要扩容
                collide = false;
                // hash值不变,重新尝试
                continue;                   // Retry with expanded table
            }
            // rehash
            h = ThreadLocalRandom.advanceProbe(h);
        }
        // CounterCell数组为空,通过cas cellsBusy字段加锁创建CounterCell数组
        else if (cellsBusy == 0 && counterCells == cs &&
                 U.compareAndSetInt(this, CELLSBUSY, 0, 1)) {
            // init表示是否把addCount传入的x初始化到CounterCell数组中去了
            boolean init = false;
            try {                           // Initialize table
                // double check,判断下引用是否还是之前的引用值,因为cas成功前可能其他线程已经完成了初始化
                if (counterCells == cs) {
                    CounterCell[] rs = new CounterCell[2];
                    rs[h & 1] = new CounterCell(x);
                    counterCells = rs;
                    init = true;
                }
            } finally {
                // 释放锁
                cellsBusy = 0;
            }
            if (init)
                // 如果初始化数组时,写入x完成,函数可以返回了
                break;
        }
        // 如果一个线程发现数组为空、且cas加锁尝试创建数组失败,再尝试下cas修改baseCount,不行再返回循环重试
        else if (U.compareAndSetLong(this, BASECOUNT, v = baseCount, v + x))
            break;                          // Fall back on using base
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

# 获取计数

获取计数实现比较好理解,把baseCount和所有的CounterCell的值加起来即可,并且对溢出Integer.MAX_VALUE的情况进行了处理避免出现负数

public int size() {
    // 计算sumCount
    long n = sumCount();
    // 如果超过Integer.MAX_VALUE,设置为Integer.MAX_VALUE,因为size()方法返回的是int类型的值,要避免溢出出现负数
    return ((n < 0L) ? 0 :
            (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
            (int)n);
}
final long sumCount() {
    CounterCell[] cs = counterCells;
    // 把baseCount和CounterCell数组中各个数组元素的value加起来
    long sum = baseCount;
    if (cs != null) {
        for (CounterCell c : cs)
            if (c != null)
                sum += c.value;
    }
    return sum;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

# computeIfAbsent

V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction)方法,只在某个key不存在时,才会调用mappingFunction对 传入的key计算value值写入到map中,方法能保证mappingFunction只会计算一次,也就是不会出现计算出结果后没有写入的情况(mappingFunction返回null不会写入)

computeIfAbsent方法返回的是key对应的value,已有的value或通过mappingFunction计算的value。

大部分代码和putVal是类似的,我们通过注释理解。

public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
    if (key == null || mappingFunction == null)
        throw new NullPointerException();
    // 计算hash
    int h = spread(key.hashCode());
    V val = null;
    // binCount用于addCount方法辅助判断是否需要扩容用
    int binCount = 0;
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh; K fk; V fv;
        if (tab == null || (n = tab.length) == 0)
            // Node数组为空先初始化
            tab = initTable();
        // 如果数组对应index上的Node为null
        else if ((f = tabAt(tab, i = (n - 1) & h)) == null) {
            // 创建一个ReservationNode占位节点
            Node<K,V> r = new ReservationNode<K,V>();
            // 对这个Node对象加锁
            synchronized (r) {
                // 加完锁cas设置到数组上
                if (casTabAt(tab, i, null, r)) {
                    // binCount = 1 说明这个链表只有一个节点
                    binCount = 1;
                    // 调用mappingFunction计算value, 然后创建Node
                    Node<K,V> node = null;
                    try {
                        if ((val = mappingFunction.apply(key)) != null)
                            node = new Node<K,V>(h, key, val);
                    } finally {
                        // 把用计算出来的新Node替换ReservationNode,mappingFunction如果抛异常,也会替换成null
                        setTabAt(tab, i, node);
                    }
                }
            }
            // binCount != 0说明cas成功了,等于0说明需要重试
            if (binCount != 0)
                break;
        }
        else if ((fh = f.hash) == MOVED)
            // MOVED说明在ConcurrentHashMap在transfer,调用helpTransfer帮助扩容迁移
            tab = helpTransfer(tab, f);
        // 其他情况,先对头结点快速判断一下key是否相等(还是判断hash、引用、equals)
        else if (fh == h    // check first node without acquiring lock
                 && ((fk = f.key) == key || (fk != null && key.equals(fk)))
                 && (fv = f.val) != null)
            // 如果找到了key,返回oldValue,因为这个方法是computeIfAbsent
            return fv;
        else {
            boolean added = false;
            // 其他情况对头结点对象加锁
            synchronized (f) {
                // double check
                if (tabAt(tab, i) == f) {
                    // fh >=0说明是链表节点
                    if (fh >= 0) {
                        binCount = 1;
                        // 遍历链表,寻找匹配的key
                        for (Node<K,V> e = f;; ++binCount) {
                            K ek;
                            if (e.hash == h &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                // 如果找到了,记录oldValue,返回
                                val = e.val;
                                break;
                            }
                            Node<K,V> pred = e;
                            // next == null说明到结尾了,链表没有这个key,则调用mappingFunction创建Node
                            if ((e = e.next) == null) {
                                if ((val = mappingFunction.apply(key)) != null) {
                                    // pred.next不等于null,说明在mappingFunction中又修改了map,出现递归内嵌修改map,抛出异常。
                                    if (pred.next != null)
                                        throw new IllegalStateException("Recursive update");
                                    added = true;
                                    // 新创建的Node追加到结尾
                                    pred.next = new Node<K,V>(h, key, val);
                                }
                                break;
                            }
                        }
                    }
                    // 红黑树节点
                    else if (f instanceof TreeBin) {
                        binCount = 2;
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> r, p;
                        // 如果找到了对应的key,说明key存在,不用计算,记录oldValue
                        if ((r = t.root) != null &&
                            (p = r.findTreeNode(h, key, null)) != null)
                            val = p.val;
                        // 对应的key不存,则调用mappingFunction,然后通过TreeBin.putTreeVal写入
                        else if ((val = mappingFunction.apply(key)) != null) {
                            added = true;
                            t.putTreeVal(h, key, val);
                        }
                    }
                    // computeIfAbsent不能嵌套,也就是computeFunction中不能在ReservationNode上写入数据,所以不要在compute相关函数中修改自身map
                    else if (f instanceof ReservationNode)
                        throw new IllegalStateException("Recursive update");
                }
            }
            // 因为链表写入新元素会导致变长,判断下是否要treeifyBin变成红黑树
            if (binCount != 0) {
                // 超过TREEIFY_THRESHOLD调用treeifyBin
                if (binCount >= TREEIFY_THRESHOLD)
                    treeifyBin(tab, i);
                // 如果没有写入数据,不用调用后面的addCount,直接返回
                if (!added)
                    return val;
                break;
            }
        }
    }
    // 走到这里说明原有map中key不存在,val不为null说明写入了数据,调用addCount
    if (val != null)
        addCount(1L, binCount);
    return val;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

# replaceNode

replaceNode方法是remove、replace等方法的底层实现,执行逻辑是找到对应的key,如果value符合条件(比如和传入的参数equals),则替换为执行的值, 如果替换的值为null,说明要删除这个节点。

replaceNode中和putVal类似的代码逻辑这里就不重复介绍了

重点看下replace替换逻辑,给头结点加锁后,遍历链表 对每个Node进行判断,如果是要找的key,则判断cv,如果cv不是null判断和节点value是否equals,是的话替换,如果cv是null,直接替换。 替换Node的value为传入的参数的value值,如果传入的参数value是null,则要删除这个Node。

注意删除的方式,是将pred.next修改为当前node.next,不会设置当前node.next为null,这是为了避免影响到get、遍历等方法的执行,因为这些遍历逻辑可能恰好刚遍历到正在删除的Node, 我们要保证即使Node节点从链表中删除了,通过这个Node节点的next还能够继续遍历下去。 如果pred是null说明这个节点是头结点,则替换数组设置头结点为next。

for (Node<K,V> e = f, pred = null;;) {
    K ek;
    if (e.hash == hash &&
        ((ek = e.key) == key ||
         (ek != null && key.equals(ek)))) {
        V ev = e.val;
        if (cv == null || cv == ev ||
            (ev != null && cv.equals(ev))) {
            oldVal = ev;
            if (value != null)
                e.val = value;
            else if (pred != null)
                pred.next = e.next;
            else
                setTabAt(tab, i, e.next);
        }
        break;
    }
    pred = e;
    if ((e = e.next) == null)
        break;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

以下是replaceNode方法的实现代码。

final V replaceNode(Object key, V value, Object cv) {
    int hash = spread(key.hashCode());
    for (Node<K,V>[] tab = table;;) {
        Node<K,V> f; int n, i, fh;
        if (tab == null || (n = tab.length) == 0 ||
            (f = tabAt(tab, i = (n - 1) & hash)) == null)
            break;
        else if ((fh = f.hash) == MOVED)
            tab = helpTransfer(tab, f);
        else {
            V oldVal = null;
            boolean validated = false;
            synchronized (f) {
                if (tabAt(tab, i) == f) {
                    if (fh >= 0) {
                        validated = true;
                        for (Node<K,V> e = f, pred = null;;) {
                            K ek;
                            if (e.hash == hash &&
                                ((ek = e.key) == key ||
                                 (ek != null && key.equals(ek)))) {
                                V ev = e.val;
                                if (cv == null || cv == ev ||
                                    (ev != null && cv.equals(ev))) {
                                    oldVal = ev;
                                    if (value != null)
                                        e.val = value;
                                    else if (pred != null)
                                        pred.next = e.next;
                                    else
                                        setTabAt(tab, i, e.next);
                                }
                                break;
                            }
                            pred = e;
                            if ((e = e.next) == null)
                                break;
                        }
                    }
                    else if (f instanceof TreeBin) {
                        validated = true;
                        TreeBin<K,V> t = (TreeBin<K,V>)f;
                        TreeNode<K,V> r, p;
                        if ((r = t.root) != null &&
                            (p = r.findTreeNode(hash, key, null)) != null) {
                            V pv = p.val;
                            if (cv == null || cv == pv ||
                                (pv != null && cv.equals(pv))) {
                                oldVal = pv;
                                if (value != null)
                                    p.val = value;
                                else if (t.removeTreeNode(p))
                                    setTabAt(tab, i, untreeify(t.first));
                            }
                        }
                    }
                    else if (f instanceof ReservationNode)
                        throw new IllegalStateException("Recursive update");
                }
            }
            if (validated) {
                if (oldVal != null) {
                    if (value == null)
                        addCount(-1L, -1);
                    return oldVal;
                }
                break;
            }
        }
    }
    return null;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

# 遍历

entrySet()keySet()values()等方法都需要遍历所有的Node,这些实现都基于Traverser这个类。 在遍历的过程中,要考虑到可能正在resize的情况,如果有Node是ForwardingNode,需要到nextTable中遍历index和index+n这两个位置上的Node数据,然后再返回 到原有的Node数组中继续遍历。还有一种特殊情况,就是nextTable又发生了resize,也就是nextTable中也出现了ForwardingNode,则创建一个栈保存之前的遍历状态, 在nextTable中处理完两个index后出栈会到对应的旧table中。

# Traverser

Traverser结构。

base表示最初的Node数组,比如baseIndex表示最初的Node数组遍历的哪个index了,baseSize是最初的数组的长度。

tab: 当前正在traverse的Node数组。

next: 下一个Node

stack: 用来保存状态的stack栈

spare: spare字段用来复用对象,减少TableStack对象的创建

index: 当前要遍历的index,可能是nextTable中的

baseIndex: 最开始的Node数组的遍历idnex

baseLimit: Traverse还支持遍历map的一部分数据,通过baseLimit控制

baseSize: 初始数组的长度

static class Traverser<K,V> {
    Node<K, V>[] tab;        // current table; updated if resized
    Node<K, V> next;         // the next entry to use
    TableStack<K, V> stack, spare; // to save/restore on ForwardingNodes
    int index;              // index of bin to use next
    int baseIndex;          // current index of initial table
    int baseLimit;          // index bound for initial table
    final int baseSize;     // initial table size

    Traverser(Node<K, V>[] tab, int size, int index, int limit) {
        this.tab = tab;
        this.baseSize = size;
        this.baseIndex = this.index = index;
        this.baseLimit = limit;
        this.next = null;
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

advance方法是遍历Node的实现。

先考虑没有resize的情况

index从0开始,对Node数组上的每个Node遍历。方法是不断使用next。next为null后,通过index = ++baseIndex对index+1遍历数组下一个Node

再考虑存在resize的情况

存在resize的时候,如果遇到ForwardingNode则将当前遍历的位置等信息保存到栈上,然后修改当前状态到nextTable上遍历index和index+baseSize两个位置的Node链表,

如果遍历nextTable中发现nextTable也出现了ForwardingNode说明又出现了resize则会再保存状态入栈继续遍历nextTable的nextTable

这时tab为nextTable,index为oldTable的index,e为null,迁移后在nextTable上的数据位于index和index+baseSize两个位置上。

continue后重新运行for循环逻辑,给t赋值为nextTable, i赋值index,通过e = tabAt(t, i)获取到nextTable中index位置的Node链表。

暂时不考虑e.hash < 0的情况,稍后介绍。这时会走到判断stack的位置if (stack != null),此时stack不为空,因为我们保存了Stack入栈,

所以会执行recoverState(n)逻辑,n为nextTable的长度。

recoverState方法执行时会先给index += s.length,index 此时小于n,因为nextTable的长度是s.length的两倍且index<s.length。 所以第一次执行完recoverState后,index会加上s.length,其他状态不变。

然后重新回到if的开始,这时e指向的是nextTable上oldIndex的位置,也就是迁移到低位的Node链表,如果不为空,会不断next遍历。

遍历完oldIndex后,会重新给i赋值成新的index(此时是高位的index,因为recoverState方法增加了index的值),然后e = tabAt(t, index)获取到 高位的Node链表。

这时会再判断stack不为null,再次执行recoverState,这次执行recoverState,给index再次加上s.length,这时会超过n,因为n是n.length的两倍。

然后执行到if语句逻辑中,栈顶的元素出栈,(出栈的对象会保存到spare以便可以复用)并且把栈顶元素的index、table、length等信息恢复到当前的Traverse 状态中。

对于只有一层栈的情况,stack会变为null,recoverState最后还判断了下如果(index += length) >= n,就给index赋值为++baseIndex,说明继续回到初始的Node链表到下一个index遍历。

如果是nextTable中也出现了ForwardingNode,说明遍历过程中出现了多次transfer,执行逻辑也是类似的。

final Node<K,V> advance() {
    Node<K,V> e;
    // 如果现在next字段不为null
    if ((e = next) != null)
        // 修改e为next.next
        e = e.next;
    for (;;) {
        Node<K,V>[] t; int i, n;  // must use locals in checks
        // 如果e不为空,返回e
        if (e != null)
            // 并且更新next,下次还会从next继续遍历
            return next = e;
        // 判断结束条件,baseIndex >= baseLimit说明最原始的Node数组已经遍历完成
        // (n = t.length) <= (i = index) 说明 并且顺便给i赋值为index
        // i < 0 说明index < 0,可能出现了溢出
        if (baseIndex >= baseLimit || (t = tab) == null ||
            (n = t.length) <= (i = index) || i < 0)
            return next = null;
        // 这里会读取数组对应i位置的Node并且赋值给e,然后再判断hash如果为负数,说明可能是ForwardingNode, TreeBin等特殊Node
        if ((e = tabAt(t, i)) != null && e.hash < 0) {
            if (e instanceof ForwardingNode) {
                // 如果是ForwardingNode,需要前往ForwardingNode遍历nextTable上index和index+baseSize这两个位置的Node链表
                tab = ((ForwardingNode<K,V>)e).nextTable;
                // 修改e为null,开始到
                e = null;
                // 保存当前的遍历状态,然后continue开始从头循环以便遍历nextTable上的两个Node链表
                pushState(t, i, n);
                continue;
            }
            // 如果是TreeBin,则通过红黑树的节点遍历
            else if (e instanceof TreeBin)
                e = ((TreeBin<K,V>)e).first;
            // 头结点是ReservationNode跳过因为Node还没计算创建完
            else
                e = null;
        }
        // 如果有stack,则出栈
        if (stack != null)
            recoverState(n);
        // 没有栈,则给baseIndex递增到Node数组下一个位置继续遍历。
        else if ((index = i + baseSize) >= n)
            index = ++baseIndex; // visit upper slots if present
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

TableStack保存原有的遍历状态,这样遍历完nextTable之后可以返回到之前的状态继续遍历。

static final class TableStack<K,V> {
    int length;
    int index;
    Node<K,V>[] tab;
    TableStack<K,V> next;
}
1
2
3
4
5
6

pushState是保存当前状态入栈的逻辑,其中会判断spare是否为null,如果不是null,则获取spare的栈顶Stack对象使用,目的是尽量少创建TableStack对象。

private void pushState(Node<K,V>[] t, int i, int n) {
    TableStack<K,V> s = spare;  // reuse if possible
    if (s != null)
        // spare = s.next,让spare指向next,如果next不为null,还可以继续复用
        spare = s.next;
    else
        s = new TableStack<K,V>();
    s.tab = t;
    s.length = n;
    s.index = i;
    // s.next = stack入栈,也和spare切断了关联
    s.next = stack;
    stack = s;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14

出栈逻辑

private void recoverState(int n) {
    TableStack<K,V> s; int len;
    // 对于每一层Stack,会执行到这里两次,第一次给index += len,修改了index状态,以便advance中遍历完低位链表后可以遍历高位链表。
    // 读取nextTable数组的高位的Node后,会再次执行到这里,此时while条件成功,会执行出栈,恢复Traverser的index等字段,以便遍历完nextTable高位链表后继续遍历原有的Node数组。  
    //  recoverState(n)会继续执行这里,index +n (len = s.length)后会大于=n
    while ((s = stack) != null && (index += (len = s.length)) >= n) {
        // index >=n 后,说明nextTable上index和index+baseSize的已经遍历完,可以退出当前栈
        n = len;
        // 设置index为s.index,恢复到入栈之前的状态
        index = s.index;
        // 修改table为入栈当时的table
        tab = s.tab;
        // 清理TableStack对象的table引用,避免影响GC
        s.tab = null;
        // 获取下一个栈元素的引用
        TableStack<K,V> next = s.next;
        // s.next指向spare,然后再把spare指向s,spare里的对象可以重复使用了
        s.next = spare; // save for reuse
        // 更新stack相当于出栈
        stack = next;
        spare = s;
    }
    // s == null说明回到的最开始的Node数组,并且校验index += baseSize >=n 则++baseIndex遍历base数组下一个index
    if (s == null && (index += baseSize) >= n)
        index = ++baseIndex;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# 使用注意事项

# computeIfAbsent加锁问题

computeIfAbsent,判断某个key不过不存在,则调用函数计算得到value并切入到map中。 但是要注意的是computeIfAbsent是会对数组对应的Node头结点加锁的。 如果性能要求比较高的场景,建议先get,如果没有再调用computeIfAbsent

比如使用ConcurrentHashMap做计数统计的场景

Map<String, AtomicInteger> counterMap = new ConcurrentHashMap<>();
public void count(String key) {
    AtomicInteger counter = counterMap.get(key);
    if (counter == null) {
        counter = counterMap.computeIfAbsent(key, k -> new AtomicInteger());
    }
    counter.getAndIncrement();
}
1
2
3
4
5
6
7
8

# 问题和解答

# 为什么ConcurrentHashMap的key和value不能为null?

目前我认为null的最大问题在于含义不明确,比如value为null、或compute方法computeFunction返回一个null究竟表示这个key对应的value为null还是key不存在呢,这两个边界很模糊。

# transfer为什么使用头插法?

为了避免Iterator遍历时出现重复元素

# transfer时为什么从n-1开始向0迁移?

# ConcurrentHashMap使用注意事项

# key和value均不能为null

# 总结

  • synchronized加锁或cas自旋锁加锁成功后,还需要再检查下状态,是一种常用的double check模式。
  • 有修改状态时使用cas或synchronized、volatile等方式解决线程安全问题
  • 分段锁减少锁粒度提高并发