diff --git a/server/middlewares.go b/server/middlewares.go index 95f0244..e32fde2 100644 --- a/server/middlewares.go +++ b/server/middlewares.go @@ -5,6 +5,8 @@ import ( "context" "os" "sync" + "sync/atomic" + "time" "github.com/Mrs4s/go-cqhttp/coolq" "github.com/Mrs4s/go-cqhttp/global" @@ -73,19 +75,39 @@ func longPolling(bot *coolq.CQBot, maxSize int) handler { if action != "get_updates" { return nil } - mutex.Lock() - defer mutex.Unlock() - if queue.Len() == 0 { - cond.Wait() + var ( + ok int32 + ch = make(chan []interface{}, 1) + timeout = time.Duration(p.Get("timeout").Int()) * time.Second + ) + defer close(ch) + go func() { + mutex.Lock() + defer mutex.Unlock() + if queue.Len() == 0 { + cond.Wait() + } + if atomic.CompareAndSwapInt32(&ok, 0, 1) { + limit := int(p.Get("limit").Int()) + if limit <= 0 || queue.Len() < limit { + limit = queue.Len() + } + ret := make([]interface{}, limit) + for i := 0; i < limit; i++ { + ret[i] = queue.Remove(queue.Front()) + } + ch <- ret + } + }() + if timeout != 0 { + select { + case <-time.After(timeout): + atomic.StoreInt32(&ok, 1) + return coolq.OK([]interface{}{}) + case ret := <-ch: + return coolq.OK(ret) + } } - limit := int(p.Get("limit").Int()) - if limit <= 0 || queue.Len() < limit { - limit = queue.Len() - } - ret := make([]interface{}, limit) - for i := 0; i < limit; i++ { - ret[i] = queue.Remove(queue.Front()) - } - return coolq.OK(ret) + return coolq.OK(<-ch) } }