Skip to content

Commit a35dff7

Browse files
authored
Added a cache lookup for AcquireTokenByCredential (#590)
* Added a cache lookup for AcquireTokenByCredential * Refactor: Extract acquireTokenSilentInternal for shared cache logic Replaced withInternalCCACallOnly() flag with acquireTokenSilentInternal() method to avoid polluting public API options with internal-only fields. - Added acquireTokenSilentInternal() shared by AcquireTokenSilent and AcquireTokenByCredential - Removed allowServicePrincipalRefresh flag from acquireTokenSilentOptions - AcquireTokenSilent validates account requirement before delegating - Updated tests to call internal method directly Cleaner separation between public API constraints and internal needs.
1 parent e59524a commit a35dff7

File tree

2 files changed

+125
-35
lines changed

2 files changed

+125
-35
lines changed

apps/confidential/confidential.go

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -596,16 +596,26 @@ func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts
596596
return AuthResult{}, errors.New("call another AcquireToken method to request a new token having these claims")
597597
}
598598

599+
// For service principal scenarios, require WithSilentAccount for public API
600+
if o.account.IsZero() {
601+
return AuthResult{}, errors.New("WithSilentAccount option is required")
602+
}
603+
604+
return cca.acquireTokenSilentInternal(ctx, scopes, o.account, o.claims, o.tenantID, o.authnScheme)
605+
}
606+
607+
// acquireTokenSilentInternal is the internal implementation shared by AcquireTokenSilent and AcquireTokenByCredential
608+
func (cca Client) acquireTokenSilentInternal(ctx context.Context, scopes []string, account Account, claims, tenantID string, authnScheme AuthenticationScheme) (AuthResult, error) {
599609
silentParameters := base.AcquireTokenSilentParameters{
600610
Scopes: scopes,
601-
Account: o.account,
611+
Account: account,
602612
RequestType: accesstokens.ATConfidential,
603613
Credential: cca.cred,
604-
IsAppCache: o.account.IsZero(),
605-
TenantID: o.tenantID,
606-
AuthnScheme: o.authnScheme,
614+
IsAppCache: account.IsZero(),
615+
TenantID: tenantID,
616+
AuthnScheme: authnScheme,
617+
Claims: claims,
607618
}
608-
609619
return cca.base.AcquireTokenSilent(ctx, silentParameters)
610620
}
611621

@@ -736,6 +746,14 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string,
736746
if o.authnScheme != nil {
737747
authParams.AuthnScheme = o.authnScheme
738748
}
749+
if o.claims == "" {
750+
// Use internal method with empty account (service principal scenario)
751+
cache, err := cca.acquireTokenSilentInternal(ctx, scopes, Account{}, o.claims, o.tenantID, authParams.AuthnScheme)
752+
if err == nil {
753+
return cache, nil
754+
}
755+
}
756+
739757
token, err := cca.base.Token.Credential(ctx, authParams, cca.cred)
740758
if err != nil {
741759
return AuthResult{}, err

apps/confidential/confidential_test.go

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
3434
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
3535
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
36+
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
3637
)
3738

3839
// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
@@ -162,17 +163,65 @@ func TestAcquireTokenByCredential(t *testing.T) {
162163
if tk.AccessToken != token {
163164
t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken)
164165
}
165-
// second attempt should return the cached token
166-
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
166+
// second attempt should return the cached token, AcquireTokenByCredential calls for cache
167+
tk, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
167168
if err != nil {
168169
t.Errorf("TestAcquireTokenByCredential(%s): got err == %s, want err == nil", test.desc, err)
169170
}
170171
if tk.AccessToken != token {
171172
t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken)
172173
}
174+
if tk.Metadata.TokenSource != TokenSourceCache {
175+
t.Errorf("TestAcquireTokenByCredential(%s): unexpected token source %d", test.desc, tk.Metadata.TokenSource)
176+
}
173177
}
174178
}
175179

