[ARVADOS] created: 2.1.0-1972-gfb3484425

Git user git at public.arvados.org
Mon Feb 28 17:27:05 UTC 2022


        at  fb348442584fc62bb670c378f80ab4c943e875cc (commit)


commit fb348442584fc62bb670c378f80ab4c943e875cc
Author: Tom Clegg <tom at curii.com>
Date:   Mon Feb 28 12:25:18 2022 -0500

    18808: Handle concurrent uses of same previously unseen token.
    
    Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom at curii.com>

diff --git a/lib/controller/localdb/login_oidc.go b/lib/controller/localdb/login_oidc.go
index e076f7e12..6d6f80f39 100644
--- a/lib/controller/localdb/login_oidc.go
+++ b/lib/controller/localdb/login_oidc.go
@@ -31,6 +31,7 @@ import (
 	"github.com/coreos/go-oidc"
 	lru "github.com/hashicorp/golang-lru"
 	"github.com/jmoiron/sqlx"
+	"github.com/lib/pq"
 	"github.com/sirupsen/logrus"
 	"golang.org/x/oauth2"
 	"google.golang.org/api/option"
@@ -43,6 +44,7 @@ var (
 	tokenCacheNegativeTTL = time.Minute * 5
 	tokenCacheTTL         = time.Minute * 10
 	tokenCacheRaceWindow  = time.Minute
+	pqCodeUniqueViolation = pq.ErrorCode("23505")
 )
 
 type oidcLoginController struct {
@@ -479,7 +481,6 @@ func (ta *oidcTokenAuthorizer) registerToken(ctx context.Context, tok string) er
 	// it's expiring.
 	exp := time.Now().UTC().Add(tokenCacheTTL + tokenCacheRaceWindow)
 
-	var aca arvados.APIClientAuthorization
 	if updating {
 		_, err = tx.ExecContext(ctx, `update api_client_authorizations set expires_at=$1 where api_token=$2`, exp, hmac)
 		if err != nil {
@@ -487,23 +488,44 @@ func (ta *oidcTokenAuthorizer) registerToken(ctx context.Context, tok string) er
 		}
 		ctxlog.FromContext(ctx).WithField("HMAC", hmac).Debug("(*oidcTokenAuthorizer)registerToken: updated api_client_authorizations row")
 	} else {
-		aca, err = ta.ctrl.Parent.CreateAPIClientAuthorization(ctx, ta.ctrl.Cluster.SystemRootToken, *authinfo)
+		aca, err := ta.ctrl.Parent.CreateAPIClientAuthorization(ctx, ta.ctrl.Cluster.SystemRootToken, *authinfo)
 		if err != nil {
 			return err
 		}
-		_, err = tx.ExecContext(ctx, `update api_client_authorizations set api_token=$1, expires_at=$2 where uuid=$3`, hmac, exp, aca.UUID)
+		_, err = tx.ExecContext(ctx, `savepoint upd`)
 		if err != nil {
+			return err
+		}
+		_, err = tx.ExecContext(ctx, `update api_client_authorizations set api_token=$1, expires_at=$2 where uuid=$3`, hmac, exp, aca.UUID)
+		if e, ok := err.(*pq.Error); ok && e.Code == pqCodeUniqueViolation {
+			// unique_violation, given that the above
+			// query did not find a row with matching
+			// api_token, means another thread/process
+			// also received this same token and won the
+			// race to insert it -- in which case this
+			// thread doesn't need to update the database.
+			// Discard the redundant row.
+			_, err = tx.ExecContext(ctx, `rollback to savepoint upd`)
+			if err != nil {
+				return err
+			}
+			_, err = tx.ExecContext(ctx, `delete from api_client_authorizations where uuid=$1`, aca.UUID)
+			if err != nil {
+				return err
+			}
+			ctxlog.FromContext(ctx).WithField("HMAC", hmac).Debug("(*oidcTokenAuthorizer)registerToken: api_client_authorizations row inserted by another thread")
+		} else if err != nil {
+			ctxlog.FromContext(ctx).Errorf("%#v", err)
 			return fmt.Errorf("error adding OIDC access token to database: %w", err)
+		} else {
+			ctxlog.FromContext(ctx).WithFields(logrus.Fields{"UUID": aca.UUID, "HMAC": hmac}).Debug("(*oidcTokenAuthorizer)registerToken: inserted api_client_authorizations row")
 		}
-		aca.APIToken = hmac
-		ctxlog.FromContext(ctx).WithFields(logrus.Fields{"UUID": aca.UUID, "HMAC": hmac}).Debug("(*oidcTokenAuthorizer)registerToken: inserted api_client_authorizations row")
 	}
 	err = tx.Commit()
 	if err != nil {
 		return err
 	}
-	aca.ExpiresAt = exp
-	ta.cache.Add(tok, aca)
+	ta.cache.Add(tok, arvados.APIClientAuthorization{ExpiresAt: exp})
 	return nil
 }
 

commit 4a1826b9e9e7e425fdf5221aa04bc09dff9eb345
Author: Tom Clegg <tom at curii.com>
Date:   Mon Feb 28 11:45:46 2022 -0500

    18808: Test concurrent uses of same previously unseen token.
    
    Arvados-DCO-1.1-Signed-off-by: Tom Clegg <tom at curii.com>

diff --git a/lib/controller/localdb/login_oidc_test.go b/lib/controller/localdb/login_oidc_test.go
index 4778e45f5..b9f0f56e0 100644
--- a/lib/controller/localdb/login_oidc_test.go
+++ b/lib/controller/localdb/login_oidc_test.go
@@ -17,6 +17,7 @@ import (
 	"net/url"
 	"sort"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 
@@ -236,18 +237,49 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
 
 	ctx := auth.NewContext(context.Background(), &auth.Credentials{Tokens: []string{accessToken}})
 	var exp1 time.Time
-	oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) {
-		creds, ok := auth.FromContext(ctx)
-		c.Assert(ok, check.Equals, true)
-		c.Assert(creds.Tokens, check.HasLen, 1)
-		c.Check(creds.Tokens[0], check.Equals, accessToken)
 
-		err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp1)
-		c.Check(err, check.IsNil)
-		c.Check(exp1.Sub(time.Now()) > -time.Second, check.Equals, true)
-		c.Check(exp1.Sub(time.Now()) < time.Second, check.Equals, true)
-		return nil, nil
-	})(ctx, nil)
+	concurrent := 4
+	s.fakeProvider.HoldUserInfo = make(chan *http.Request)
+	s.fakeProvider.ReleaseUserInfo = make(chan struct{})
+	go func() {
+		for i := 0; ; i++ {
+			if i == concurrent {
+				close(s.fakeProvider.ReleaseUserInfo)
+			}
+			<-s.fakeProvider.HoldUserInfo
+		}
+	}()
+	var wg sync.WaitGroup
+	for i := 0; i < concurrent; i++ {
+		i := i
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			_, err := oidcAuthorizer.WrapCalls(func(ctx context.Context, opts interface{}) (interface{}, error) {
+				c.Logf("concurrent req %d/%d", i, concurrent)
+				var exp time.Time
+
+				creds, ok := auth.FromContext(ctx)
+				c.Assert(ok, check.Equals, true)
+				c.Assert(creds.Tokens, check.HasLen, 1)
+				c.Check(creds.Tokens[0], check.Equals, accessToken)
+
+				err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp)
+				c.Check(err, check.IsNil)
+				c.Check(exp.Sub(time.Now()) > -time.Second, check.Equals, true)
+				c.Check(exp.Sub(time.Now()) < time.Second, check.Equals, true)
+				if i == 0 {
+					exp1 = exp
+				}
+				return nil, nil
+			})(ctx, nil)
+			c.Check(err, check.IsNil)
+		}()
+	}
+	wg.Wait()
+	if c.Failed() {
+		c.Fatal("giving up")
+	}
 
 	// If the token is used again after the in-memory cache
 	// expires, oidcAuthorizer must re-check the token and update
