diff --git a/tstest/tstest.go b/tstest/tstest.go index 7ccba8004..59f46fc0b 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -6,6 +6,10 @@ package tstest import ( "context" + "math/rand" + "os" + "strconv" + "sync" "testing" "time" @@ -46,3 +50,38 @@ func WaitFor(maxWait time.Duration, try func() error) error { } return err } + +var ( + seed int64 + seedOnce sync.Once +) + +// GetSeed gets the current global random test seed, by default this is based on +// the current time, but can be fixed to a particular value using the +// TS_TEST_SEED environment variable. +func GetSeed(t testing.TB) int64 { + t.Helper() + + seedOnce.Do(func() { + if s := os.Getenv("TS_TEST_SEED"); s != "" { + var err error + seed, err = strconv.ParseInt(s, 10, 64) + if err != nil { + t.Fatalf("invalid TS_TEST_SEED: %v", err) + } + } else { + seed = time.Now().UnixNano() + } + }) + return seed +} + +// SeedRand seeds the standard library global rand with the current test seed. +func SeedRand(t testing.TB) { + t.Helper() + + // Seed is called every time, as other tests may execute code that reseeds + // the global rand. + rand.Seed(GetSeed(t)) + t.Logf("TS_TEST_SEED=%d", seed) +} diff --git a/tstest/tstest_test.go b/tstest/tstest_test.go index e988d5d56..e705ea0d6 100644 --- a/tstest/tstest_test.go +++ b/tstest/tstest_test.go @@ -3,7 +3,9 @@ package tstest -import "testing" +import ( + "testing" +) func TestReplace(t *testing.T) { before := "before" @@ -22,3 +24,10 @@ func TestReplace(t *testing.T) { t.Errorf("before = %q; want %q", before, "before") } } + +func TestGetSeed(t *testing.T) { + t.Setenv("TS_TEST_SEED", "1234") + if got, want := GetSeed(t), int64(1234); got != want { + t.Errorf("GetSeed = %v; want %v", got, want) + } +}