diff --git a/pkg/authserver/server/handlers/helpers_test.go b/pkg/authserver/server/handlers/helpers_test.go index 63367a44d1..9e824edcc6 100644 --- a/pkg/authserver/server/handlers/helpers_test.go +++ b/pkg/authserver/server/handlers/helpers_test.go @@ -80,7 +80,7 @@ func (m *mockIDPProvider) ExchangeCode(_ context.Context, code, codeVerifier str return m.exchangeTokens, nil } -func (m *mockIDPProvider) RefreshTokens(_ context.Context, _ string) (*upstream.Tokens, error) { +func (m *mockIDPProvider) RefreshTokens(_ context.Context, _, _ string) (*upstream.Tokens, error) { if m.refreshErr != nil { return nil, m.refreshErr } diff --git a/pkg/authserver/upstream/doc.go b/pkg/authserver/upstream/doc.go index 99b3f1cd9f..f2e43da409 100644 --- a/pkg/authserver/upstream/doc.go +++ b/pkg/authserver/upstream/doc.go @@ -24,7 +24,7 @@ // - Type: Returns the provider type identifier // - AuthorizationURL: Build redirect URL for user authentication // - ExchangeCode: Exchange authorization code for tokens -// - RefreshTokens: Refresh expired tokens +// - RefreshTokens: Refresh expired tokens (with subject validation for OIDC) // - ResolveIdentity: Resolve user identity from tokens // - FetchUserInfo: Fetch user claims // diff --git a/pkg/authserver/upstream/mocks/mock_provider.go b/pkg/authserver/upstream/mocks/mock_provider.go index d5a823c400..8a9048cc55 100644 --- a/pkg/authserver/upstream/mocks/mock_provider.go +++ b/pkg/authserver/upstream/mocks/mock_provider.go @@ -92,18 +92,18 @@ func (mr *MockOAuth2ProviderMockRecorder) FetchUserInfo(ctx, accessToken any) *g } // RefreshTokens mocks base method. -func (m *MockOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken string) (*upstream.Tokens, error) { +func (m *MockOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken, expectedSubject string) (*upstream.Tokens, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RefreshTokens", ctx, refreshToken) + ret := m.ctrl.Call(m, "RefreshTokens", ctx, refreshToken, expectedSubject) ret0, _ := ret[0].(*upstream.Tokens) ret1, _ := ret[1].(error) return ret0, ret1 } // RefreshTokens indicates an expected call of RefreshTokens. -func (mr *MockOAuth2ProviderMockRecorder) RefreshTokens(ctx, refreshToken any) *gomock.Call { +func (mr *MockOAuth2ProviderMockRecorder) RefreshTokens(ctx, refreshToken, expectedSubject any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTokens", reflect.TypeOf((*MockOAuth2Provider)(nil).RefreshTokens), ctx, refreshToken) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshTokens", reflect.TypeOf((*MockOAuth2Provider)(nil).RefreshTokens), ctx, refreshToken, expectedSubject) } // ResolveIdentity mocks base method. diff --git a/pkg/authserver/upstream/oauth2.go b/pkg/authserver/upstream/oauth2.go index 8b90ee9735..51f6925e26 100644 --- a/pkg/authserver/upstream/oauth2.go +++ b/pkg/authserver/upstream/oauth2.go @@ -77,7 +77,10 @@ type OAuth2Provider interface { ExchangeCode(ctx context.Context, code, codeVerifier string) (*Tokens, error) // RefreshTokens refreshes the upstream IDP tokens. - RefreshTokens(ctx context.Context, refreshToken string) (*Tokens, error) + // expectedSubject is the original sub claim; OIDC providers validate it per + // Section 12.2 when the response includes a new ID token. Pure OAuth2 providers + // ignore it. + RefreshTokens(ctx context.Context, refreshToken, expectedSubject string) (*Tokens, error) // ResolveIdentity validates tokens and returns the canonical subject. // For OIDC providers, it validates the ID token and nonce (ID token required). @@ -393,7 +396,7 @@ func (p *BaseOAuth2Provider) ExchangeCode(ctx context.Context, code, codeVerifie } // RefreshTokens refreshes the upstream IDP tokens. -func (p *BaseOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken string) (*Tokens, error) { +func (p *BaseOAuth2Provider) RefreshTokens(ctx context.Context, refreshToken, _ string) (*Tokens, error) { if refreshToken == "" { return nil, errors.New("refresh token is required") } diff --git a/pkg/authserver/upstream/oauth2_test.go b/pkg/authserver/upstream/oauth2_test.go index 3508db75c8..501d7d96a7 100644 --- a/pkg/authserver/upstream/oauth2_test.go +++ b/pkg/authserver/upstream/oauth2_test.go @@ -711,7 +711,7 @@ func TestBaseOAuth2Provider_RefreshTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - tokens, err := provider.RefreshTokens(ctx, "old-refresh-token") + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "") require.NoError(t, err) // Verify request parameters @@ -756,7 +756,7 @@ func TestBaseOAuth2Provider_RefreshTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - _, err = provider.RefreshTokens(ctx, "expired-refresh-token") + _, err = provider.RefreshTokens(ctx, "expired-refresh-token", "") require.Error(t, err) assert.Contains(t, err.Error(), "invalid_grant") }) @@ -780,7 +780,7 @@ func TestBaseOAuth2Provider_RefreshTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - _, err = provider.RefreshTokens(ctx, "") + _, err = provider.RefreshTokens(ctx, "", "") require.Error(t, err) assert.Contains(t, err.Error(), "refresh token is required") }) @@ -808,7 +808,7 @@ func TestBaseOAuth2Provider_RefreshTokens(t *testing.T) { provider, err := NewOAuth2Provider(config) require.NoError(t, err) - _, err = provider.RefreshTokens(ctx, "refresh-token") + _, err = provider.RefreshTokens(ctx, "refresh-token", "") require.Error(t, err) assert.Contains(t, err.Error(), "token request failed") }) diff --git a/pkg/authserver/upstream/oidc.go b/pkg/authserver/upstream/oidc.go index 9bcde3c175..be997e2bcd 100644 --- a/pkg/authserver/upstream/oidc.go +++ b/pkg/authserver/upstream/oidc.go @@ -9,6 +9,8 @@ import ( "fmt" "net/http" "net/url" + "slices" + "time" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" @@ -47,6 +49,15 @@ func (c *OIDCConfig) Validate() error { // the expected nonce from the authorization request. var ErrNonceMismatch = errors.New("ID token nonce does not match expected value") +// ErrSubjectMismatch is returned when the sub claim in a refreshed ID token does not +// match the expected subject from the original token response. +// Per OIDC Core Section 12.2, the sub claim MUST be identical. +var ErrSubjectMismatch = errors.New("ID token subject does not match expected value") + +// ErrNonceMissing is returned when the ID token does not contain a nonce claim +// but one was expected (because a nonce was sent in the authorization request). +var ErrNonceMissing = errors.New("ID token missing nonce claim when nonce was expected") + // OIDCProviderImpl implements OAuth2Provider for OIDC-compliant identity providers. // It embeds BaseOAuth2Provider to share common OAuth 2.0 logic while adding // OIDC-specific functionality like discovery and ID token validation. @@ -69,6 +80,13 @@ func WithHTTPClient(client *http.Client) OIDCProviderOption { } } +// WithNonce adds an OIDC nonce parameter to the authorization request. +// The nonce is used to associate a client session with an ID Token and to +// prevent replay attacks. See OIDC Core Section 3.1.2.1. +func WithNonce(nonce string) AuthorizationOption { + return WithAdditionalParams(map[string]string{"nonce": nonce}) +} + // WithForceConsentScreen configures the provider to always request the consent screen // from the identity provider. When enabled, the "prompt=consent" parameter is added // to authorization requests, forcing the user to re-consent even if they have @@ -154,6 +172,13 @@ func NewOIDCProvider( scopes = []string{"openid", "profile", "email"} } + // Validate that openid scope is present for OIDC provider. + // Per OIDC Core, openid scope is mandatory for ID tokens. Without it, the IDP + // won't return an ID token, but OIDCProviderImpl requires one for identity resolution. + if !slices.Contains(scopes, "openid") { + return nil, errors.New("openid scope is required for OIDC provider; use BaseOAuth2Provider for pure OAuth 2.0 flows") + } + // Now create OAuth2Config from discovered endpoints + OIDC config. // This allows the embedded BaseOAuth2Provider to use the discovered endpoints // for token requests while preserving the original OIDC config. @@ -219,18 +244,38 @@ func (p *OIDCProviderImpl) ResolveIdentity(ctx context.Context, tokens *Tokens, return "", fmt.Errorf("%w: ID token required for OIDC provider", ErrIdentityResolutionFailed) } - claims, err := p.validateIDToken(ctx, tokens.IDToken, nonce) + validatedToken, err := p.validateIDToken(ctx, tokens.IDToken, nonce) if err != nil { - return "", fmt.Errorf("%w: ID token validation failed: %v", ErrIdentityResolutionFailed, err) + logger.Debugw("ID token validation failed", "error", err) + return "", fmt.Errorf("%w: ID token validation failed", ErrIdentityResolutionFailed) } - return claims.Subject, nil + return validatedToken.Subject, nil } // validateIDToken validates an ID token and returns the parsed token. -// TODO: Implement full validation using p.verifier in a follow-up PR. -func (*OIDCProviderImpl) validateIDToken(_ context.Context, _, _ string) (*oidc.IDToken, error) { - // Stub - full implementation in follow-up PR - return nil, errors.New("ID token validation not yet implemented") +func (p *OIDCProviderImpl) validateIDToken(ctx context.Context, idToken, nonce string) (*oidc.IDToken, error) { + if p.verifier == nil { + return nil, errors.New("ID token verifier not initialized") + } + + token, err := p.verifier.Verify(ctx, idToken) + if err != nil { + return nil, fmt.Errorf("failed to verify ID token: %w", err) + } + + // Validate nonce if expected (was sent in authorization request). + // This ensures that when a nonce is provided, the token MUST contain it + // and it MUST match, preventing replay attacks. + if nonce != "" { + if token.Nonce == "" { + return nil, ErrNonceMissing + } + if token.Nonce != nonce { + return nil, ErrNonceMismatch + } + } + + return token, nil } // supportsPKCE checks if the provider advertises S256 PKCE support. @@ -241,6 +286,140 @@ func (p *OIDCProviderImpl) supportsPKCE() bool { return p.endpoints.SupportsPKCE() } +// AuthorizationURL builds the URL to redirect the user to the upstream IDP. +// This overrides the base implementation to add OIDC-specific parameters (nonce, prompt) +// and use discovered endpoints. +func (p *OIDCProviderImpl) AuthorizationURL(state, codeChallenge string, opts ...AuthorizationOption) (string, error) { + if p.endpoints == nil { + return "", errors.New("OIDC endpoints not discovered") + } + + // Apply authorization options to extract nonce for logging + authOpts := &authorizationOptions{} + for _, opt := range opts { + opt(authOpts) + } + + // Extract nonce from additionalParams if present + nonce := "" + if authOpts.additionalParams != nil { + nonce = authOpts.additionalParams["nonce"] + } + + logger.Debugw("building authorization URL", + "authorization_endpoint", p.endpoints.AuthorizationEndpoint, + "has_pkce", codeChallenge != "", + "has_nonce", nonce != "", + ) + + // PKCE: Per RFC 7636 Section 5, clients SHOULD send PKCE parameters to all + // servers regardless of whether they advertise support. Servers that don't + // support PKCE will simply ignore the parameters. + if codeChallenge != "" && !p.supportsPKCE() { + logger.Debugw("sending PKCE to provider that does not advertise S256 support (per RFC 7636 Section 5)") + } + + // Merge caller's opts with OIDC-specific params + allOpts := append(opts, WithAdditionalParams(p.buildOIDCParams())) //nolint:gocritic // intentionally appending single element + + // Use the base implementation which uses oauth2Config (scopes already configured) + return p.buildAuthorizationURL(state, codeChallenge, allOpts...) +} + +// buildOIDCParams builds the OIDC-specific authorization parameters. +func (p *OIDCProviderImpl) buildOIDCParams() map[string]string { + params := make(map[string]string) + + // Add prompt=consent if configured to force the consent screen + if p.forceConsentScreen { + params["prompt"] = "consent" + } + + return params +} + +// ExchangeCode exchanges an authorization code for tokens with the upstream IDP. +// This overrides the base implementation to add OIDC-specific ID token validation. +func (p *OIDCProviderImpl) ExchangeCode(ctx context.Context, code, codeVerifier string) (*Tokens, error) { + if p.endpoints == nil { + return nil, errors.New("OIDC endpoints not discovered") + } + + logger.Debugw("exchanging authorization code for tokens", + "token_endpoint", p.endpoints.TokenEndpoint, + "has_pkce_verifier", codeVerifier != "", + ) + + // Use base provider's implementation for token exchange + tokens, err := p.BaseOAuth2Provider.ExchangeCode(ctx, code, codeVerifier) + if err != nil { + return nil, err + } + + // OIDC-specific: Validate ID token structure (signature, issuer, audience, expiry). + // Per Section 3.1.3.3, ID token MUST be present in OIDC token responses. + // Note: Nonce validation (Section 3.1.3.7) is deferred to ResolveIdentity, + // which has access to the expected nonce from the authorization request. + // Callers MUST call ResolveIdentity after ExchangeCode for full OIDC compliance. + if tokens.IDToken == "" { + return nil, errors.New("ID token required for OIDC provider") + } + if _, err := p.validateIDToken(ctx, tokens.IDToken, ""); err != nil { + return nil, fmt.Errorf("ID token validation failed: %w", err) + } + + logger.Debugw("authorization code exchange successful", + "has_refresh_token", tokens.RefreshToken != "", + "has_id_token", tokens.IDToken != "", + "expires_at", tokens.ExpiresAt.Format(time.RFC3339), + ) + + return tokens, nil +} + +// RefreshTokens refreshes the upstream IDP tokens. +// This overrides the base implementation to add OIDC-specific ID token validation. +func (p *OIDCProviderImpl) RefreshTokens(ctx context.Context, refreshToken, expectedSubject string) (*Tokens, error) { + if p.endpoints == nil { + return nil, errors.New("OIDC endpoints not discovered") + } + + logger.Debugw("refreshing tokens", + "token_endpoint", p.endpoints.TokenEndpoint, + ) + + // Use base provider's implementation for token refresh + tokens, err := p.BaseOAuth2Provider.RefreshTokens(ctx, refreshToken, expectedSubject) + if err != nil { + return nil, err + } + + // OIDC-specific: Validate ID token if present. + // Per OIDC Core Section 12.2, refresh responses MAY include a new ID token + // (unlike ExchangeCode where it's required per Section 3.1.3.3). + // Nonce validation is intentionally omitted: Section 12.2 states that + // refreshed ID tokens SHOULD NOT contain a nonce claim, and no new + // authorization request exists to provide an expected nonce value. + // Full nonce validation occurs in ResolveIdentity during the initial auth flow. + if tokens.IDToken != "" && p.verifier != nil { + token, err := p.validateIDToken(ctx, tokens.IDToken, "") + if err != nil { + return nil, fmt.Errorf("ID token validation failed: %w", err) + } + // OIDC Core Section 12.2: sub claim MUST be identical to the original. + if expectedSubject != "" && token.Subject != expectedSubject { + return nil, ErrSubjectMismatch + } + } + + logger.Debugw("token refresh successful", + "has_new_refresh_token", tokens.RefreshToken != "", + "expires_at", tokens.ExpiresAt.Format(time.RFC3339), + ) + + return tokens, nil +} + // validateDiscoveryDocument validates the OIDC discovery document. // // It first delegates to OIDCDiscoveryDocument.Validate() for spec-compliant field diff --git a/pkg/authserver/upstream/oidc_test.go b/pkg/authserver/upstream/oidc_test.go index 3216bd363f..5d35c43255 100644 --- a/pkg/authserver/upstream/oidc_test.go +++ b/pkg/authserver/upstream/oidc_test.go @@ -16,6 +16,8 @@ import ( "testing" "time" + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -122,7 +124,9 @@ func (*mockOIDCServer) handleUserInfo(w http.ResponseWriter, r *http.Request) { "email": "test@example.com", } w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(resp) + if err := json.NewEncoder(w).Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } func (m *mockOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) { @@ -145,6 +149,35 @@ func (m *mockOIDCServer) handleJWKS(w http.ResponseWriter, _ *http.Request) { } } +// signIDToken creates a signed JWT ID token. +// +//nolint:unparam // subject parameter kept for test flexibility +func (m *mockOIDCServer) signIDToken(audience, subject, nonce string, expiry time.Time) string { + signingKey := jose.SigningKey{Algorithm: jose.RS256, Key: m.privateKey} + signer, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", m.keyID)) + if err != nil { + panic(err) + } + + claims := map[string]any{ + "iss": m.issuer, + "sub": subject, + "aud": audience, + "exp": expiry.Unix(), + "iat": time.Now().Unix(), + } + if nonce != "" { + claims["nonce"] = nonce + } + + token, err := jwt.Signed(signer).Claims(claims).CompactSerialize() + if err != nil { + panic(err) + } + + return token +} + func TestNewOIDCProvider(t *testing.T) { t.Parallel() @@ -340,6 +373,28 @@ func TestNewOIDCProvider(t *testing.T) { require.NotNil(t, provider) // Force consent screen is tested in commit 2 with AuthorizationURL tests }) + + t.Run("scopes without openid returns error", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + Scopes: []string{"profile", "email"}, // missing openid + }, + Issuer: mock.issuer, + } + + ctx := context.Background() + _, err := NewOIDCProvider(ctx, config) + require.Error(t, err) + assert.Contains(t, err.Error(), "openid scope is required") + }) } func TestValidateDiscoveryDocument(t *testing.T) { @@ -434,12 +489,69 @@ func TestOIDCProviderImpl_ResolveIdentity(t *testing.T) { provider, err := NewOIDCProvider(ctx, config) require.NoError(t, err) + t.Run("valid ID token returns subject", func(t *testing.T) { + t.Parallel() + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + tokens := &Tokens{ + AccessToken: "test-access-token", + IDToken: idToken, + } + subject, err := provider.ResolveIdentity(ctx, tokens, "") + require.NoError(t, err) + assert.Equal(t, "user-123", subject) + }) + + t.Run("valid ID token with nonce returns subject", func(t *testing.T) { + t.Parallel() + idToken := mock.signIDToken(testClientID, "user-456", "test-nonce", time.Now().Add(time.Hour)) + tokens := &Tokens{ + AccessToken: "test-access-token", + IDToken: idToken, + } + subject, err := provider.ResolveIdentity(ctx, tokens, "test-nonce") + require.NoError(t, err) + assert.Equal(t, "user-456", subject) + }) + + t.Run("nonce mismatch returns error", func(t *testing.T) { + t.Parallel() + idToken := mock.signIDToken(testClientID, "user-123", "token-nonce", time.Now().Add(time.Hour)) + tokens := &Tokens{ + AccessToken: "test-access-token", + IDToken: idToken, + } + _, err := provider.ResolveIdentity(ctx, tokens, "different-nonce") + require.Error(t, err) + require.ErrorIs(t, err, ErrIdentityResolutionFailed) + }) + + t.Run("missing nonce in token when expected returns error", func(t *testing.T) { + t.Parallel() + // Sign ID token without nonce + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + tokens := &Tokens{ + AccessToken: "test-access-token", + IDToken: idToken, + } + // But caller expects a nonce - this should fail + // (detailed error logged at DEBUG, generic error returned for security) + _, err := provider.ResolveIdentity(ctx, tokens, "expected-nonce") + require.Error(t, err) + require.ErrorIs(t, err, ErrIdentityResolutionFailed) + + // Also verify the underlying error is ErrNonceMissing via direct validation + _, validationErr := provider.validateIDToken(ctx, idToken, "expected-nonce") + require.Error(t, validationErr) + require.ErrorIs(t, validationErr, ErrNonceMissing) + }) + + // Error cases table tests := []struct { name string tokens *Tokens wantContain string // empty means just check ErrorIs }{ - {"with ID token returns validation error", &Tokens{AccessToken: "test-access-token", IDToken: "dummy-id-token"}, "ID token validation failed"}, + {"invalid ID token returns validation error", &Tokens{AccessToken: "test-access-token", IDToken: "dummy-id-token"}, "ID token validation failed"}, {"without ID token returns error", &Tokens{AccessToken: "test-access-token"}, "ID token required"}, {"nil tokens returns error", nil, ""}, } @@ -455,3 +567,473 @@ func TestOIDCProviderImpl_ResolveIdentity(t *testing.T) { }) } } + +func TestOIDCProvider_AuthorizationURL(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + Scopes: []string{"openid", "profile"}, + }, + Issuer: mock.issuer, + } + + ctx := context.Background() + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tests := []struct { + name string + state string + codeChallenge string + opts []AuthorizationOption + wantParams map[string]string // exact match + wantContains map[string]string // substring match + wantErr string + }{ + { + name: "builds correct URL with all parameters", + state: "test-state", + wantParams: map[string]string{ + "response_type": "code", + "client_id": testClientID, + "redirect_uri": testRedirectURI, + "state": "test-state", + }, + wantContains: map[string]string{"scope": "openid"}, + }, + { + name: "includes PKCE code_challenge when provided", + state: "test-state", + codeChallenge: "test-challenge-abc123", + wantParams: map[string]string{ + "code_challenge": "test-challenge-abc123", + "code_challenge_method": "S256", + }, + }, + { + name: "includes nonce with WithNonce option", + state: "test-state", + opts: []AuthorizationOption{WithNonce("test-nonce-123")}, + wantParams: map[string]string{ + "nonce": "test-nonce-123", + }, + }, + { + name: "includes additional params", + state: "test-state", + opts: []AuthorizationOption{WithAdditionalParams(map[string]string{ + "login_hint": "user@example.com", + "acr_values": "urn:mace:incommon:iap:silver", + })}, + wantParams: map[string]string{ + "login_hint": "user@example.com", + "acr_values": "urn:mace:incommon:iap:silver", + }, + }, + { + name: "returns error for empty state", + state: "", + wantErr: "state parameter is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + authURL, err := provider.AuthorizationURL(tt.state, tt.codeChallenge, tt.opts...) + + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + + require.NoError(t, err) + parsed, err := url.Parse(authURL) + require.NoError(t, err) + + query := parsed.Query() + for key, want := range tt.wantParams { + assert.Equal(t, want, query.Get(key), "param %s", key) + } + for key, want := range tt.wantContains { + assert.Contains(t, query.Get(key), want, "param %s", key) + } + }) + } +} + +func TestOIDCProvider_ExchangeCode(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("successful token exchange with ID token", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + var receivedParams url.Values + mock.tokenHandler = func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + receivedParams = r.PostForm + + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + + resp := testTokenResponse{ + AccessToken: "exchanged-access-token", + TokenType: "Bearer", + RefreshToken: "exchanged-refresh-token", + IDToken: idToken, + ExpiresIn: 7200, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tokens, err := provider.ExchangeCode(ctx, "test-auth-code", "test-verifier") + require.NoError(t, err) + + // Verify request parameters + assert.Equal(t, "authorization_code", receivedParams.Get("grant_type")) + assert.Equal(t, "test-auth-code", receivedParams.Get("code")) + assert.Equal(t, "test-verifier", receivedParams.Get("code_verifier")) + + // Verify response + assert.Equal(t, "exchanged-access-token", tokens.AccessToken) + assert.Equal(t, "exchanged-refresh-token", tokens.RefreshToken) + assert.NotEmpty(t, tokens.IDToken) + }) + + t.Run("empty code returns error", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + _, err = provider.ExchangeCode(ctx, "", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "authorization code is required") + }) + + t.Run("invalid ID token fails validation", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + resp := testTokenResponse{ + AccessToken: "access-token", + TokenType: "Bearer", + IDToken: "invalid.token.here", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + _, err = provider.ExchangeCode(ctx, "test-code", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "ID token validation failed") + }) + + t.Run("token endpoint error", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + resp := testTokenErrorResponse{ + Error: "invalid_grant", + ErrorDescription: "The authorization code has expired", + } + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + _, err = provider.ExchangeCode(ctx, "expired-code", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid_grant") + }) +} + +func TestOIDCProvider_RefreshTokens(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + t.Run("successful token refresh", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + var receivedParams url.Values + mock.tokenHandler = func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + receivedParams = r.PostForm + + resp := testTokenResponse{ + AccessToken: "refreshed-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "") + require.NoError(t, err) + + // Verify request parameters + assert.Equal(t, "refresh_token", receivedParams.Get("grant_type")) + assert.Equal(t, "old-refresh-token", receivedParams.Get("refresh_token")) + + // Verify response + assert.Equal(t, "refreshed-access-token", tokens.AccessToken) + assert.Equal(t, "new-refresh-token", tokens.RefreshToken) + }) + + t.Run("empty refresh token returns error", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + _, err = provider.RefreshTokens(ctx, "", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "refresh token is required") + }) + + t.Run("refresh with matching subject succeeds", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + resp := testTokenResponse{ + AccessToken: "refreshed-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + IDToken: idToken, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "user-123") + require.NoError(t, err) + assert.Equal(t, "refreshed-access-token", tokens.AccessToken) + }) + + t.Run("refresh with mismatched subject returns ErrSubjectMismatch", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + resp := testTokenResponse{ + AccessToken: "refreshed-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + IDToken: idToken, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + _, err = provider.RefreshTokens(ctx, "old-refresh-token", "different-user") + require.ErrorIs(t, err, ErrSubjectMismatch) + }) + + t.Run("refresh without ID token skips subject validation", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + resp := testTokenResponse{ + AccessToken: "refreshed-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "user-123") + require.NoError(t, err) + assert.Equal(t, "refreshed-access-token", tokens.AccessToken) + }) + + t.Run("refresh with empty expectedSubject skips subject validation", func(t *testing.T) { + t.Parallel() + + mock := newMockOIDCServer(t) + t.Cleanup(mock.Close) + + mock.tokenHandler = func(w http.ResponseWriter, _ *http.Request) { + idToken := mock.signIDToken(testClientID, "user-123", "", time.Now().Add(time.Hour)) + resp := testTokenResponse{ + AccessToken: "refreshed-access-token", + TokenType: "Bearer", + RefreshToken: "new-refresh-token", + ExpiresIn: 3600, + IDToken: idToken, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + } + + config := &OIDCConfig{ + CommonOAuthConfig: CommonOAuthConfig{ + ClientID: testClientID, + ClientSecret: testClientSecret, + RedirectURI: testRedirectURI, + }, + Issuer: mock.issuer, + } + + provider, err := NewOIDCProvider(ctx, config) + require.NoError(t, err) + + tokens, err := provider.RefreshTokens(ctx, "old-refresh-token", "") + require.NoError(t, err) + assert.Equal(t, "refreshed-access-token", tokens.AccessToken) + }) +}