@@ -257,8 +289,8 @@ func (s *OIDCLoginSuite) TestOIDCAuthorizer(c *check.C) {
 		var exp time.Time
 		err := db.QueryRowContext(ctx, `select expires_at at time zone 'UTC' from api_client_authorizations where api_token=$1`, apiToken).Scan(&exp)
 		c.Check(err, check.IsNil)
-		c.Check(exp.Sub(exp1) > 0, check.Equals, true)
-		c.Check(exp.Sub(exp1) < time.Second, check.Equals, true)
+		c.Check(exp.Sub(exp1) > 0, check.Equals, true, check.Commentf("expect %v > 0", exp.Sub(exp1)))
+		c.Check(exp.Sub(exp1) < time.Second, check.Equals, true, check.Commentf("expect %v < 1s", exp.Sub(exp1)))
 		return nil, nil
 	})(ctx, nil)
 
diff --git a/sdk/go/arvadostest/oidc_provider.go b/sdk/go/arvadostest/oidc_provider.go
index fa5e55c42..087adc4b2 100644
--- a/sdk/go/arvadostest/oidc_provider.go
+++ b/sdk/go/arvadostest/oidc_provider.go
@@ -35,6 +35,12 @@ type OIDCProvider struct {
 
 	PeopleAPIResponse map[string]interface{}
 
+	// send incoming /userinfo requests to HoldUserInfo (if not
+	// nil), then receive from ReleaseUserInfo (if not nil),
+	// before responding (these are used to set up races)
+	HoldUserInfo    chan *http.Request
+	ReleaseUserInfo chan struct{}
+
 	key       *rsa.PrivateKey
 	Issuer    *httptest.Server
 	PeopleAPI *httptest.Server
@@ -126,6 +132,12 @@ func (p *OIDCProvider) serveOIDC(w http.ResponseWriter, req *http.Request) {
 	case "/auth":
 		w.WriteHeader(http.StatusInternalServerError)
 	case "/userinfo":
+		if p.HoldUserInfo != nil {
+			p.HoldUserInfo <- req
+		}
+		if p.ReleaseUserInfo != nil {
+			<-p.ReleaseUserInfo
+		}
 		authhdr := req.Header.Get("Authorization")
 		if _, err := jwt.ParseSigned(strings.TrimPrefix(authhdr, "Bearer ")); err != nil {
 			p.c.Logf("OIDCProvider: bad auth %q", authhdr)

-----------------------------------------------------------------------


hooks/post-receive
-- 




More information about the arvados-commits mailing list