util: add waitmap for target synchronization

Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com>
pull/928/head
Tonis Tiigi 3 years ago
parent 0fc2b5ca85
commit ffa062dc95

@ -0,0 +1,74 @@
package waitmap
import (
"context"
"sync"
)
type Map struct {
mu sync.RWMutex
m map[string]interface{}
ch map[string]chan struct{}
}
func New() *Map {
return &Map{
m: make(map[string]interface{}),
ch: make(map[string]chan struct{}),
}
}
func (m *Map) Set(key string, value interface{}) {
m.mu.Lock()
defer m.mu.Unlock()
m.m[key] = value
if ch, ok := m.ch[key]; ok {
if ch != nil {
close(ch)
}
}
m.ch[key] = nil
}
func (m *Map) Get(ctx context.Context, keys ...string) (map[string]interface{}, error) {
if len(keys) == 0 {
return map[string]interface{}{}, nil
}
if len(keys) > 1 {
out := make(map[string]interface{})
for _, key := range keys {
mm, err := m.Get(ctx, key)
if err != nil {
return nil, err
}
out[key] = mm[key]
}
return out, nil
}
key := keys[0]
m.mu.Lock()
ch, ok := m.ch[key]
if !ok {
ch = make(chan struct{})
m.ch[key] = ch
}
if ch != nil {
m.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ch:
m.mu.Lock()
}
}
res := m.m[key]
m.mu.Unlock()
return map[string]interface{}{key: res}, nil
}

@ -0,0 +1,64 @@
package waitmap
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestGetAfter(t *testing.T) {
m := New()
m.Set("foo", "bar")
m.Set("bar", "baz")
ctx := context.TODO()
v, err := m.Get(ctx, "foo", "bar")
require.NoError(t, err)
require.Equal(t, 2, len(v))
require.Equal(t, "bar", v["foo"])
require.Equal(t, "baz", v["bar"])
v, err = m.Get(ctx, "foo")
require.NoError(t, err)
require.Equal(t, 1, len(v))
require.Equal(t, "bar", v["foo"])
}
func TestTimeout(t *testing.T) {
m := New()
m.Set("foo", "bar")
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
defer cancel()
_, err := m.Get(ctx, "bar")
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}
func TestBlocking(t *testing.T) {
m := New()
m.Set("foo", "bar")
go func() {
time.Sleep(100 * time.Millisecond)
m.Set("bar", "baz")
time.Sleep(50 * time.Millisecond)
m.Set("baz", "abc")
}()
ctx := context.TODO()
v, err := m.Get(ctx, "foo", "bar", "baz")
require.NoError(t, err)
require.Equal(t, 3, len(v))
require.Equal(t, "bar", v["foo"])
require.Equal(t, "baz", v["bar"])
require.Equal(t, "abc", v["baz"])
}
Loading…
Cancel
Save