Merge pull request #928 from tonistiigi/bake-named-contexts
bake: add named contexts keyspull/948/head
						commit
						60a025b227
					
				@ -0,0 +1,74 @@
 | 
				
			|||||||
 | 
					package waitmap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Map struct {
 | 
				
			||||||
 | 
						mu sync.RWMutex
 | 
				
			||||||
 | 
						m  map[string]interface{}
 | 
				
			||||||
 | 
						ch map[string]chan struct{}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func New() *Map {
 | 
				
			||||||
 | 
						return &Map{
 | 
				
			||||||
 | 
							m:  make(map[string]interface{}),
 | 
				
			||||||
 | 
							ch: make(map[string]chan struct{}),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *Map) Set(key string, value interface{}) {
 | 
				
			||||||
 | 
						m.mu.Lock()
 | 
				
			||||||
 | 
						defer m.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.m[key] = value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ch, ok := m.ch[key]; ok {
 | 
				
			||||||
 | 
							if ch != nil {
 | 
				
			||||||
 | 
								close(ch)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						m.ch[key] = nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (m *Map) Get(ctx context.Context, keys ...string) (map[string]interface{}, error) {
 | 
				
			||||||
 | 
						if len(keys) == 0 {
 | 
				
			||||||
 | 
							return map[string]interface{}{}, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(keys) > 1 {
 | 
				
			||||||
 | 
							out := make(map[string]interface{})
 | 
				
			||||||
 | 
							for _, key := range keys {
 | 
				
			||||||
 | 
								mm, err := m.Get(ctx, key)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return nil, err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								out[key] = mm[key]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return out, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						key := keys[0]
 | 
				
			||||||
 | 
						m.mu.Lock()
 | 
				
			||||||
 | 
						ch, ok := m.ch[key]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							ch = make(chan struct{})
 | 
				
			||||||
 | 
							m.ch[key] = ch
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if ch != nil {
 | 
				
			||||||
 | 
							m.mu.Unlock()
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-ctx.Done():
 | 
				
			||||||
 | 
								return nil, ctx.Err()
 | 
				
			||||||
 | 
							case <-ch:
 | 
				
			||||||
 | 
								m.mu.Lock()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						res := m.m[key]
 | 
				
			||||||
 | 
						m.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return map[string]interface{}{key: res}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -0,0 +1,64 @@
 | 
				
			|||||||
 | 
					package waitmap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/require"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestGetAfter(t *testing.T) {
 | 
				
			||||||
 | 
						m := New()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.Set("foo", "bar")
 | 
				
			||||||
 | 
						m.Set("bar", "baz")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx := context.TODO()
 | 
				
			||||||
 | 
						v, err := m.Get(ctx, "foo", "bar")
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						require.Equal(t, 2, len(v))
 | 
				
			||||||
 | 
						require.Equal(t, "bar", v["foo"])
 | 
				
			||||||
 | 
						require.Equal(t, "baz", v["bar"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						v, err = m.Get(ctx, "foo")
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.Equal(t, 1, len(v))
 | 
				
			||||||
 | 
						require.Equal(t, "bar", v["foo"])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTimeout(t *testing.T) {
 | 
				
			||||||
 | 
						m := New()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.Set("foo", "bar")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
 | 
				
			||||||
 | 
						defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, err := m.Get(ctx, "bar")
 | 
				
			||||||
 | 
						require.Error(t, err)
 | 
				
			||||||
 | 
						require.True(t, errors.Is(err, context.DeadlineExceeded))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestBlocking(t *testing.T) {
 | 
				
			||||||
 | 
						m := New()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m.Set("foo", "bar")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							time.Sleep(100 * time.Millisecond)
 | 
				
			||||||
 | 
							m.Set("bar", "baz")
 | 
				
			||||||
 | 
							time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
 | 
							m.Set("baz", "abc")
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx := context.TODO()
 | 
				
			||||||
 | 
						v, err := m.Get(ctx, "foo", "bar", "baz")
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.Equal(t, 3, len(v))
 | 
				
			||||||
 | 
						require.Equal(t, "bar", v["foo"])
 | 
				
			||||||
 | 
						require.Equal(t, "baz", v["bar"])
 | 
				
			||||||
 | 
						require.Equal(t, "abc", v["baz"])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
					Loading…
					
					
				
		Reference in New Issue