Golang sync.WaitGroup 源码分析


结构

// WaitGroup类型的数据不可以被复制
type WaitGroup struct {
	noCopy noCopy	// 用来禁止当前结构的类型复制

    // state1 是 64-bit变量:
    //   高32位是计数器counter,也就是活跃的g的个数
    //   低32位表示因执行Wait()而阻塞的g的数量,即waiters
    // state2 表示sema信号量,说明本章代码用到了原语
    // 64-bit的原子操作需要64-bit的对齐,但是32位的编译器只能保证64-bit字段是32位对齐
    // 因此,在32位架构上,我们需要在state()中检查state1是否对齐,
    // 并在需要时动态地 "交换 "字段顺序。
	state1 uint64
	state2 uint32
}

Add

Add

func (wg *WaitGroup) Add(delta int) {
    // 获取state1 和 state2
	statep, semap := wg.state()
    // 竞争检测代码不看
	if race.Enabled {
		_ = *statep // trigger nil deref early
		if delta < 0 {
			// Synchronize decrements with Wait.
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
    // counter加delta
	state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)	// 获取当前活跃的g的数量
    w := uint32(state)		// 获取当前Wait()的次数
    // 竞争检测代码,不看
	if race.Enabled && delta > 0 && v == int32(delta) {
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(semap))
	}
    // 活跃的g个数不能是负数个,有可能delta传的是负数
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
    // 说明先调用的Wait()再调用Add(),正常应该反过来
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
    // Add()执行成功返回
	if v > 0 || w == 0 {
		return
	}
	// 到这一步说明 counter == 0 且 waiters > 0.
	// 如果*statep != state,有可能发生了两种错误情况
    // - Add()和Wait()并发(concurrently)调用
	// - 当counter归零时,waiters数量还在增加
	// 继续检查,保证WaitGroup不被滥用
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 如果WaitGroup使用规范,到这一步counter为0说明无活跃g了
    // 将state1置0,同时释放所有的waiters
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false, 0)
	}
}

state

// 返回存state1和state2,即state和sema,
// 因为32位编译器不能直接对齐64位数据,需要这个函数做对齐工作
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if unsafe.Alignof(wg.state1) == 8 || uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		//如果state1是64-bit则直接原样返回
		return &wg.state1, &wg.state2
	} else {
		// 如果state1是32-bit对齐而不是64-bit对齐,
		// 那么(&state1)+4 就是 64-bit 对齐了
		state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
		return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
	}
}
  • unsafe.Alignof(T Type) 返回对齐值即对其边界,在64位系统中指针宽度为4字节,即对齐边界为8.

  • 下表是32位系统和64位系统各大小数据的对齐边界

    32位系统:

    Sizeof(x)   = 16  Alignof(x)   = 4
    Sizeof(x.a) = 1   Alignof(x.a) = 1 Offsetof(x.a) = 0
    Sizeof(x.b) = 2   Alignof(x.b) = 2 Offsetof(x.b) = 2
    Sizeof(x.c) = 12  Alignof(x.c) = 4 Offsetof(x.c) = 4
    

    64位系统:

    Sizeof(x)   = 32  Alignof(x)   = 8
    Sizeof(x.a) = 1   Alignof(x.a) = 1 Offsetof(x.a) = 0
    Sizeof(x.b) = 2   Alignof(x.b) = 2 Offsetof(x.b) = 2
    Sizeof(x.c) = 24  Alignof(x.c) = 8 Offsetof(x.c) = 8
    
  • unsafe.Pointer(&T) 是一个可以包含任意类型变量的地址的通用指针,所以unsafe.Pointer()常用于各种指针相互转换的桥梁。有四个特有的操作:

    1. 任何类型的指针都可以被转化为Pointer
    2. Pointer可以被转化为任何类型的指针
    3. uintptr可以被转化为Pointer
    4. Pointer可以被转化为uintptr

    不能直接通过 *p 方式来取得真实的变量值,因为不知道变量的具体类型。可以通过 uintptr(unsafe.Pointer(&T)) 来进行运算。

    unsafe.Pointer是可以比较的,并且支持和nil常量比较判断是否为空指针。

  • uintptr() 主要用来进行指针计算,本质是一个整型。一般用 uintptr(unsafe.Pointer(&T))来进行运算(T的类型未知)。

    注意:首先GC不认为uintptr是一个活引用,因此uintptr指向的对象可能被gc回收。其次,如果uintptr关联的对象移动,则其值也不会更新,即uintptr无法保持对变量的引用。

Done

// Done decrements the WaitGroup counter by one.
func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait

Wait

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
    // 获取state1和state2的地址
	statep, semap := wg.state()
	if race.Enabled {
		_ = *statep // trigger nil deref early
		race.Disable()
	}
	for {
        // 原语:原子操作获取state1的值
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)	// counter,活跃的g的数量
		w := uint32(state)		// waiters
        // 如果没有活跃的g了,直接返回
		if v == 0 {
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// CAS操作增加waiters的数量
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 竞争检测代码,不看
			if race.Enabled && w == 0 {
				// Wait must be synchronized with the first Add.
				// Need to model this is as a write to race with the read in Add.
				// As a consequence, can do the write only for the first waiter,
				// otherwise concurrent Waits will race with each other.
				race.Write(unsafe.Pointer(semap))
			}
            // 到这一步,说明当前g要阻塞等待了
            // 原语:根据semap的值也就是state2的地址找到相应的阻塞队列,把当前g放进去,并挂起
			runtime_Semacquire(semap)
            
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}