You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
580 lines
13 KiB
Go
580 lines
13 KiB
Go
2 years ago
|
/*
|
||
|
Copyright The containerd Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
*/
|
||
|
|
||
|
package ttrpc
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"math/rand"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"syscall"
|
||
|
"time"
|
||
|
|
||
|
"github.com/sirupsen/logrus"
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/status"
|
||
|
)
|
||
|
|
||
|
type Server struct {
|
||
|
config *serverConfig
|
||
|
services *serviceSet
|
||
|
codec codec
|
||
|
|
||
|
mu sync.Mutex
|
||
|
listeners map[net.Listener]struct{}
|
||
|
connections map[*serverConn]struct{} // all connections to current state
|
||
|
done chan struct{} // marks point at which we stop serving requests
|
||
|
}
|
||
|
|
||
|
func NewServer(opts ...ServerOpt) (*Server, error) {
|
||
|
config := &serverConfig{}
|
||
|
for _, opt := range opts {
|
||
|
if err := opt(config); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
if config.interceptor == nil {
|
||
|
config.interceptor = defaultServerInterceptor
|
||
|
}
|
||
|
|
||
|
return &Server{
|
||
|
config: config,
|
||
|
services: newServiceSet(config.interceptor),
|
||
|
done: make(chan struct{}),
|
||
|
listeners: make(map[net.Listener]struct{}),
|
||
|
connections: make(map[*serverConn]struct{}),
|
||
|
}, nil
|
||
|
}
|
||
|
|
||
|
// Register registers a map of methods to method handlers
|
||
|
// TODO: Remove in 2.0, does not support streams
|
||
|
func (s *Server) Register(name string, methods map[string]Method) {
|
||
|
s.services.register(name, &ServiceDesc{Methods: methods})
|
||
|
}
|
||
|
|
||
|
func (s *Server) RegisterService(name string, desc *ServiceDesc) {
|
||
|
s.services.register(name, desc)
|
||
|
}
|
||
|
|
||
|
func (s *Server) Serve(ctx context.Context, l net.Listener) error {
|
||
|
s.addListener(l)
|
||
|
defer s.closeListener(l)
|
||
|
|
||
|
var (
|
||
|
backoff time.Duration
|
||
|
handshaker = s.config.handshaker
|
||
|
)
|
||
|
|
||
|
if handshaker == nil {
|
||
|
handshaker = handshakerFunc(noopHandshake)
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
conn, err := l.Accept()
|
||
|
if err != nil {
|
||
|
select {
|
||
|
case <-s.done:
|
||
|
return ErrServerClosed
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
if terr, ok := err.(interface {
|
||
|
Temporary() bool
|
||
|
}); ok && terr.Temporary() {
|
||
|
if backoff == 0 {
|
||
|
backoff = time.Millisecond
|
||
|
} else {
|
||
|
backoff *= 2
|
||
|
}
|
||
|
|
||
|
if max := time.Second; backoff > max {
|
||
|
backoff = max
|
||
|
}
|
||
|
|
||
|
sleep := time.Duration(rand.Int63n(int64(backoff)))
|
||
|
logrus.WithError(err).Errorf("ttrpc: failed accept; backoff %v", sleep)
|
||
|
time.Sleep(sleep)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
backoff = 0
|
||
|
|
||
|
approved, handshake, err := handshaker.Handshake(ctx, conn)
|
||
|
if err != nil {
|
||
|
logrus.WithError(err).Error("ttrpc: refusing connection after handshake")
|
||
|
conn.Close()
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
sc, err := s.newConn(approved, handshake)
|
||
|
if err != nil {
|
||
|
logrus.WithError(err).Error("ttrpc: create connection failed")
|
||
|
conn.Close()
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
go sc.run(ctx)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||
|
s.mu.Lock()
|
||
|
select {
|
||
|
case <-s.done:
|
||
|
default:
|
||
|
// protected by mutex
|
||
|
close(s.done)
|
||
|
}
|
||
|
lnerr := s.closeListeners()
|
||
|
s.mu.Unlock()
|
||
|
|
||
|
ticker := time.NewTicker(200 * time.Millisecond)
|
||
|
defer ticker.Stop()
|
||
|
for {
|
||
|
s.closeIdleConns()
|
||
|
|
||
|
if s.countConnection() == 0 {
|
||
|
break
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return ctx.Err()
|
||
|
case <-ticker.C:
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return lnerr
|
||
|
}
|
||
|
|
||
|
// Close the server without waiting for active connections.
|
||
|
func (s *Server) Close() error {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
select {
|
||
|
case <-s.done:
|
||
|
default:
|
||
|
// protected by mutex
|
||
|
close(s.done)
|
||
|
}
|
||
|
|
||
|
err := s.closeListeners()
|
||
|
for c := range s.connections {
|
||
|
c.close()
|
||
|
delete(s.connections, c)
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *Server) addListener(l net.Listener) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
s.listeners[l] = struct{}{}
|
||
|
}
|
||
|
|
||
|
func (s *Server) closeListener(l net.Listener) error {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
return s.closeListenerLocked(l)
|
||
|
}
|
||
|
|
||
|
func (s *Server) closeListenerLocked(l net.Listener) error {
|
||
|
defer delete(s.listeners, l)
|
||
|
return l.Close()
|
||
|
}
|
||
|
|
||
|
func (s *Server) closeListeners() error {
|
||
|
var err error
|
||
|
for l := range s.listeners {
|
||
|
if cerr := s.closeListenerLocked(l); cerr != nil && err == nil {
|
||
|
err = cerr
|
||
|
}
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (s *Server) addConnection(c *serverConn) error {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
select {
|
||
|
case <-s.done:
|
||
|
return ErrServerClosed
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
s.connections[c] = struct{}{}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) delConnection(c *serverConn) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
delete(s.connections, c)
|
||
|
}
|
||
|
|
||
|
func (s *Server) countConnection() int {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
return len(s.connections)
|
||
|
}
|
||
|
|
||
|
func (s *Server) closeIdleConns() {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
for c := range s.connections {
|
||
|
if st, ok := c.getState(); !ok || st == connStateActive {
|
||
|
continue
|
||
|
}
|
||
|
c.close()
|
||
|
delete(s.connections, c)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type connState int
|
||
|
|
||
|
const (
|
||
|
connStateActive = iota + 1 // outstanding requests
|
||
|
connStateIdle // no requests
|
||
|
connStateClosed // closed connection
|
||
|
)
|
||
|
|
||
|
func (cs connState) String() string {
|
||
|
switch cs {
|
||
|
case connStateActive:
|
||
|
return "active"
|
||
|
case connStateIdle:
|
||
|
return "idle"
|
||
|
case connStateClosed:
|
||
|
return "closed"
|
||
|
default:
|
||
|
return "unknown"
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *Server) newConn(conn net.Conn, handshake interface{}) (*serverConn, error) {
|
||
|
c := &serverConn{
|
||
|
server: s,
|
||
|
conn: conn,
|
||
|
handshake: handshake,
|
||
|
shutdown: make(chan struct{}),
|
||
|
}
|
||
|
c.setState(connStateIdle)
|
||
|
if err := s.addConnection(c); err != nil {
|
||
|
c.close()
|
||
|
return nil, err
|
||
|
}
|
||
|
return c, nil
|
||
|
}
|
||
|
|
||
|
type serverConn struct {
|
||
|
server *Server
|
||
|
conn net.Conn
|
||
|
handshake interface{} // data from handshake, not used for now
|
||
|
state atomic.Value
|
||
|
|
||
|
shutdownOnce sync.Once
|
||
|
shutdown chan struct{} // forced shutdown, used by close
|
||
|
}
|
||
|
|
||
|
func (c *serverConn) getState() (connState, bool) {
|
||
|
cs, ok := c.state.Load().(connState)
|
||
|
return cs, ok
|
||
|
}
|
||
|
|
||
|
func (c *serverConn) setState(newstate connState) {
|
||
|
c.state.Store(newstate)
|
||
|
}
|
||
|
|
||
|
func (c *serverConn) close() error {
|
||
|
c.shutdownOnce.Do(func() {
|
||
|
close(c.shutdown)
|
||
|
})
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *serverConn) run(sctx context.Context) {
|
||
|
type (
|
||
|
response struct {
|
||
|
id uint32
|
||
|
status *status.Status
|
||
|
data []byte
|
||
|
closeStream bool
|
||
|
streaming bool
|
||
|
}
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
ch = newChannel(c.conn)
|
||
|
ctx, cancel = context.WithCancel(sctx)
|
||
|
state connState = connStateIdle
|
||
|
responses = make(chan response)
|
||
|
recvErr = make(chan error, 1)
|
||
|
done = make(chan struct{})
|
||
|
streams = sync.Map{}
|
||
|
active int32
|
||
|
lastStreamID uint32
|
||
|
)
|
||
|
|
||
|
defer c.conn.Close()
|
||
|
defer cancel()
|
||
|
defer close(done)
|
||
|
defer c.server.delConnection(c)
|
||
|
|
||
|
sendStatus := func(id uint32, st *status.Status) bool {
|
||
|
select {
|
||
|
case responses <- response{
|
||
|
// even though we've had an invalid stream id, we send it
|
||
|
// back on the same stream id so the client knows which
|
||
|
// stream id was bad.
|
||
|
id: id,
|
||
|
status: st,
|
||
|
closeStream: true,
|
||
|
}:
|
||
|
return true
|
||
|
case <-c.shutdown:
|
||
|
return false
|
||
|
case <-done:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
go func(recvErr chan error) {
|
||
|
defer close(recvErr)
|
||
|
for {
|
||
|
select {
|
||
|
case <-c.shutdown:
|
||
|
return
|
||
|
case <-done:
|
||
|
return
|
||
|
default: // proceed
|
||
|
}
|
||
|
|
||
|
mh, p, err := ch.recv()
|
||
|
if err != nil {
|
||
|
status, ok := status.FromError(err)
|
||
|
if !ok {
|
||
|
recvErr <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// in this case, we send an error for that particular message
|
||
|
// when the status is defined.
|
||
|
if !sendStatus(mh.StreamID, status) {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if mh.StreamID%2 != 1 {
|
||
|
// enforce odd client initiated identifiers.
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID must be odd for client initiated streams")) {
|
||
|
return
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if mh.Type == messageTypeData {
|
||
|
i, ok := streams.Load(mh.StreamID)
|
||
|
if !ok {
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID is no longer active")) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
sh := i.(*streamHandler)
|
||
|
if mh.Flags&flagNoData != flagNoData {
|
||
|
unmarshal := func(obj interface{}) error {
|
||
|
err := protoUnmarshal(p, obj)
|
||
|
ch.putmbuf(p)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err := sh.data(unmarshal); err != nil {
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data handling error: %v", err)) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if mh.Flags&flagRemoteClosed == flagRemoteClosed {
|
||
|
sh.closeSend()
|
||
|
if len(p) > 0 {
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "data close message cannot include data")) {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
} else if mh.Type == messageTypeRequest {
|
||
|
if mh.StreamID <= lastStreamID {
|
||
|
// enforce odd client initiated identifiers.
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "StreamID cannot be re-used and must increment")) {
|
||
|
return
|
||
|
}
|
||
|
continue
|
||
|
|
||
|
}
|
||
|
lastStreamID = mh.StreamID
|
||
|
|
||
|
// TODO: Make request type configurable
|
||
|
// Unmarshaller which takes in a byte array and returns an interface?
|
||
|
var req Request
|
||
|
if err := c.server.codec.Unmarshal(p, &req); err != nil {
|
||
|
ch.putmbuf(p)
|
||
|
if !sendStatus(mh.StreamID, status.Newf(codes.InvalidArgument, "unmarshal request error: %v", err)) {
|
||
|
return
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
ch.putmbuf(p)
|
||
|
|
||
|
id := mh.StreamID
|
||
|
respond := func(status *status.Status, data []byte, streaming, closeStream bool) error {
|
||
|
select {
|
||
|
case responses <- response{
|
||
|
id: id,
|
||
|
status: status,
|
||
|
data: data,
|
||
|
closeStream: closeStream,
|
||
|
streaming: streaming,
|
||
|
}:
|
||
|
case <-done:
|
||
|
return ErrClosed
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
sh, err := c.server.services.handle(ctx, &req, respond)
|
||
|
if err != nil {
|
||
|
status, _ := status.FromError(err)
|
||
|
if !sendStatus(mh.StreamID, status) {
|
||
|
return
|
||
|
}
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
streams.Store(id, sh)
|
||
|
atomic.AddInt32(&active, 1)
|
||
|
}
|
||
|
// TODO: else we must ignore this for future compat. log this?
|
||
|
}
|
||
|
}(recvErr)
|
||
|
|
||
|
for {
|
||
|
var (
|
||
|
newstate connState
|
||
|
shutdown chan struct{}
|
||
|
)
|
||
|
|
||
|
activeN := atomic.LoadInt32(&active)
|
||
|
if activeN > 0 {
|
||
|
newstate = connStateActive
|
||
|
shutdown = nil
|
||
|
} else {
|
||
|
newstate = connStateIdle
|
||
|
shutdown = c.shutdown // only enable this branch in idle mode
|
||
|
}
|
||
|
if newstate != state {
|
||
|
c.setState(newstate)
|
||
|
state = newstate
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case response := <-responses:
|
||
|
if !response.streaming || response.status.Code() != codes.OK {
|
||
|
p, err := c.server.codec.Marshal(&Response{
|
||
|
Status: response.status.Proto(),
|
||
|
Payload: response.data,
|
||
|
})
|
||
|
if err != nil {
|
||
|
logrus.WithError(err).Error("failed marshaling response")
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if err := ch.send(response.id, messageTypeResponse, 0, p); err != nil {
|
||
|
logrus.WithError(err).Error("failed sending message on channel")
|
||
|
return
|
||
|
}
|
||
|
} else {
|
||
|
var flags uint8
|
||
|
if response.closeStream {
|
||
|
flags = flagRemoteClosed
|
||
|
}
|
||
|
if response.data == nil {
|
||
|
flags = flags | flagNoData
|
||
|
}
|
||
|
if err := ch.send(response.id, messageTypeData, flags, response.data); err != nil {
|
||
|
logrus.WithError(err).Error("failed sending message on channel")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if response.closeStream {
|
||
|
// The ttrpc protocol currently does not support the case where
|
||
|
// the server is localClosed but not remoteClosed. Once the server
|
||
|
// is closing, the whole stream may be considered finished
|
||
|
streams.Delete(response.id)
|
||
|
atomic.AddInt32(&active, -1)
|
||
|
}
|
||
|
case err := <-recvErr:
|
||
|
// TODO(stevvooe): Not wildly clear what we should do in this
|
||
|
// branch. Basically, it means that we are no longer receiving
|
||
|
// requests due to a terminal error.
|
||
|
recvErr = nil // connection is now "closing"
|
||
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, syscall.ECONNRESET) {
|
||
|
// The client went away and we should stop processing
|
||
|
// requests, so that the client connection is closed
|
||
|
return
|
||
|
}
|
||
|
logrus.WithError(err).Error("error receiving message")
|
||
|
// else, initiate shutdown
|
||
|
case <-shutdown:
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var noopFunc = func() {}
|
||
|
|
||
|
func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) {
|
||
|
if len(req.Metadata) > 0 {
|
||
|
md := MD{}
|
||
|
md.fromRequest(req)
|
||
|
ctx = WithMetadata(ctx, md)
|
||
|
}
|
||
|
|
||
|
cancel = noopFunc
|
||
|
if req.TimeoutNano == 0 {
|
||
|
return ctx, cancel
|
||
|
}
|
||
|
|
||
|
ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano))
|
||
|
return ctx, cancel
|
||
|
}
|