180+
func TestAcquireTokenByCredentialWithCache(t *testing.T) {
181+
cred, err := NewCredFromSecret(fakeSecret)
182+
if err != nil {
183+
t.Fatal(err)
184+
}
185+
tenant := "tenant"
186+
lmo := "login.microsoftonline.com"
187+
mockClient := mock.NewClient()
188+
189+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, "tenant")))
190+
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(mockClient))
191+
if err != nil {
192+
t.Fatal(err)
193+
}
194+
ctx := context.Background()
195+
196+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
197+
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600, 0)))
198+
199+
token, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
200+
if err != nil {
201+
t.Fatal(err)
202+
}
203+
if token.AccessToken != tenant {
204+
t.Fatalf("expected token to be %s, got %s", tenant, token.AccessToken)
205+
}
206+
207+
// calling the acquire token by credential again to get from cache
208+
token, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
209+
if err != nil {
210+
t.Fatal(err)
211+
}
212+
if token.AccessToken != tenant {
213+
t.Fatalf("expected token to be %s, got %s", tenant, token.AccessToken)
214+
}
215+
if token.Metadata.TokenSource != TokenSourceCache {
216+
t.Fatalf("expected token source to be cache, got %d", token.Metadata.TokenSource)
217+
}
218+
// calling silent should still be error
219+
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
220+
t.Fatal("silent auth should fail because it is a service principal call")
221+
}
222+
223+
}
224+
176225
func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) {
177226
cred, err := NewCredFromSecret(fakeSecret)
178227
if err != nil {
@@ -495,7 +544,7 @@ func TestAcquireTokenSilentTenants(t *testing.T) {
495544
}
496545
// cache should return the correct access token for each tenant
497546
for _, tenant := range tenants {
498-
ar, err := client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant))
547+
ar, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
499548
if err != nil {
500549
t.Fatal(err)
501550
}
@@ -816,7 +865,7 @@ func TestNewCredFromTokenProvider(t *testing.T) {
816865
if ar.AccessToken != expectedToken {
817866
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
818867
}
819-
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
868+
ar, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
820869
if err != nil {
821870
t.Fatal(err)
822871
}
@@ -902,7 +951,7 @@ func TestRefreshInMultipleRequests(t *testing.T) {
902951
wg.Add(2)
903952
go func() {
904953
defer wg.Done()
905-
_, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant"))
954+
_, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant"))
906955
if err != nil {
907956
select {
908957
case ch <- err:
@@ -913,7 +962,7 @@ func TestRefreshInMultipleRequests(t *testing.T) {
913962
}()
914963
go func() {
915964
defer wg.Done()
916-
_, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant"))
965+
_, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant"))
917966
if err != nil {
918967
select {
919968
case ch <- err:
@@ -1005,7 +1054,7 @@ func TestConcurrentRequests(t *testing.T) {
10051054
)
10061055
go func(id string) {
10071056
defer wg.Done()
1008-
if _, err := client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(id)); err != nil {
1057+
if _, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(id)); err != nil {
10091058
t.Error("Unexpected error", err)
10101059
}
10111060
}(tenant)
@@ -1100,7 +1149,7 @@ func TestRefreshIn(t *testing.T) {
11001149
base.Now = func() time.Time {
11011150
return fixedTime
11021151
}
1103-
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
1152+
ar, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
11041153
if err != nil {
11051154
t.Fatal(err)
11061155
}
@@ -1311,7 +1360,6 @@ func TestWithClaims(t *testing.T) {
13111360
for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} {
13121361
t.Run(method, func(t *testing.T) {
13131362
mockClient := mock.NewClient()
1314-
13151363
clientInfo, idToken, refreshToken := "", "", ""
13161364
if method == "obo" {
13171365
clientInfo = base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
@@ -1320,7 +1368,11 @@ func TestWithClaims(t *testing.T) {
13201368
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
13211369
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
13221370
}
1323-
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
1371+
if method == "credential" {
1372+
if test.claims == "" {
1373+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
1374+
}
1375+
}
13241376
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
13251377
mockClient.AppendResponse(
13261378
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600, 0)),
@@ -1331,11 +1383,14 @@ func TestWithClaims(t *testing.T) {
13311383
validate(t, r.Form)
13321384
}),
13331385
)
1386+
if method != "obo" {
1387+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
1388+
}
13341389
client, err := New(authority, fakeClientID, cred, WithClientCapabilities(test.capabilities), WithHTTPClient(mockClient))
13351390
if err != nil {
13361391
t.Fatal(err)
13371392
}
1338-
if _, err = client.AcquireTokenSilent(context.Background(), tokenScope); err == nil {
1393+
if _, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(shared.Account{})); err == nil {
13391394
t.Fatal("silent authentication should fail because the cache is empty")
13401395
}
13411396
ctx := context.Background()
@@ -1370,8 +1425,10 @@ func TestWithClaims(t *testing.T) {
13701425
// silent auth should now succeed, provided no claims are requested, because the client has cached an access token
13711426
if method == "obo" {
13721427
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope)
1428+
} else if method == "credential" {
1429+
ar, err = client.AcquireTokenByCredential(ctx, tokenScope)
13731430
} else {
1374-
ar, err = client.AcquireTokenSilent(ctx, tokenScope)
1431+
ar, err = client.acquireTokenSilentInternal(ctx, tokenScope, shared.Account{}, "", tenant, client.base.AuthParams.AuthnScheme)
13751432
}
13761433
if err != nil {
13771434
t.Fatal(err)
@@ -1395,7 +1452,7 @@ func TestWithClaims(t *testing.T) {
13951452
// all token requests should include any specified claims
13961453
validate(t, r.Form)
13971454
if actual := r.Form.Get("refresh_token"); actual != refreshToken {
1398-
t.Fatalf(`unexpected refresh token "%s"`, actual)
1455+
t.Fatalf(`unexpected refresh token "%s ,, %s"`, actual, refreshToken)
13991456
}
14001457
}),
14011458
)
@@ -1445,20 +1502,23 @@ func TestWithTenantID(t *testing.T) {
14451502
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
14461503
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
14471504
}
1448-
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
1505+
if method == "credential" {
1506+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
1507+
}
1508+
// mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
14491509
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
14501510
mockClient.AppendResponse(
14511511
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)),
14521512
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
14531513
)
1514+
if method != "obo" {
1515+
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
1516+
}
14541517
client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(mockClient))
14551518
if err != nil {
14561519
t.Fatal(err)
14571520
}
14581521
ctx := context.Background()
1459-
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err == nil {
1460-
t.Fatal("silent auth should fail because the cache is empty")
1461-
}
14621522
var ar AuthResult
14631523
switch method {
14641524
case "authcode":
@@ -1495,7 +1555,11 @@ func TestWithTenantID(t *testing.T) {
14951555
if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant)); err != nil {
14961556
t.Fatal(err)
14971557
}
1498-
} else if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err != nil {
1558+
} else if method == "credential" {
1559+
if ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant)); err != nil {
1560+
t.Fatal(err)
1561+
}
1562+
} else if ar, err = client.acquireTokenSilentInternal(ctx, tokenScope, shared.Account{}, "", test.tenant, client.base.AuthParams.AuthnScheme); err != nil {
14991563
t.Fatal(err)
15001564
}
15011565
if ar.AccessToken != accessToken {
@@ -1562,7 +1626,7 @@ func TestWithTenantID(t *testing.T) {
15621626
t.Fatalf("unexpected access token %q", ar.AccessToken)
15631627
}
15641628
// silent authentication should now succeed for the given tenant...
1565-
if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil {
1629+
if ar, err = client.acquireTokenSilentInternal(ctx, tokenScope, shared.Account{}, "", tenant, client.base.AuthParams.AuthnScheme); err != nil {
15661630
t.Fatal(err)
15671631
}
15681632
if ar.AccessToken != accessToken {
@@ -1634,7 +1698,11 @@ func TestWithInstanceDiscovery(t *testing.T) {
16341698
if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope); err != nil {
16351699
t.Fatal(err)
16361700
}
1637-
} else if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil {
1701+
} else if method == "credential" {
1702+
if ar, err = client.AcquireTokenByCredential(ctx, tokenScope); err != nil {
1703+
t.Fatal(err)
1704+
}
1705+
} else if ar, err = client.acquireTokenSilentInternal(ctx, tokenScope, shared.Account{}, "", tenant, client.base.AuthParams.AuthnScheme); err != nil {
16381706
t.Fatal(err)
16391707
}
16401708
if ar.AccessToken != accessToken {
@@ -1656,18 +1724,18 @@ func TestWithPortAuthority(t *testing.T) {
16561724
if err != nil {
16571725
t.Fatal(err)
16581726
}
1659-
idToken, refreshToken, URL := "", "", ""
1727+
refreshToken, URL := "", ""
16601728
mockClient := mock.NewClient()
1729+
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"fakeuser","utid":"fakeuserid"}`))
16611730

1662-
//2 calls to instance discovery are made because Host is not trusted
1663-
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
1664-
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
16651731
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
16661732
mockClient.AppendResponse(
1667-
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600, 0)),
1733+
mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, authority), refreshToken, clientInfo, 36000, 0)),
16681734
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
16691735
)
1670-
client, err := New(authority, fakeClientID, cred, WithHTTPClient(mockClient))
1736+
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
1737+
1738+
client, err := New(authority, fakeClientID, cred, WithHTTPClient(mockClient), WithInstanceDiscovery(false))
16711739
if err != nil {
16721740
t.Fatal(err)
16731741
}
@@ -1686,7 +1754,11 @@ func TestWithPortAuthority(t *testing.T) {
16861754
if ar.AccessToken != accessToken {
16871755
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
16881756
}
1689-
if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil {
1757+
account := ar.Account
1758+
if actual := account.Realm; actual != tenant {
1759+
t.Fatalf(`unexpected realm "%s"`, actual)
1760+
}
1761+
if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(account)); err != nil {
16901762
t.Fatal(err)
16911763
}
16921764
if ar.AccessToken != accessToken {
@@ -1798,7 +1870,7 @@ func TestWithAuthenticationScheme(t *testing.T) {
17981870
if result.AccessToken != fmt.Sprintf(mock.Authnschemeformat, token) {
17991871
t.Fatalf(`unexpected access token "%s"`, result.AccessToken)
18001872
}
1801-
result, err = client.AcquireTokenSilent(ctx, tokenScope, WithAuthenticationScheme(authScheme))
1873+
result, err = client.AcquireTokenByCredential(ctx, tokenScope, WithAuthenticationScheme(authScheme))
18021874
if err != nil {
18031875
t.Fatal(err)
18041876
}
@@ -1846,7 +1918,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
18461918
t.Errorf("unexpected access token %s", tk.AccessToken)
18471919
}
18481920

1849-
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
1921+
tk, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
18501922
if err != nil {
18511923
t.Errorf("got err == %s, want err == nil", err)
18521924
}
@@ -1855,7 +1927,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
18551927
}
18561928

18571929
// fail for another tenant
1858-
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other"))
1930+
tk, err = client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("other"))
18591931
if err == nil {
18601932
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
18611933
}

0 commit comments

Comments
 (0)