package grpchijack

import (
	"context"
	"io"
	"net"
	"strings"
	"sync"
	"time"

	controlapi "github.com/moby/buildkit/api/services/control"
	"github.com/moby/buildkit/session"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

func Dialer(api controlapi.ControlClient) session.Dialer {
	return func(ctx context.Context, proto string, meta map[string][]string) (net.Conn, error) {
		meta = lowerHeaders(meta)
		md := metadata.MD(meta)
		ctx = metadata.NewOutgoingContext(ctx, md)

		stream, err := api.Session(ctx)
		if err != nil {
			return nil, err
		}

		c, _ := streamToConn(stream)
		return c, nil
	}
}

type stream interface {
	Context() context.Context
	SendMsg(m interface{}) error
	RecvMsg(m interface{}) error
}

func streamToConn(stream stream) (net.Conn, <-chan struct{}) {
	closeCh := make(chan struct{})
	c := &conn{stream: stream, buf: make([]byte, 32*1<<10), closeCh: closeCh}
	return c, closeCh
}

type conn struct {
	stream  stream
	buf     []byte
	lastBuf []byte

	closedOnce sync.Once
	readMu     sync.Mutex
	writeMu    sync.Mutex
	closeCh    chan struct{}
}

func (c *conn) Read(b []byte) (n int, err error) {
	c.readMu.Lock()
	defer c.readMu.Unlock()

	if c.lastBuf != nil {
		n := copy(b, c.lastBuf)
		c.lastBuf = c.lastBuf[n:]
		if len(c.lastBuf) == 0 {
			c.lastBuf = nil
		}
		return n, nil
	}
	m := new(controlapi.BytesMessage)
	m.Data = c.buf

	if err := c.stream.RecvMsg(m); err != nil {
		return 0, err
	}
	c.buf = m.Data[:cap(m.Data)]

	n = copy(b, m.Data)
	if n < len(m.Data) {
		c.lastBuf = m.Data[n:]
	}

	return n, nil
}

func (c *conn) Write(b []byte) (int, error) {
	c.writeMu.Lock()
	defer c.writeMu.Unlock()
	m := &controlapi.BytesMessage{Data: b}
	if err := c.stream.SendMsg(m); err != nil {
		return 0, err
	}
	return len(b), nil
}

func (c *conn) Close() (err error) {
	c.closedOnce.Do(func() {
		defer func() {
			close(c.closeCh)
		}()

		if cs, ok := c.stream.(grpc.ClientStream); ok {
			c.writeMu.Lock()
			err = cs.CloseSend()
			c.writeMu.Unlock()
			if err != nil {
				return
			}
		}

		c.readMu.Lock()
		for {
			m := new(controlapi.BytesMessage)
			m.Data = c.buf
			err = c.stream.RecvMsg(m)
			if err != nil {
				if err != io.EOF {
					c.readMu.Unlock()
					return
				}
				err = nil
				break
			}
			c.buf = m.Data[:cap(m.Data)]
			c.lastBuf = append(c.lastBuf, c.buf...)
		}
		c.readMu.Unlock()
	})
	return nil
}

func (c *conn) LocalAddr() net.Addr {
	return dummyAddr{}
}
func (c *conn) RemoteAddr() net.Addr {
	return dummyAddr{}
}
func (c *conn) SetDeadline(t time.Time) error {
	return nil
}
func (c *conn) SetReadDeadline(t time.Time) error {
	return nil
}
func (c *conn) SetWriteDeadline(t time.Time) error {
	return nil
}

type dummyAddr struct {
}

func (d dummyAddr) Network() string {
	return "tcp"
}

func (d dummyAddr) String() string {
	return "localhost"
}

func lowerHeaders(in map[string][]string) map[string][]string {
	out := map[string][]string{}
	for k := range in {
		out[strings.ToLower(k)] = in[k]
	}
	return out
}