实现线程池或者协程池是面试经常需要手写的题型。本文将介绍协程池如何实现。
协程池
池化技术是很重要的一种思想,将一些频繁使用但是创建开销比较大的对象自定义保存起来,反复使用,典型的有线程池、内存池、mysql连接池等等。协程是用户级别更加轻量化的线程,使用协程池来维护一个永远工作不回收的协程池也可以优化业务逻辑的效率。
笔者的协程在思路上借鉴了一些Java线程池七大参数的思想,分为核心协程和非核心协程,初始时只有核心协程工作,如果工作数量比较多还会起一些非核心协程来工作,但是这些非核心协程如果长时间没有工作也会被销毁,这其中涉及到的工作数量比较多的阈值、非核心协程多少时间不工作会被销毁都是可以由用户调整的参数。工作其实就是函数的抽象,本文的协程池通过对于信道的封装定义了一个future对象,用来保存一个任务的运行结果。协程池在启动后,会写一个协程daemon,用来监测协程池的状态,决定是否要起非核心协程。此外,还实现了一个优雅关闭协程池的功能,保证所有协程在完成自己手上的工作后立即退出,然后主程序里面将所有的信道关闭。任务的分发则是通过原生的信道来实现的。
不可否认,这里实现的这个协程池只是一个demo,后续可以参考golang优秀的第三方协程池库进行相应的优化。
代码实现
代码实现如下:
type Future struct {ch chan interface{}
}func (f *Future) Get() interface{} {var res interface{}res, ok := <- f.chfor !ok {time.Sleep(1 * time.Millisecond)res, ok = <- f.ch}return res
}type Work struct {work func(args []interface{}) interface{}args []interface{}res *Future
}type Worker struct {taskChannel chan WorkisCore boolquit chan struct{}keepAliveTime time.Durationwg *sync.WaitGroup
}func NewWorker(taskChannle chan Work, isCore bool, quit chan struct{}, keepAliveTime time.Duration, wg *sync.WaitGroup) *Worker {return &Worker{taskChannel: taskChannle, isCore: isCore, quit: quit, keepAliveTime: keepAliveTime, wg: wg}
}func (w *Worker) Work() {defer w.wg.Done()timer := time.NewTimer(w.keepAliveTime)for {select {case <- timer.C:if !w.isCore {return}case <- w.quit: returncase task := <- w.taskChannel:res := task.work(task.args)task.res.ch <- restimer.Reset(w.keepAliveTime)continue}}
}type ThreadPool struct {workers []*WorkercoreThreads intmaxThreads intkeepAliveTime time.DurationtaskChannel chan Workquit chan struct{}threshold intdiscount intwg *sync.WaitGroup
}func NewThreadPool(coreThreads, maxThreads int, keepAliveTime time.Duration, threshold, discount int) *ThreadPool {taskChannel := make(chan Work, maxThreads * 2 + 1)quit := make(chan struct{}, maxThreads)workers := make([]*Worker, 0)var wg sync.WaitGroupfor i := 0; i < coreThreads; i++ {worker := NewWorker(taskChannel, true, quit, keepAliveTime, &wg)go worker.Work()workers = append(workers, worker)wg.Add(1)}t := &ThreadPool{workers: workers, coreThreads: coreThreads, maxThreads: maxThreads, keepAliveTime: keepAliveTime, taskChannel: taskChannel, quit: quit, threshold: threshold, discount: discount, wg: &wg}go t.daemon()return t
}// Submit 往协程池中提交任务
func (t *ThreadPool) Submit(function func(args []interface{}) interface{}, args []interface{}) *Future {resChannel := make(chan interface{}, 1)future := &Future{ch: resChannel}work := Work{work: function, args: args, res: future}t.taskChannel <- workreturn future
}// 线程池后台监控程序
func (t *ThreadPool) daemon() {if len(t.taskChannel) > t.threshold {// 准备在起的协程数threadNum := min(t.maxThreads - len(t.workers), len(t.taskChannel) / t.discount)for i := 0; i < threadNum; i++ {worker := NewWorker(t.taskChannel, false, t.quit, t.keepAliveTime, t.wg)go worker.Work()t.workers = append(t.workers, worker)t.wg.Add(1)}}
}// Close 优雅关闭协程池
func (t *ThreadPool) Close() {for i := 0; i < t.maxThreads; i++ {t.quit <- struct{}{}}t.wg.Wait()close(t.quit)close(t.taskChannel)
}
一个简单的单元测试例子:
func TestThreadPool(t *testing.T) {tp := NewThreadPool(2, 5, 2 * time.Second, 8, 4)args := []interface{}{3, 4}for i := 0; i < 10; i++ {tp.Submit(func(args []interface{}) interface{} {a := args[0].(int)b := args[1].(int)c := a + btime.Sleep(5 * time.Second)return c}, args)}f := tp.Submit(func(args []interface{}) interface{} {a := args[0].(int)b := args[1].(int)c := a + btime.Sleep(5 * time.Second)return c}, args)res := f.Get()t.Logf("get 4+3=%d", res.(int))tp.Close()
}