diff --git a/internal/api/external.go b/internal/api/external.go index 8392797d5..ff70339a5 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -725,6 +725,8 @@ func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, if q.Get("error_code") != "" { hq.Set("error_code", q.Get("error_code")) } + // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects + hq.Set("sb", "") u.Fragment = hq.Encode() http.Redirect(w, r, u.String(), http.StatusFound) } diff --git a/internal/api/external_test.go b/internal/api/external_test.go index d2adcac42..b9601572f 100644 --- a/internal/api/external_test.go +++ b/internal/api/external_test.go @@ -202,6 +202,8 @@ func assertAuthorizationSuccess(ts *ExternalTestSuite, u *url.URL, tokenCount in ts.NotEmpty(v.Get("refresh_token")) ts.NotEmpty(v.Get("expires_in")) ts.Equal("bearer", v.Get("token_type")) + // Verify Supabase Auth identifier is present + ts.Contains(v, "sb", "Fragment should contain Supabase Auth identifier 'sb'") ts.Equal(1, tokenCount) if userCount > -1 { @@ -243,6 +245,8 @@ func assertAuthorizationFailure(ts *ExternalTestSuite, u *url.URL, errorDescript ts.Empty(v.Get("refresh_token")) ts.Empty(v.Get("expires_in")) ts.Empty(v.Get("token_type")) + // Verify Supabase Auth identifier is present even in error responses + ts.Contains(v, "sb", "Fragment should contain Supabase Auth identifier 'sb' even in errors") // ensure user is nil user, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) diff --git a/internal/api/verify.go b/internal/api/verify.go index 6209774bb..e41813008 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -507,6 +507,8 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, u.RawQuery = q.Encode() } // Left as hash fragment to comply with spec. + // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects + hq.Set("sb", "") u.Fragment = hq.Encode() return u.String(), nil } @@ -523,6 +525,8 @@ func (a *API) prepRedirectURL(message string, rurl string, flowType models.FlowT q.Set("message", message) } u.RawQuery = q.Encode() + // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects + hq.Set("sb", "") u.Fragment = hq.Encode() return u.String(), nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 4dc514791..025bda589 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -1148,28 +1148,28 @@ func (ts *VerifyTestSuite) TestPrepRedirectURL() { message: singleConfirmationAccepted, rurl: "https://example.com/?first=another&second=other", flowType: models.PKCEFlow, - expected: fmt.Sprintf("https://example.com/?first=another&message=%s&second=other#message=%s", escapedMessage, escapedMessage), + expected: fmt.Sprintf("https://example.com/?first=another&message=%s&second=other#message=%s&sb=", escapedMessage, escapedMessage), }, { desc: "(PKCE): Query params in redirect url are overriden", message: singleConfirmationAccepted, rurl: "https://example.com/?message=Valid+redirect+URL", flowType: models.PKCEFlow, - expected: fmt.Sprintf("https://example.com/?message=%s#message=%s", escapedMessage, escapedMessage), + expected: fmt.Sprintf("https://example.com/?message=%s#message=%s&sb=", escapedMessage, escapedMessage), }, { desc: "(Implicit): plain redirect url", message: singleConfirmationAccepted, rurl: "https://example.com/", flowType: models.ImplicitFlow, - expected: fmt.Sprintf("https://example.com/#message=%s", escapedMessage), + expected: fmt.Sprintf("https://example.com/#message=%s&sb=", escapedMessage), }, { desc: "(Implicit): query params retained", message: singleConfirmationAccepted, rurl: "https://example.com/?first=another", flowType: models.ImplicitFlow, - expected: fmt.Sprintf("https://example.com/?first=another#message=%s", escapedMessage), + expected: fmt.Sprintf("https://example.com/?first=another#message=%s&sb=", escapedMessage), }, } for _, c := range cases { @@ -1197,28 +1197,28 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { message: "Valid redirect URL", rurl: "https://example.com/", flowType: models.PKCEFlow, - expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + expected: fmt.Sprintf("https://example.com/?%s#%s&sb=", redirectError, redirectError), }, { desc: "(PKCE): Error with conflicting query params in redirect url", message: DefaultError, rurl: "https://example.com/?error=Error+to+be+overriden", flowType: models.PKCEFlow, - expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + expected: fmt.Sprintf("https://example.com/?%s#%s&sb=", redirectError, redirectError), }, { desc: "(Implicit): plain redirect url", message: DefaultError, rurl: "https://example.com/", flowType: models.ImplicitFlow, - expected: fmt.Sprintf("https://example.com/#%s", redirectError), + expected: fmt.Sprintf("https://example.com/#%s&sb=", redirectError), }, { desc: "(Implicit): query params preserved", message: DefaultError, rurl: "https://example.com/?test=param", flowType: models.ImplicitFlow, - expected: fmt.Sprintf("https://example.com/?test=param#%s", redirectError), + expected: fmt.Sprintf("https://example.com/?test=param#%s&sb=", redirectError), }, } for _, c := range cases { diff --git a/internal/tokens/service.go b/internal/tokens/service.go index 3e3d7cd08..02c640e87 100644 --- a/internal/tokens/service.go +++ b/internal/tokens/service.go @@ -145,6 +145,8 @@ func (r *AccessTokenResponse) AsRedirectURL(redirectURL string, extraParams url. extraParams.Set("expires_in", strconv.Itoa(r.ExpiresIn)) extraParams.Set("expires_at", strconv.FormatInt(r.ExpiresAt, 10)) extraParams.Set("refresh_token", r.RefreshToken) + // Add Supabase Auth identifier to help clients distinguish Supabase Auth redirects + extraParams.Set("sb", "") return redirectURL + "#" + extraParams.Encode() } diff --git a/internal/tokens/service_test.go b/internal/tokens/service_test.go index 8c409934b..0ec17d502 100644 --- a/internal/tokens/service_test.go +++ b/internal/tokens/service_test.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/json" "net/http" + "net/url" "strconv" "strings" "sync" @@ -1089,3 +1090,39 @@ func TestAMRClaimUnmarshal(t *testing.T) { require.Equal(t, "webauthn", claim[1].Provider, "provider should be preserved") }) } + +// TestAsRedirectURL tests that AsRedirectURL includes the Supabase Auth identifier +func TestAsRedirectURL(t *testing.T) { + response := &AccessTokenResponse{ + Token: "test_access_token", + TokenType: "bearer", + ExpiresIn: 3600, + ExpiresAt: 1234567890, + RefreshToken: "test_refresh_token", + } + + extraParams := url.Values{} + extraParams.Set("provider_token", "provider_access_token") + + redirectURL := response.AsRedirectURL("https://example.com/callback", extraParams) + + // Parse the URL + u, err := url.Parse(redirectURL) + require.NoError(t, err) + + // Parse the fragment + fragment, err := url.ParseQuery(u.Fragment) + require.NoError(t, err) + + // Verify all expected parameters are present + require.Equal(t, "test_access_token", fragment.Get("access_token")) + require.Equal(t, "bearer", fragment.Get("token_type")) + require.Equal(t, "3600", fragment.Get("expires_in")) + require.Equal(t, "1234567890", fragment.Get("expires_at")) + require.Equal(t, "test_refresh_token", fragment.Get("refresh_token")) + require.Equal(t, "provider_access_token", fragment.Get("provider_token")) + + // Verify Supabase Auth identifier is present + require.Contains(t, fragment, "sb", "Fragment should contain Supabase Auth identifier 'sb'") + require.Equal(t, "", fragment.Get("sb"), "Supabase Auth identifier should have empty value") +}