Merge pull request #928 from tonistiigi/bake-named-contexts
bake: add named contexts keyspull/948/head
						commit
						60a025b227
					
				@ -0,0 +1,74 @@
 | 
			
		||||
package waitmap
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Map struct {
 | 
			
		||||
	mu sync.RWMutex
 | 
			
		||||
	m  map[string]interface{}
 | 
			
		||||
	ch map[string]chan struct{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func New() *Map {
 | 
			
		||||
	return &Map{
 | 
			
		||||
		m:  make(map[string]interface{}),
 | 
			
		||||
		ch: make(map[string]chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Map) Set(key string, value interface{}) {
 | 
			
		||||
	m.mu.Lock()
 | 
			
		||||
	defer m.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	m.m[key] = value
 | 
			
		||||
 | 
			
		||||
	if ch, ok := m.ch[key]; ok {
 | 
			
		||||
		if ch != nil {
 | 
			
		||||
			close(ch)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	m.ch[key] = nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Map) Get(ctx context.Context, keys ...string) (map[string]interface{}, error) {
 | 
			
		||||
	if len(keys) == 0 {
 | 
			
		||||
		return map[string]interface{}{}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(keys) > 1 {
 | 
			
		||||
		out := make(map[string]interface{})
 | 
			
		||||
		for _, key := range keys {
 | 
			
		||||
			mm, err := m.Get(ctx, key)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			out[key] = mm[key]
 | 
			
		||||
		}
 | 
			
		||||
		return out, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	key := keys[0]
 | 
			
		||||
	m.mu.Lock()
 | 
			
		||||
	ch, ok := m.ch[key]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		ch = make(chan struct{})
 | 
			
		||||
		m.ch[key] = ch
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ch != nil {
 | 
			
		||||
		m.mu.Unlock()
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ctx.Done():
 | 
			
		||||
			return nil, ctx.Err()
 | 
			
		||||
		case <-ch:
 | 
			
		||||
			m.mu.Lock()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := m.m[key]
 | 
			
		||||
	m.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	return map[string]interface{}{key: res}, nil
 | 
			
		||||
}
 | 
			
		||||
@ -0,0 +1,64 @@
 | 
			
		||||
package waitmap
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestGetAfter(t *testing.T) {
 | 
			
		||||
	m := New()
 | 
			
		||||
 | 
			
		||||
	m.Set("foo", "bar")
 | 
			
		||||
	m.Set("bar", "baz")
 | 
			
		||||
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
	v, err := m.Get(ctx, "foo", "bar")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	require.Equal(t, 2, len(v))
 | 
			
		||||
	require.Equal(t, "bar", v["foo"])
 | 
			
		||||
	require.Equal(t, "baz", v["bar"])
 | 
			
		||||
 | 
			
		||||
	v, err = m.Get(ctx, "foo")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, 1, len(v))
 | 
			
		||||
	require.Equal(t, "bar", v["foo"])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestTimeout(t *testing.T) {
 | 
			
		||||
	m := New()
 | 
			
		||||
 | 
			
		||||
	m.Set("foo", "bar")
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	_, err := m.Get(ctx, "bar")
 | 
			
		||||
	require.Error(t, err)
 | 
			
		||||
	require.True(t, errors.Is(err, context.DeadlineExceeded))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestBlocking(t *testing.T) {
 | 
			
		||||
	m := New()
 | 
			
		||||
 | 
			
		||||
	m.Set("foo", "bar")
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		time.Sleep(100 * time.Millisecond)
 | 
			
		||||
		m.Set("bar", "baz")
 | 
			
		||||
		time.Sleep(50 * time.Millisecond)
 | 
			
		||||
		m.Set("baz", "abc")
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	ctx := context.TODO()
 | 
			
		||||
	v, err := m.Get(ctx, "foo", "bar", "baz")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	require.Equal(t, 3, len(v))
 | 
			
		||||
	require.Equal(t, "bar", v["foo"])
 | 
			
		||||
	require.Equal(t, "baz", v["bar"])
 | 
			
		||||
	require.Equal(t, "abc", v["baz"])
 | 
			
		||||
}
 | 
			
		||||
					Loading…
					
					
				
		Reference in New Issue