Skip to content

Commit 741c2d2

Browse files
committed
Add oidc pkce support
1 parent 26e7995 commit 741c2d2

11 files changed

Lines changed: 372 additions & 54 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ OIDC Provider:
177177
--providers.oidc.client-id= Client ID [$PROVIDERS_OIDC_CLIENT_ID]
178178
--providers.oidc.client-secret= Client Secret [$PROVIDERS_OIDC_CLIENT_SECRET]
179179
--providers.oidc.resource= Optional resource indicator [$PROVIDERS_OIDC_RESOURCE]
180+
--providers.oidc.pkce_required= Optional pkce required indicator [$PROVIDERS_OIDC_PKCE_REQUIRED]
180181
181182
Generic OAuth2 Provider:
182183
--providers.generic-oauth.auth-url= Auth/Login URL [$PROVIDERS_GENERIC_OAUTH_AUTH_URL]

internal/cookie/cookie.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package cookie
2+
3+
import (
4+
"errors"
5+
"net/http"
6+
"time"
7+
)
8+
9+
// CookieStore interface defines methods for setting and getting cookies
10+
type CookieStore interface {
11+
SetCookie(name, value string)
12+
GetCookie(name string) (string, error)
13+
DeleteCookie(name string)
14+
}
15+
16+
// CookieStoreImpl is a concrete implementation of the CookieStore interface
17+
type CookieStoreImpl struct {
18+
writer http.ResponseWriter
19+
request *http.Request
20+
secure bool
21+
}
22+
23+
// NewCookieStore creates a new instance of CookieStoreImpl
24+
func NewCookieStore(w http.ResponseWriter, r *http.Request, secure bool) *CookieStoreImpl {
25+
return &CookieStoreImpl{
26+
writer: w,
27+
request: r,
28+
secure: secure,
29+
}
30+
}
31+
32+
// SetCookie sets a cookie with the given name, value, and attributes
33+
func (c *CookieStoreImpl) SetCookie(name, value string) {
34+
cookie := &http.Cookie{
35+
Name: name,
36+
Value: value,
37+
Path: "/",
38+
Secure: c.secure,
39+
HttpOnly: true,
40+
SameSite: http.SameSiteLaxMode,
41+
}
42+
43+
http.SetCookie(c.writer, cookie)
44+
}
45+
46+
// DeleteCookie removes a cookie with the given name
47+
func (c *CookieStoreImpl) DeleteCookie(name string) {
48+
cookie := &http.Cookie{
49+
Name: name,
50+
Value: "",
51+
Path: "/",
52+
MaxAge: -1,
53+
Expires: time.Unix(0, 0),
54+
}
55+
56+
http.SetCookie(c.writer, cookie)
57+
}
58+
59+
// GetCookie retrieves the value of the cookie with the given name
60+
func (c *CookieStoreImpl) GetCookie(name string) (string, error) {
61+
cookie, err := c.request.Cookie(name)
62+
if err != nil {
63+
return "", errors.New("cookie not found")
64+
}
65+
return cookie.Value, nil
66+
}

internal/pkce/verifier.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package pkce
2+
3+
import (
4+
"crypto/rand"
5+
"crypto/sha256"
6+
"encoding/base64"
7+
"fmt"
8+
"io"
9+
)
10+
11+
type CodeVerifier struct {
12+
Value string
13+
}
14+
15+
func CreateCodeVerifier() (*CodeVerifier, error) {
16+
secureRandomString, err := generateSecureRandomString(32)
17+
if err != nil {
18+
return nil, err
19+
}
20+
return &CodeVerifier{
21+
Value: secureRandomString,
22+
}, nil
23+
}
24+
25+
func CreateCodeVerifierWithCode(code string) *CodeVerifier {
26+
return &CodeVerifier{
27+
Value: code,
28+
}
29+
}
30+
31+
func (v *CodeVerifier) String() string {
32+
return v.Value
33+
}
34+
35+
func (v *CodeVerifier) CodeChallengeS256() string {
36+
h := sha256.New()
37+
h.Write([]byte(v.Value))
38+
hash := h.Sum(nil)
39+
40+
return encode(hash)
41+
}
42+
43+
func GenerateNonce() (string, error) {
44+
return generateSecureRandomString(32)
45+
}
46+
47+
func generateSecureRandomString(length int) (string, error) {
48+
bytes := make([]byte, length)
49+
if _, err := io.ReadFull(rand.Reader, bytes); err != nil {
50+
return "", fmt.Errorf("failed to generate secure random string: %w", err)
51+
}
52+
return base64.RawURLEncoding.EncodeToString(bytes), nil
53+
}
54+
55+
func encode(msg []byte) string {
56+
encoded := base64.RawURLEncoding.EncodeToString(msg)
57+
return encoded
58+
}

internal/provider/generic_oauth.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"net/http"
99

10+
"github.com/thomseddon/traefik-forward-auth/internal/cookie"
1011
"golang.org/x/oauth2"
1112
)
1213

