From 847823bbf8edbf61a50645e4073e95f7ae59e82d Mon Sep 17 00:00:00 2001 From: Arunavo Ray Date: Wed, 22 Oct 2025 16:33:33 +0530 Subject: [PATCH] sso: normalize provider config via discovery --- src/lib/auth.ts | 5 +- src/lib/sso/oidc-config.test.ts | 56 ++++++++ src/lib/sso/oidc-config.ts | 202 ++++++++++++++++++++++++++++ src/lib/sso/provider-maintenance.ts | 65 +++++++++ src/pages/api/auth/sso/register.ts | 113 +++++++++++----- src/pages/api/sso/providers.ts | 134 +++++++++++++----- 6 files changed, 507 insertions(+), 68 deletions(-) create mode 100644 src/lib/sso/oidc-config.test.ts create mode 100644 src/lib/sso/oidc-config.ts create mode 100644 src/lib/sso/provider-maintenance.ts diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 24ae5ce..e6812d5 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -5,6 +5,9 @@ import { sso } from "@better-auth/sso"; import { db, users } from "./db"; import * as schema from "./db/schema"; import { eq } from "drizzle-orm"; +import { ensureValidSsoProviders } from "./sso/provider-maintenance"; + +await ensureValidSsoProviders(); export const auth = betterAuth({ // Database configuration @@ -171,4 +174,4 @@ export const auth = betterAuth({ }); // Export type for use in other parts of the app -export type Auth = typeof auth; \ No newline at end of file +export type Auth = typeof auth; diff --git a/src/lib/sso/oidc-config.test.ts b/src/lib/sso/oidc-config.test.ts new file mode 100644 index 0000000..e549477 --- /dev/null +++ b/src/lib/sso/oidc-config.test.ts @@ -0,0 +1,56 @@ +import { describe, expect, it } from "bun:test"; +import { normalizeOidcProviderConfig, OidcConfigError } from "./oidc-config"; + +const issuer = "https://auth.example.com"; + +describe("normalizeOidcProviderConfig", () => { + it("returns provided endpoints when complete", async () => { + const result = await normalizeOidcProviderConfig(issuer, { + clientId: "client", + clientSecret: "secret", + authorizationEndpoint: "https://auth.example.com/auth", + tokenEndpoint: "https://auth.example.com/token", + jwksEndpoint: "https://auth.example.com/jwks", + userInfoEndpoint: "https://auth.example.com/userinfo", + scopes: ["openid", "email"], + pkce: false, + }, async () => { + throw new Error("fetch should not be called when endpoints are provided"); + }); + + expect(result.oidcConfig.authorizationEndpoint).toBe("https://auth.example.com/auth"); + expect(result.oidcConfig.tokenEndpoint).toBe("https://auth.example.com/token"); + expect(result.oidcConfig.jwksEndpoint).toBe("https://auth.example.com/jwks"); + expect(result.oidcConfig.userInfoEndpoint).toBe("https://auth.example.com/userinfo"); + expect(result.oidcConfig.scopes).toEqual(["openid", "email"]); + expect(result.oidcConfig.pkce).toBe(false); + }); + + it("derives missing fields from discovery", async () => { + const fetchMock = async () => + new Response(JSON.stringify({ + authorization_endpoint: "https://auth.example.com/auth", + token_endpoint: "https://auth.example.com/token", + jwks_uri: "https://auth.example.com/jwks", + userinfo_endpoint: "https://auth.example.com/userinfo", + scopes_supported: ["openid", "email", "profile"], + })); + + const result = await normalizeOidcProviderConfig(issuer, { + clientId: "client", + clientSecret: "secret", + }, fetchMock); + + expect(result.oidcConfig.authorizationEndpoint).toBe("https://auth.example.com/auth"); + expect(result.oidcConfig.tokenEndpoint).toBe("https://auth.example.com/token"); + expect(result.oidcConfig.jwksEndpoint).toBe("https://auth.example.com/jwks"); + expect(result.oidcConfig.userInfoEndpoint).toBe("https://auth.example.com/userinfo"); + expect(result.oidcConfig.scopes).toEqual(["openid", "email", "profile"]); + }); + + it("throws for invalid issuer URL", async () => { + await expect( + normalizeOidcProviderConfig("not-a-url", {}), + ).rejects.toBeInstanceOf(OidcConfigError); + }); +}); diff --git a/src/lib/sso/oidc-config.ts b/src/lib/sso/oidc-config.ts new file mode 100644 index 0000000..57af364 --- /dev/null +++ b/src/lib/sso/oidc-config.ts @@ -0,0 +1,202 @@ +import { z } from "zod"; + +const DEFAULT_SCOPES = ["openid", "email", "profile"] as const; +const DISCOVERY_TIMEOUT_MS = 10000; + +const discoverySchema = z.object({ + issuer: z.string().url().optional(), + authorization_endpoint: z.string().url().optional(), + token_endpoint: z.string().url().optional(), + userinfo_endpoint: z.string().url().optional(), + jwks_uri: z.string().url().optional(), + scopes_supported: z.array(z.string()).optional(), +}); + +export class OidcConfigError extends Error { + constructor(message: string) { + super(message); + this.name = "OidcConfigError"; + } +} + +export type RawOidcConfig = { + clientId?: string; + clientSecret?: string; + authorizationEndpoint?: string; + tokenEndpoint?: string; + jwksEndpoint?: string; + userInfoEndpoint?: string; + discoveryEndpoint?: string; + scopes?: string[]; + pkce?: boolean; + mapping?: ProviderMapping; +}; + +export type ProviderMapping = { + id: string; + email: string; + emailVerified?: string; + name?: string; + image?: string; + firstName?: string; + lastName?: string; +}; + +export type NormalizedOidcConfig = { + oidcConfig: { + clientId?: string; + clientSecret?: string; + authorizationEndpoint: string; + tokenEndpoint: string; + jwksEndpoint?: string; + userInfoEndpoint?: string; + discoveryEndpoint: string; + scopes: string[]; + pkce: boolean; + }; + mapping: ProviderMapping; +}; + +type FetchFn = typeof fetch; + +function cleanUrl(value: string | undefined, field: string): string | undefined { + if (!value || typeof value !== "string") return undefined; + const trimmed = value.trim(); + if (!trimmed) return undefined; + try { + return new URL(trimmed).toString(); + } catch { + throw new OidcConfigError(`Invalid ${field} URL: ${value}`); + } +} + +function sanitizeScopes(scopes: string[] | undefined, fallback: readonly string[]): string[] { + const candidates = Array.isArray(scopes) ? scopes : []; + const sanitized = candidates + .map(scope => scope?.trim()) + .filter((scope): scope is string => Boolean(scope)); + + if (sanitized.length === 0) { + return [...fallback]; + } + + return Array.from(new Set(sanitized)); +} + +async function fetchDiscoveryDocument(url: string, fetchFn: FetchFn): Promise> { + const controller = new AbortController(); + const timeoutId = setTimeout(() => controller.abort(), DISCOVERY_TIMEOUT_MS); + + try { + const response = await fetchFn(url, { + signal: controller.signal, + headers: { Accept: "application/json" }, + }); + + if (!response.ok) { + throw new OidcConfigError(`OIDC discovery request failed (${response.status} ${response.statusText})`); + } + + let payload: unknown; + try { + payload = await response.json(); + } catch { + throw new OidcConfigError("OIDC discovery response is not valid JSON"); + } + + const parsed = discoverySchema.parse(payload); + if (!parsed.authorization_endpoint || !parsed.token_endpoint) { + throw new OidcConfigError("OIDC discovery document is missing required endpoints"); + } + return parsed; + } catch (error) { + if (error instanceof OidcConfigError) { + throw error; + } + if (error instanceof Error && error.name === "AbortError") { + throw new OidcConfigError(`OIDC discovery timed out after ${DISCOVERY_TIMEOUT_MS / 1000}s`); + } + throw new OidcConfigError(`Failed to fetch OIDC discovery document: ${error instanceof Error ? error.message : "unknown error"}`); + } finally { + clearTimeout(timeoutId); + } +} + +export async function normalizeOidcProviderConfig( + issuer: string, + rawConfig: RawOidcConfig, + fetchFn: FetchFn = fetch, +): Promise { + if (!issuer || typeof issuer !== "string") { + throw new OidcConfigError("Issuer is required"); + } + + let normalizedIssuer: string; + try { + const issuerUrl = new URL(issuer.trim()); + normalizedIssuer = issuerUrl.toString().replace(/\/$/, ""); + } catch { + throw new OidcConfigError(`Invalid issuer URL: ${issuer}`); + } + + const discoveryEndpoint = cleanUrl( + rawConfig.discoveryEndpoint, + "discovery endpoint", + ) ?? `${normalizedIssuer}/.well-known/openid-configuration`; + + const authorizationEndpoint = cleanUrl(rawConfig.authorizationEndpoint, "authorization endpoint"); + const tokenEndpoint = cleanUrl(rawConfig.tokenEndpoint, "token endpoint"); + const jwksEndpoint = cleanUrl(rawConfig.jwksEndpoint, "JWKS endpoint"); + const userInfoEndpoint = cleanUrl(rawConfig.userInfoEndpoint, "userinfo endpoint"); + const providedScopes = Array.isArray(rawConfig.scopes) ? rawConfig.scopes : undefined; + let scopes = sanitizeScopes(providedScopes, DEFAULT_SCOPES); + + const shouldFetchDiscovery = + !authorizationEndpoint || + !tokenEndpoint || + !jwksEndpoint || + !userInfoEndpoint || + !providedScopes || + providedScopes.length === 0; + + let resolvedAuthorization = authorizationEndpoint; + let resolvedToken = tokenEndpoint; + let resolvedJwks = jwksEndpoint; + let resolvedUserInfo = userInfoEndpoint; + + if (shouldFetchDiscovery) { + const discovery = await fetchDiscoveryDocument(discoveryEndpoint, fetchFn); + resolvedAuthorization = resolvedAuthorization ?? discovery.authorization_endpoint; + resolvedToken = resolvedToken ?? discovery.token_endpoint; + resolvedJwks = resolvedJwks ?? discovery.jwks_uri; + resolvedUserInfo = resolvedUserInfo ?? discovery.userinfo_endpoint; + if (!providedScopes || providedScopes.length === 0) { + scopes = sanitizeScopes(discovery.scopes_supported, DEFAULT_SCOPES); + } + } + + if (!resolvedAuthorization || !resolvedToken) { + throw new OidcConfigError("OIDC configuration must include authorization and token endpoints"); + } + + return { + oidcConfig: { + clientId: rawConfig.clientId, + clientSecret: rawConfig.clientSecret, + authorizationEndpoint: resolvedAuthorization, + tokenEndpoint: resolvedToken, + jwksEndpoint: resolvedJwks, + userInfoEndpoint: resolvedUserInfo, + discoveryEndpoint, + scopes, + pkce: rawConfig.pkce !== false, + }, + mapping: rawConfig.mapping ?? { + id: "sub", + email: "email", + emailVerified: "email_verified", + name: "name", + image: "picture", + }, + }; +} diff --git a/src/lib/sso/provider-maintenance.ts b/src/lib/sso/provider-maintenance.ts new file mode 100644 index 0000000..a3e5d7b --- /dev/null +++ b/src/lib/sso/provider-maintenance.ts @@ -0,0 +1,65 @@ +import { db, ssoProviders } from "@/lib/db"; +import { normalizeOidcProviderConfig, OidcConfigError } from "./oidc-config"; +import { eq } from "drizzle-orm"; + +type Logger = Pick; + +export async function ensureValidSsoProviders(logger: Logger = console): Promise { + const providers = await db.select().from(ssoProviders); + + for (const provider of providers) { + if (!provider.oidcConfig) continue; + + let parsedConfig: any; + try { + parsedConfig = JSON.parse(provider.oidcConfig); + } catch (error) { + logger.warn(`[SSO] Skipping provider ${provider.providerId}: invalid JSON configuration`, error); + continue; + } + + const hasEndpoints = + parsedConfig.authorizationEndpoint && + parsedConfig.tokenEndpoint; + + const hasScopes = Array.isArray(parsedConfig.scopes) && parsedConfig.scopes.length > 0; + + if (hasEndpoints && hasScopes) { + continue; + } + + try { + const normalized = await normalizeOidcProviderConfig(provider.issuer, { + clientId: parsedConfig.clientId, + clientSecret: parsedConfig.clientSecret, + authorizationEndpoint: parsedConfig.authorizationEndpoint, + tokenEndpoint: parsedConfig.tokenEndpoint, + jwksEndpoint: parsedConfig.jwksEndpoint, + userInfoEndpoint: parsedConfig.userInfoEndpoint, + discoveryEndpoint: parsedConfig.discoveryEndpoint, + scopes: parsedConfig.scopes, + pkce: parsedConfig.pkce, + mapping: parsedConfig.mapping, + }); + + await db + .update(ssoProviders) + .set({ + oidcConfig: JSON.stringify({ + ...normalized.oidcConfig, + mapping: normalized.mapping, + }), + updatedAt: new Date(), + }) + .where(eq(ssoProviders.id, provider.id)); + + logger.info(`[SSO] Normalized OIDC configuration for provider ${provider.providerId}`); + } catch (error) { + if (error instanceof OidcConfigError) { + logger.warn(`[SSO] Unable to normalize provider ${provider.providerId}: ${error.message}`); + } else { + logger.error(`[SSO] Unexpected error normalizing provider ${provider.providerId}`, error); + } + } + } +} diff --git a/src/pages/api/auth/sso/register.ts b/src/pages/api/auth/sso/register.ts index c32b5b0..21b193a 100644 --- a/src/pages/api/auth/sso/register.ts +++ b/src/pages/api/auth/sso/register.ts @@ -2,6 +2,10 @@ import type { APIContext } from "astro"; import { createSecureErrorResponse } from "@/lib/utils"; import { requireAuth } from "@/lib/utils/auth-helpers"; import { auth } from "@/lib/auth"; +import { db, ssoProviders } from "@/lib/db"; +import { eq } from "drizzle-orm"; +import { nanoid } from "nanoid"; +import { normalizeOidcProviderConfig, OidcConfigError } from "@/lib/sso/oidc-config"; // POST /api/auth/sso/register - Register a new SSO provider using Better Auth export async function POST(context: APIContext) { @@ -104,43 +108,37 @@ export async function POST(context: APIContext) { userInfoEndpoint, scopes, pkce = true, - mapping = { - id: "sub", - email: "email", - emailVerified: "email_verified", - name: "name", - image: "picture", - } + mapping, } = body; - // Use provided scopes or default if not specified - const finalScopes = scopes || ["openid", "email", "profile"]; + try { + const normalized = await normalizeOidcProviderConfig(validatedIssuer, { + clientId, + clientSecret, + authorizationEndpoint, + tokenEndpoint, + jwksEndpoint, + userInfoEndpoint, + discoveryEndpoint, + scopes, + pkce, + mapping, + }); - // Validate endpoint URLs if provided - const validateUrl = (url: string | undefined, name: string): string | undefined => { - if (!url) return undefined; - if (typeof url !== 'string' || url.trim() === '') return undefined; - try { - const validatedUrl = new URL(url.trim()); - return validatedUrl.toString(); - } catch (e) { - console.warn(`Invalid ${name} URL: ${url}, skipping`); - return undefined; + registrationBody.oidcConfig = normalized.oidcConfig; + registrationBody.mapping = normalized.mapping; + } catch (error) { + if (error instanceof OidcConfigError) { + return new Response( + JSON.stringify({ error: error.message }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + } + ); } - }; - - registrationBody.oidcConfig = { - clientId: clientId || undefined, - clientSecret: clientSecret || undefined, - authorizationEndpoint: validateUrl(authorizationEndpoint, 'authorization endpoint'), - tokenEndpoint: validateUrl(tokenEndpoint, 'token endpoint'), - jwksEndpoint: validateUrl(jwksEndpoint, 'JWKS endpoint'), - discoveryEndpoint: validateUrl(discoveryEndpoint, 'discovery endpoint'), - userInfoEndpoint: validateUrl(userInfoEndpoint, 'userinfo endpoint'), - scopes: finalScopes, - pkce, - }; - registrationBody.mapping = mapping; + throw error; + } } // Get the user's auth headers to make the request @@ -168,7 +166,52 @@ export async function POST(context: APIContext) { } const result = await response.json(); - + + // Mirror provider entry into local SSO table for UI listing + try { + const existing = await db + .select() + .from(ssoProviders) + .where(eq(ssoProviders.providerId, registrationBody.providerId)) + .limit(1); + + const values: any = { + issuer: registrationBody.issuer, + domain: registrationBody.domain, + organizationId: registrationBody.organizationId, + updatedAt: new Date(), + }; + + if (registrationBody.oidcConfig) { + values.oidcConfig = JSON.stringify({ + ...registrationBody.oidcConfig, + mapping: registrationBody.mapping, + }); + } + + if (existing.length > 0) { + await db + .update(ssoProviders) + .set(values) + .where(eq(ssoProviders.id, existing[0].id)); + } else { + await db.insert(ssoProviders).values({ + id: nanoid(), + issuer: registrationBody.issuer, + domain: registrationBody.domain, + oidcConfig: JSON.stringify({ + ...registrationBody.oidcConfig, + mapping: registrationBody.mapping, + }), + userId: user.id, + providerId: registrationBody.providerId, + organizationId: registrationBody.organizationId, + }); + } + } catch (mirroringError) { + console.warn("Failed to mirror SSO provider to local DB:", mirroringError); + } + return new Response(JSON.stringify(result), { status: 201, headers: { "Content-Type": "application/json" }, @@ -199,4 +242,4 @@ export async function GET(context: APIContext) { } catch (error) { return createSecureErrorResponse(error, "SSO provider listing"); } -} \ No newline at end of file +} diff --git a/src/pages/api/sso/providers.ts b/src/pages/api/sso/providers.ts index 5b4eb60..5be56fd 100644 --- a/src/pages/api/sso/providers.ts +++ b/src/pages/api/sso/providers.ts @@ -4,6 +4,7 @@ import { requireAuth } from "@/lib/utils/auth-helpers"; import { db, ssoProviders } from "@/lib/db"; import { nanoid } from "nanoid"; import { eq } from "drizzle-orm"; +import { normalizeOidcProviderConfig, OidcConfigError, type RawOidcConfig } from "@/lib/sso/oidc-config"; // GET /api/sso/providers - List all SSO providers export async function GET(context: APIContext) { @@ -45,10 +46,12 @@ export async function POST(context: APIContext) { tokenEndpoint, jwksEndpoint, userInfoEndpoint, + discoveryEndpoint, mapping, providerId, organizationId, scopes, + pkce, } = body; // Validate required fields @@ -79,22 +82,51 @@ export async function POST(context: APIContext) { ); } - // Create OIDC config object - const oidcConfig = { - clientId, - clientSecret, - authorizationEndpoint, - tokenEndpoint, - jwksEndpoint, - userInfoEndpoint, - scopes: scopes || ["openid", "email", "profile"], - mapping: mapping || { - id: "sub", - email: "email", - emailVerified: "email_verified", - name: "name", - image: "picture", - }, + // Clean issuer URL (remove trailing slash); validate format + let cleanIssuer = issuer; + try { + const issuerUrl = new URL(issuer.toString().trim()); + cleanIssuer = issuerUrl.toString().replace(/\/$/, ""); + } catch { + return new Response( + JSON.stringify({ error: `Invalid issuer URL format: ${issuer}` }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + } + ); + } + + let normalized; + try { + normalized = await normalizeOidcProviderConfig(cleanIssuer, { + clientId, + clientSecret, + authorizationEndpoint, + tokenEndpoint, + jwksEndpoint, + userInfoEndpoint, + discoveryEndpoint, + scopes, + pkce, + mapping, + }); + } catch (error) { + if (error instanceof OidcConfigError) { + return new Response( + JSON.stringify({ error: error.message }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + } + ); + } + throw error; + } + + const storedOidcConfig = { + ...normalized.oidcConfig, + mapping: normalized.mapping, }; // Insert new provider @@ -102,9 +134,9 @@ export async function POST(context: APIContext) { .insert(ssoProviders) .values({ id: nanoid(), - issuer, + issuer: cleanIssuer, domain, - oidcConfig: JSON.stringify(oidcConfig), + oidcConfig: JSON.stringify(storedOidcConfig), userId: user.id, providerId, organizationId, @@ -156,7 +188,9 @@ export async function PUT(context: APIContext) { tokenEndpoint, jwksEndpoint, userInfoEndpoint, + discoveryEndpoint, scopes, + pkce, organizationId, } = body; @@ -179,26 +213,62 @@ export async function PUT(context: APIContext) { // Parse existing config const existingConfig = JSON.parse(existingProvider.oidcConfig); + const effectiveIssuer = issuer || existingProvider.issuer; - // Create updated OIDC config - const updatedOidcConfig = { - ...existingConfig, - clientId: clientId || existingConfig.clientId, - clientSecret: clientSecret || existingConfig.clientSecret, - authorizationEndpoint: authorizationEndpoint || existingConfig.authorizationEndpoint, - tokenEndpoint: tokenEndpoint || existingConfig.tokenEndpoint, - jwksEndpoint: jwksEndpoint || existingConfig.jwksEndpoint, - userInfoEndpoint: userInfoEndpoint || existingConfig.userInfoEndpoint, - scopes: scopes || existingConfig.scopes || ["openid", "email", "profile"], + let cleanIssuer = effectiveIssuer; + try { + const issuerUrl = new URL(effectiveIssuer.toString().trim()); + cleanIssuer = issuerUrl.toString().replace(/\/$/, ""); + } catch { + return new Response( + JSON.stringify({ error: `Invalid issuer URL format: ${effectiveIssuer}` }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + } + ); + } + + const mergedConfig: RawOidcConfig = { + clientId: clientId ?? existingConfig.clientId, + clientSecret: clientSecret ?? existingConfig.clientSecret, + authorizationEndpoint: authorizationEndpoint ?? existingConfig.authorizationEndpoint, + tokenEndpoint: tokenEndpoint ?? existingConfig.tokenEndpoint, + jwksEndpoint: jwksEndpoint ?? existingConfig.jwksEndpoint, + userInfoEndpoint: userInfoEndpoint ?? existingConfig.userInfoEndpoint, + discoveryEndpoint: discoveryEndpoint ?? existingConfig.discoveryEndpoint, + scopes: scopes ?? existingConfig.scopes, + pkce: pkce ?? existingConfig.pkce, + mapping: existingConfig.mapping, + }; + + let normalized; + try { + normalized = await normalizeOidcProviderConfig(cleanIssuer, mergedConfig); + } catch (error) { + if (error instanceof OidcConfigError) { + return new Response( + JSON.stringify({ error: error.message }), + { + status: 400, + headers: { "Content-Type": "application/json" }, + } + ); + } + throw error; + } + + const storedOidcConfig = { + ...normalized.oidcConfig, + mapping: normalized.mapping, }; - // Update provider const [updatedProvider] = await db .update(ssoProviders) .set({ - issuer: issuer || existingProvider.issuer, + issuer: cleanIssuer, domain: domain || existingProvider.domain, - oidcConfig: JSON.stringify(updatedOidcConfig), + oidcConfig: JSON.stringify(storedOidcConfig), organizationId: organizationId !== undefined ? organizationId : existingProvider.organizationId, updatedAt: new Date(), }) @@ -259,4 +329,4 @@ export async function DELETE(context: APIContext) { } catch (error) { return createSecureErrorResponse(error, "SSO providers API"); } -} \ No newline at end of file +}