Java 中线程池,也支持自定义线程池,为啥 Golang 官方没有提供协程池的实现?Golang 官方偏向轻量级的并发, 希望通过 go func() 解决问题。
一般来说,用 waitGroup 结合 channel ,可以实现一个协程池的功能。一个协程池,一般要具有如下三个功能:
package main
import (
"fmt"
"sync"
"testing"
)
// 任务结构体
type Task struct {
ID int
// 任务
Job func()
}
// 协程池结构体
type Pool struct {
// 任务通道
taskChan chan Task
// 工作协程数量
workerCount int
// 等待组
wg sync.WaitGroup
}
// 创建协程池
func NewPool(workerCount int) *Pool {
workChannel := make(chan Task, workerCount)
return &Pool{
taskChan: workChannel,
workerCount: workerCount,
wg: sync.WaitGroup{},
}
}
// 向协程池提交任务
func (p *Pool) SubmitTask(task Task) {
p.taskChan <- task
p.wg.Add(1)
}
// 启动工作协程
func (p *Pool) StartWorkers() {
for i := 0; i < p.workerCount; i++ {
go p.worker()
}
}
// 工作协程
func (p *Pool) worker() {
for task := range p.taskChan {
defer p.wg.Done()
fmt.Printf("Worker received task %d\n", task.ID)
task.Job()
fmt.Printf("Worker completed task %d\n", task.ID)
}
}
func TestThreadPool(t *testing.T) {
// 创建一个协程池,设置工作协程数量为 5
pool := NewPool(5)
// 提交任务到协程池
for i := 1; i < 5; i++ {
task := Task{
ID: i,
Job: func() {
fmt.Printf("Task %d is running\n", i)
},
}
pool.SubmitTask(task)
}
// 启动工作协程
pool.StartWorkers()
// 等待所有任务完成
pool.wg.Wait()
}
执行结果:
=== RUN TestThreadPool
Worker received task 1
Task 5 is running
Worker completed task 1
Worker received task 4
Task 5 is running
Worker completed task 4
Worker received task 2
Task 5 is running
Worker completed task 2
Worker received task 3
Task 5 is running
Worker completed task 3
--- PASS: TestThreadPool (0.00s)
PASS
优化一下上面的代码:
package utils
import (
"context"
"sync"
)
// Semaphore 使用waitGroup和channel实现并发同时控制最大并发量
// 参考golang.org/x/sync.errgroup实现返回err功能
type Semaphore struct {
c chan struct{}
wg sync.WaitGroup
cancel func()
errOnce sync.Once
err error
}
func NewSemaphore(maxSize int) *Semaphore {
return &Semaphore{
c: make(chan struct{}, maxSize),
}
}
func NewSemaphoreWithContext(ctx context.Context, maxSize int) (*Semaphore, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Semaphore{
c: make(chan struct{}, maxSize),
cancel: cancel,
}, ctx
}
func (s *Semaphore) Go(f func() error) {
s.wg.Add(1)
s.c <- struct{}{}
go func() {
defer func() {
if err := recover(); err != nil {
}
}()
defer func() {
<-s.c
s.wg.Done()
}()
if err := f(); err != nil {
s.errOnce.Do(func() {
s.err = err
if s.cancel != nil {
s.cancel()
}
})
}
}()
}
func (s *Semaphore) Wait() error {
s.wg.Wait()
if s.cancel != nil {
s.cancel()
}
return s.err
}
测试代码:
package utils
import (
"math"
"testing"
"time"
"github.com/bmizerany/assert"
)
func sleep1s() error {
time.Sleep(time.Second)
return nil
}
func TestSemaphore(t *testing.T) {
// 最大并发 >= 执行任务数量
sema := NewSemaphore(4)
now := time.Now()
for i := 0; i < 4; i++ {
sema.Go(sleep1s)
}
err := sema.Wait()
assert.Equal(t, nil, err)
sec := math.Round(time.Since(now).Seconds())
assert.Equal(t, 1, int(sec))
// 设置最大并发为2
sema = NewSemaphore(2)
now = time.Now()
for i := 0; i < 4; i++ {
sema.Go(sleep1s)
}
err = sema.Wait()
assert.Equal(t, nil, err)
sec = math.Round(time.Since(now).Seconds())
assert.Equal(t, 2, int(sec))
}
https://github.com/bytedance/gopkg/tree/develop/util/gopool
原理和 Java 线程池原理有点类似
// 如果没使用 NewPool方法创建协程池 会默认 init 建一个 default pool
func init() {
initMetrics()
defaultPool = NewPool("gopool.DefaultPool", 10000, NewConfig())
}
func NewPool(name string, cap int32, config *Config) Pool {
p := &pool{
name: name,
cap: cap,
config: config,
}
return p
}
var taskPool sync.Pool
func init() {
taskPool.New = newTask
}
func newTask() interface{} {
return &task{}
}
func (p *pool) CtxGo(ctx context.Context, f func()) {
t := taskPool.Get().(*task)
t.ctx = ctx
t.f = f
p.taskLock.Lock()
if p.taskHead == nil {
p.taskHead = t
p.taskTail = t
} else {
p.taskTail.next = t
p.taskTail = t
}
p.taskLock.Unlock()
atomic.AddInt32(&p.taskCount, 1)
// 如果 pool 已经被关闭了,就 panic
if atomic.LoadInt32(&p.closed) == 1 {
panic("use closed pool")
}
// 满足以下两个条件:
// 1. task 数量大于阈值
// 2. 目前的 worker 数量小于上限 p.cap(工作协程数)
// 或者目前没有 worker
if (atomic.LoadInt32(&p.taskCount) >= p.config.ScaleThreshold && p.WorkerCount() < atomic.LoadInt32(&p.cap)) || p.WorkerCount() == 0 {
p.incWorkerCount()
w := workerPool.Get().(*worker)
w.pool = p
w.run()
}
}
return
func (w *worker) run() {
go func() {
for {
//select {
//case <-w.stopChan:
// w.close()
// return
//default:
var t *task
w.pool.taskLock.Lock()
if w.pool.taskHead != nil {
t = w.pool.taskHead
w.pool.taskHead = w.pool.taskHead.next
atomic.AddInt32(&w.pool.taskCount, -1)
}
if t == nil {
// 如果没有任务要做了,就释放资源,退出
w.close()
w.pool.taskLock.Unlock()
w.Recycle()
return
}
w.pool.taskLock.Unlock()
func() {
defer func() {
if r := recover(); r != nil {
logs.CtxFatal(t.ctx, "GOPOOL: panic in pool: %s: %v: %s", w.pool.name, r, debug.Stack())
if w.pool.config.EnablePanicMetrics {
panicMetricsClient.EmitCounter(panicKey, 1, metrics.T{Name: "pool", Value: w.pool.name})
}
w.pool.panicHandler(t.ctx, r)
}
}()
t.f()
}()
t.Recycle()
//}
}
}()
}
可能会问,为啥要写个死循环去遍历,假设不写 for 循环, 如果一个任务,run 一次,就创建一个工作协程,这个开销成本比较高,通过循环变了任务队列的方式,不断去取,可以避免创建一些不必要的工作协程。
举个例子,假设有 4个任务,任务1 执行,开启了一个工作协程1, 任务2 执行,开启了一个工作协程2,任务3执行,开启了一个工作协程3, 任务4来了,此时工作协程1执行完毕,去取任务4执行。这样的话,4个任务,只需要3个工作协程,如果工作协程执行足够快,工作协程数会更少。
场景:捞取2个月的数据,然后导出 捞取一个月的动账明细数据,然后进行导出,原流程是一个开始时间,一个结束时间,每次捞取10分钟的数据,每次加10分钟,循环处理。改为并发流程后,先将时间按10分钟分段,每一段做为一个任务,交给协程池去跑。最后再对结果进行汇总。项目实测,导出效率提升10倍以上。
https://github.com/bytedance/gopkg/tree/develop/util/gopool