资源池实现
package pool
import (
"sync"
"io"
"errors"
"log"
)
type Pool struct {
m sync.Mutex
resources chan io.Closer
factory func() (io.Closer, error)
closed bool
}
// 资源池已关闭错误提示
var ErrPoolClosed = errors.New("pool has been closed")
func New(fn func() (io.Closer, error), size uint) (*Pool, error) {
if size <= 0 {
return nil, errors.New("pool size too small")
}
return &Pool{
resources: make(chan io.Closer, size),
factory: fn,
}, nil
}
// 获取一个资源
func (p *Pool) Acquire() (resource io.Closer, err error) {
select {
case r, ok := <-p.resources:
log.Println("Acquire:", "shared resource")
if !ok {
return nil, ErrPoolClosed
}
return r, nil
default:
log.Println("Acquire:", "new resource")
return p.factory()
}
}
// 释放资源回池
func (p *Pool) Release(r io.Closer) {
p.m.Lock()
defer p.m.Unlock()
if p.closed {
r.Close()
return
}
select {
case p.resources <- r:
log.Println("Release:", "in queue")
default:
log.Println("Release:", "closeing")
r.Close()
}
}
// 关闭资源池
func (p *Pool) Close() {
p.m.Lock()
defer p.m.Unlock()
if p.closed {
return
}
p.closed = true
close(p.resources)
for r := range p.resources {
r.Close()
}
}
`</pre>
# 测试
<pre class="line-numbers prism-highlight" data-start="1">`package main
import (
"log"
"io"
"sync/atomic"
"mypackages/pool"
"sync"
"time"
"math/rand"
)
const (
maxGoroutines = 200 // 最大协程数
poolResources = 10 //资源池资源数
)
type dbConn struct {
ID int32
}
func (d *dbConn) Close() error {
log.Println("close connection:", d.ID)
return nil
}
var resourceInt int32
// 创建连接工厂函数
func createConnection() (io.Closer, error) {
id := atomic.AddInt32(&resourceInt, 1) // 原子方法
log.Println("createConn:id=", id)
return &dbConn{
ID: atomic.AddInt32(&id, 1),
}, nil
}
func main() {
var wg sync.WaitGroup
wg.Add(maxGoroutines)
p, err := pool.New(createConnection, poolResources)
if err != nil {
log.Fatalln(err)
}
for query := 0; query < maxGoroutines; query++ {
go func(q int) {
selectQuery(q, p)
wg.Done()
}(query)
}
wg.Wait()
log.Println("close pool")
p.Close()
}
func selectQuery(q int, pool2 *pool.Pool) {
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
dbconn, err := pool2.Acquire()
if err != nil {
log.Println("acquire conn error:", err)
}
defer pool2.Release(dbconn)
time.Sleep(time.Millisecond * time.Duration(rand.Intn(1000)))
log.Printf("Qin:%d.Cin%d", q, dbconn.(*dbConn).ID)
}