@@ -52,12 +53,12 @@ func (o *GenericOAuth) Setup() error {
5253
}
5354

5455
// GetLoginURL provides the login url for the given redirect uri and state
55-
func (o *GenericOAuth) GetLoginURL(redirectURI, state string) string {
56-
return o.OAuthGetLoginURL(redirectURI, state)
56+
func (o *GenericOAuth) GetLoginURL(redirectURI, state string, _ cookie.CookieStore) (string, error) {
57+
return o.OAuthGetLoginURL(redirectURI, state), nil
5758
}
5859

5960
// ExchangeCode exchanges the given redirect uri and code for a token
60-
func (o *GenericOAuth) ExchangeCode(redirectURI, code string) (string, error) {
61+
func (o *GenericOAuth) ExchangeCode(redirectURI, code string, _ cookie.CookieStore) (string, error) {
6162
token, err := o.OAuthExchangeCode(redirectURI, code)
6263
if err != nil {
6364
return "", err

internal/provider/generic_oauth_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ func TestGenericOAuthGetLoginURL(t *testing.T) {
5353
}
5454

5555
// Check url
56-
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
56+
loginURL, _ := p.GetLoginURL("http://example.com/_oauth", "state", nil)
57+
uri, err := url.Parse(loginURL)
5758
assert.Nil(err)
5859
assert.Equal("https", uri.Scheme)
5960
assert.Equal("provider.com", uri.Host)
@@ -104,7 +105,7 @@ func TestGenericOAuthExchangeCode(t *testing.T) {
104105
// AuthStyleInHeader is attempted
105106
p.Config.Endpoint.AuthStyle = oauth2.AuthStyleInParams
106107

107-
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
108+
token, err := p.ExchangeCode("http://example.com/_oauth", "code", nil)
108109
assert.Nil(err)
109110
assert.Equal("123456789", token)
110111
}

internal/provider/google.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"net/http"
88
"net/url"
9+
10+
"github.com/thomseddon/traefik-forward-auth/internal/cookie"
911
)
1012

1113
// Google provider
@@ -53,7 +55,7 @@ func (g *Google) Setup() error {
5355
}
5456

5557
// GetLoginURL provides the login url for the given redirect uri and state
56-
func (g *Google) GetLoginURL(redirectURI, state string) string {
58+
func (g *Google) GetLoginURL(redirectURI, state string, _ cookie.CookieStore) (string, error) {
5759
q := url.Values{}
5860
q.Set("client_id", g.ClientID)
5961
q.Set("response_type", "code")
@@ -68,11 +70,11 @@ func (g *Google) GetLoginURL(redirectURI, state string) string {
6870
u = *g.LoginURL
6971
u.RawQuery = q.Encode()
7072

71-
return u.String()
73+
return u.String(), nil
7274
}
7375

7476
// ExchangeCode exchanges the given redirect uri and code for a token
75-
func (g *Google) ExchangeCode(redirectURI, code string) (string, error) {
77+
func (g *Google) ExchangeCode(redirectURI, code string, _ cookie.CookieStore) (string, error) {
7678
form := url.Values{}
7779
form.Set("client_id", g.ClientID)
7880
form.Set("client_secret", g.ClientSecret)

internal/provider/google_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ func TestGoogleGetLoginURL(t *testing.T) {
6868
}
6969

7070
// Check url
71-
uri, err := url.Parse(p.GetLoginURL("http://example.com/_oauth", "state"))
71+
loginUrl, _ := p.GetLoginURL("http://example.com/_oauth", "state", nil)
72+
uri, err := url.Parse(loginUrl)
7273
assert.Nil(err)
7374
assert.Equal("https", uri.Scheme)
7475
assert.Equal("google.com", uri.Host)
@@ -116,7 +117,7 @@ func TestGoogleExchangeCode(t *testing.T) {
116117
},
117118
}
118119

119-
token, err := p.ExchangeCode("http://example.com/_oauth", "code")
120+
token, err := p.ExchangeCode("http://example.com/_oauth", "code", nil)
120121
assert.Nil(err)
121122
assert.Equal("123456789", token)
122123
}

internal/provider/oidc.go

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,25 @@ package provider
33
import (
44
"context"
55
"errors"
6+
"strings"
67

78
"github.com/coreos/go-oidc"
9+
"github.com/thomseddon/traefik-forward-auth/internal/cookie"
10+
"github.com/thomseddon/traefik-forward-auth/internal/pkce"
811
"golang.org/x/oauth2"
912
)
1013

14+
const (
15+
CookieNameNonce = "oidc-nonce"
16+
CookieNamePkceCode = "oidc-pkce-code"
17+
)
18+
1119
// OIDC provider
1220
type OIDC struct {
1321
IssuerURL string `long:"issuer-url" env:"ISSUER_URL" description:"Issuer URL"`
1422
ClientID string `long:"client-id" env:"CLIENT_ID" description:"Client ID"`
1523
ClientSecret string `long:"client-secret" env:"CLIENT_SECRET" description:"Client Secret" json:"-"`
24+
PkceRequired bool `long:"pkce-required" env:"PKCE_REQUIRED" description:"Optional pkce required indicator"`
1625

1726
OAuthProvider
1827

@@ -27,9 +36,9 @@ func (o *OIDC) Name() string {
2736

2837
// Setup performs validation and setup
2938
func (o *OIDC) Setup() error {
30-
// Check parms
31-
if o.IssuerURL == "" || o.ClientID == "" || o.ClientSecret == "" {
32-
return errors.New("providers.oidc.issuer-url, providers.oidc.client-id, providers.oidc.client-secret must be set")
39+
// Check params
40+
if err := o.checkParams(); err != nil {
41+
return err
3342
}
3443

3544
var err error
@@ -60,13 +69,47 @@ func (o *OIDC) Setup() error {
6069
}
6170

6271
// GetLoginURL provides the login url for the given redirect uri and state
63-
func (o *OIDC) GetLoginURL(redirectURI, state string) string {
64-
return o.OAuthGetLoginURL(redirectURI, state)
72+
func (o *OIDC) GetLoginURL(redirectURI, state string, cookieStore cookie.CookieStore) (string, error) {
73+
var opts []oauth2.AuthCodeOption
74+
75+
// Generate and store nonce
76+
nonce, err := pkce.GenerateNonce()
77+
if err != nil {
78+
return "", err
79+
}
80+
81+
cookieStore.SetCookie(CookieNameNonce, nonce)
82+
83+
opts = append(opts, oauth2.SetAuthURLParam("nonce", nonce))
84+
85+
if o.PkceRequired {
86+
pkceVerifier, err := pkce.CreateCodeVerifier()
87+
if err != nil {
88+
return "", err
89+
}
90+
91+
opts = append(opts, oauth2.SetAuthURLParam("code_challenge_method", "S256"))
92+
opts = append(opts, oauth2.SetAuthURLParam("code_challenge", pkceVerifier.CodeChallengeS256()))
93+
94+
cookieStore.SetCookie(CookieNamePkceCode, pkceVerifier.String())
95+
}
96+
return o.OAuthGetLoginURL(redirectURI, state, opts...), nil
6597
}
6698

6799
// ExchangeCode exchanges the given redirect uri and code for a token
68-
func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
69-
token, err := o.OAuthExchangeCode(redirectURI, code)
100+
func (o *OIDC) ExchangeCode(redirectURI, code string, cookieStore cookie.CookieStore) (string, error) {
101+
var opts []oauth2.AuthCodeOption
102+
103+
if o.PkceRequired {
104+
pkceCode, err := cookieStore.GetCookie(CookieNamePkceCode)
105+
if err != nil {
106+
return "", err
107+
}
108+
cookieStore.DeleteCookie(CookieNamePkceCode)
109+
opts = append(opts, oauth2.SetAuthURLParam("code_verifier", pkceCode))
110+
}
111+
112+
token, err := o.OAuthExchangeCode(redirectURI, code, opts...)
70113
if err != nil {
71114
return "", err
72115
}
@@ -77,6 +120,23 @@ func (o *OIDC) ExchangeCode(redirectURI, code string) (string, error) {
77120
return "", errors.New("Missing id_token")
78121
}
79122

123+
// Verify nonce
124+
idToken, err := o.verifier.Verify(o.ctx, rawIDToken)
125+
if err != nil {
126+
return "", err
127+
}
128+
129+
nonce, err := cookieStore.GetCookie(CookieNameNonce)
130+
if err != nil {
131+
return "", errors.New("nonce not found")
132+
}
133+
134+
cookieStore.DeleteCookie(CookieNameNonce)
135+
136+
if idToken.Nonce != nonce {
137+
return "", errors.New("nonce verification failed")
138+
}
139+
80140
return rawIDToken, nil
81141
}
82142

@@ -97,3 +157,25 @@ func (o *OIDC) GetUser(token string) (User, error) {
97157

98158
return user, nil
99159
}
160+
161+
func (o *OIDC) checkParams() error {
162+
if o.IssuerURL == "" || o.ClientID == "" || (o.ClientSecret == "" && !o.PkceRequired) {
163+
var emptyFields []string
164+
165+
if o.IssuerURL == "" {
166+
emptyFields = append(emptyFields, "providers.oidc.issuer-url")
167+
}
168+
169+
if o.ClientID == "" {
170+
emptyFields = append(emptyFields, "providers.oidc.client-id")
171+
}
172+
173+
if o.ClientSecret == "" && !o.PkceRequired {
174+
emptyFields = append(emptyFields, "providers.oidc.client-secret")
175+
}
176+
177+
return errors.New(strings.Join(emptyFields, ", ") + " must be set")
178+
}
179+
180+
return nil
181+
}

0 commit comments

Comments
 (0)