sso: normalize provider config via discovery

This commit is contained in:
Arunavo Ray
2025-10-22 16:33:33 +05:30
parent e4e54722cf
commit 847823bbf8
6 changed files with 507 additions and 68 deletions

View File

@@ -5,6 +5,9 @@ import { sso } from "@better-auth/sso";
import { db, users } from "./db"; import { db, users } from "./db";
import * as schema from "./db/schema"; import * as schema from "./db/schema";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { ensureValidSsoProviders } from "./sso/provider-maintenance";
await ensureValidSsoProviders();
export const auth = betterAuth({ export const auth = betterAuth({
// Database configuration // Database configuration
@@ -171,4 +174,4 @@ export const auth = betterAuth({
}); });
// Export type for use in other parts of the app // Export type for use in other parts of the app
export type Auth = typeof auth; export type Auth = typeof auth;

View File

@@ -0,0 +1,56 @@
import { describe, expect, it } from "bun:test";
import { normalizeOidcProviderConfig, OidcConfigError } from "./oidc-config";
const issuer = "https://auth.example.com";
describe("normalizeOidcProviderConfig", () => {
it("returns provided endpoints when complete", async () => {
const result = await normalizeOidcProviderConfig(issuer, {
clientId: "client",
clientSecret: "secret",
authorizationEndpoint: "https://auth.example.com/auth",
tokenEndpoint: "https://auth.example.com/token",
jwksEndpoint: "https://auth.example.com/jwks",
userInfoEndpoint: "https://auth.example.com/userinfo",
scopes: ["openid", "email"],
pkce: false,
}, async () => {
throw new Error("fetch should not be called when endpoints are provided");
});
expect(result.oidcConfig.authorizationEndpoint).toBe("https://auth.example.com/auth");
expect(result.oidcConfig.tokenEndpoint).toBe("https://auth.example.com/token");
expect(result.oidcConfig.jwksEndpoint).toBe("https://auth.example.com/jwks");
expect(result.oidcConfig.userInfoEndpoint).toBe("https://auth.example.com/userinfo");
expect(result.oidcConfig.scopes).toEqual(["openid", "email"]);
expect(result.oidcConfig.pkce).toBe(false);
});
it("derives missing fields from discovery", async () => {
const fetchMock = async () =>
new Response(JSON.stringify({
authorization_endpoint: "https://auth.example.com/auth",
token_endpoint: "https://auth.example.com/token",
jwks_uri: "https://auth.example.com/jwks",
userinfo_endpoint: "https://auth.example.com/userinfo",
scopes_supported: ["openid", "email", "profile"],
}));
const result = await normalizeOidcProviderConfig(issuer, {
clientId: "client",
clientSecret: "secret",
}, fetchMock);
expect(result.oidcConfig.authorizationEndpoint).toBe("https://auth.example.com/auth");
expect(result.oidcConfig.tokenEndpoint).toBe("https://auth.example.com/token");
expect(result.oidcConfig.jwksEndpoint).toBe("https://auth.example.com/jwks");
expect(result.oidcConfig.userInfoEndpoint).toBe("https://auth.example.com/userinfo");
expect(result.oidcConfig.scopes).toEqual(["openid", "email", "profile"]);
});
it("throws for invalid issuer URL", async () => {
await expect(
normalizeOidcProviderConfig("not-a-url", {}),
).rejects.toBeInstanceOf(OidcConfigError);
});
});

202
src/lib/sso/oidc-config.ts Normal file
View File

@@ -0,0 +1,202 @@
import { z } from "zod";
const DEFAULT_SCOPES = ["openid", "email", "profile"] as const;
const DISCOVERY_TIMEOUT_MS = 10000;
const discoverySchema = z.object({
issuer: z.string().url().optional(),
authorization_endpoint: z.string().url().optional(),
token_endpoint: z.string().url().optional(),
userinfo_endpoint: z.string().url().optional(),
jwks_uri: z.string().url().optional(),
scopes_supported: z.array(z.string()).optional(),
});
export class OidcConfigError extends Error {
constructor(message: string) {
super(message);
this.name = "OidcConfigError";
}
}
export type RawOidcConfig = {
clientId?: string;
clientSecret?: string;
authorizationEndpoint?: string;
tokenEndpoint?: string;
jwksEndpoint?: string;
userInfoEndpoint?: string;
discoveryEndpoint?: string;
scopes?: string[];
pkce?: boolean;
mapping?: ProviderMapping;
};
export type ProviderMapping = {
id: string;
email: string;
emailVerified?: string;
name?: string;
image?: string;
firstName?: string;
lastName?: string;
};
export type NormalizedOidcConfig = {
oidcConfig: {
clientId?: string;
clientSecret?: string;
authorizationEndpoint: string;
tokenEndpoint: string;
jwksEndpoint?: string;
userInfoEndpoint?: string;
discoveryEndpoint: string;
scopes: string[];
pkce: boolean;
};
mapping: ProviderMapping;
};
type FetchFn = typeof fetch;
function cleanUrl(value: string | undefined, field: string): string | undefined {
if (!value || typeof value !== "string") return undefined;
const trimmed = value.trim();
if (!trimmed) return undefined;
try {
return new URL(trimmed).toString();
} catch {
throw new OidcConfigError(`Invalid ${field} URL: ${value}`);
}
}
function sanitizeScopes(scopes: string[] | undefined, fallback: readonly string[]): string[] {
const candidates = Array.isArray(scopes) ? scopes : [];
const sanitized = candidates
.map(scope => scope?.trim())
.filter((scope): scope is string => Boolean(scope));
if (sanitized.length === 0) {
return [...fallback];
}
return Array.from(new Set(sanitized));
}
async function fetchDiscoveryDocument(url: string, fetchFn: FetchFn): Promise<z.infer<typeof discoverySchema>> {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), DISCOVERY_TIMEOUT_MS);
try {
const response = await fetchFn(url, {
signal: controller.signal,
headers: { Accept: "application/json" },
});
if (!response.ok) {
throw new OidcConfigError(`OIDC discovery request failed (${response.status} ${response.statusText})`);
}
let payload: unknown;
try {
payload = await response.json();
} catch {
throw new OidcConfigError("OIDC discovery response is not valid JSON");
}
const parsed = discoverySchema.parse(payload);
if (!parsed.authorization_endpoint || !parsed.token_endpoint) {
throw new OidcConfigError("OIDC discovery document is missing required endpoints");
}
return parsed;
} catch (error) {
if (error instanceof OidcConfigError) {
throw error;
}
if (error instanceof Error && error.name === "AbortError") {
throw new OidcConfigError(`OIDC discovery timed out after ${DISCOVERY_TIMEOUT_MS / 1000}s`);
}
throw new OidcConfigError(`Failed to fetch OIDC discovery document: ${error instanceof Error ? error.message : "unknown error"}`);
} finally {
clearTimeout(timeoutId);
}
}
export async function normalizeOidcProviderConfig(
issuer: string,
rawConfig: RawOidcConfig,
fetchFn: FetchFn = fetch,
): Promise<NormalizedOidcConfig> {
if (!issuer || typeof issuer !== "string") {
throw new OidcConfigError("Issuer is required");
}
let normalizedIssuer: string;
try {
const issuerUrl = new URL(issuer.trim());
normalizedIssuer = issuerUrl.toString().replace(/\/$/, "");
} catch {
throw new OidcConfigError(`Invalid issuer URL: ${issuer}`);
}
const discoveryEndpoint = cleanUrl(
rawConfig.discoveryEndpoint,
"discovery endpoint",
) ?? `${normalizedIssuer}/.well-known/openid-configuration`;
const authorizationEndpoint = cleanUrl(rawConfig.authorizationEndpoint, "authorization endpoint");
const tokenEndpoint = cleanUrl(rawConfig.tokenEndpoint, "token endpoint");
const jwksEndpoint = cleanUrl(rawConfig.jwksEndpoint, "JWKS endpoint");
const userInfoEndpoint = cleanUrl(rawConfig.userInfoEndpoint, "userinfo endpoint");
const providedScopes = Array.isArray(rawConfig.scopes) ? rawConfig.scopes : undefined;
let scopes = sanitizeScopes(providedScopes, DEFAULT_SCOPES);
const shouldFetchDiscovery =
!authorizationEndpoint ||
!tokenEndpoint ||
!jwksEndpoint ||
!userInfoEndpoint ||
!providedScopes ||
providedScopes.length === 0;
let resolvedAuthorization = authorizationEndpoint;
let resolvedToken = tokenEndpoint;
let resolvedJwks = jwksEndpoint;
let resolvedUserInfo = userInfoEndpoint;
if (shouldFetchDiscovery) {
const discovery = await fetchDiscoveryDocument(discoveryEndpoint, fetchFn);
resolvedAuthorization = resolvedAuthorization ?? discovery.authorization_endpoint;
resolvedToken = resolvedToken ?? discovery.token_endpoint;
resolvedJwks = resolvedJwks ?? discovery.jwks_uri;
resolvedUserInfo = resolvedUserInfo ?? discovery.userinfo_endpoint;
if (!providedScopes || providedScopes.length === 0) {
scopes = sanitizeScopes(discovery.scopes_supported, DEFAULT_SCOPES);
}
}
if (!resolvedAuthorization || !resolvedToken) {
throw new OidcConfigError("OIDC configuration must include authorization and token endpoints");
}
return {
oidcConfig: {
clientId: rawConfig.clientId,
clientSecret: rawConfig.clientSecret,
authorizationEndpoint: resolvedAuthorization,
tokenEndpoint: resolvedToken,
jwksEndpoint: resolvedJwks,
userInfoEndpoint: resolvedUserInfo,
discoveryEndpoint,
scopes,
pkce: rawConfig.pkce !== false,
},
mapping: rawConfig.mapping ?? {
id: "sub",
email: "email",
emailVerified: "email_verified",
name: "name",
image: "picture",
},
};
}

View File

@@ -0,0 +1,65 @@
import { db, ssoProviders } from "@/lib/db";
import { normalizeOidcProviderConfig, OidcConfigError } from "./oidc-config";
import { eq } from "drizzle-orm";
type Logger = Pick<typeof console, "info" | "warn" | "error">;
export async function ensureValidSsoProviders(logger: Logger = console): Promise<void> {
const providers = await db.select().from(ssoProviders);
for (const provider of providers) {
if (!provider.oidcConfig) continue;
let parsedConfig: any;
try {
parsedConfig = JSON.parse(provider.oidcConfig);
} catch (error) {
logger.warn(`[SSO] Skipping provider ${provider.providerId}: invalid JSON configuration`, error);
continue;
}
const hasEndpoints =
parsedConfig.authorizationEndpoint &&
parsedConfig.tokenEndpoint;
const hasScopes = Array.isArray(parsedConfig.scopes) && parsedConfig.scopes.length > 0;
if (hasEndpoints && hasScopes) {
continue;
}
try {
const normalized = await normalizeOidcProviderConfig(provider.issuer, {
clientId: parsedConfig.clientId,
clientSecret: parsedConfig.clientSecret,
authorizationEndpoint: parsedConfig.authorizationEndpoint,
tokenEndpoint: parsedConfig.tokenEndpoint,
jwksEndpoint: parsedConfig.jwksEndpoint,
userInfoEndpoint: parsedConfig.userInfoEndpoint,
discoveryEndpoint: parsedConfig.discoveryEndpoint,
scopes: parsedConfig.scopes,
pkce: parsedConfig.pkce,
mapping: parsedConfig.mapping,
});
await db
.update(ssoProviders)
.set({
oidcConfig: JSON.stringify({
...normalized.oidcConfig,
mapping: normalized.mapping,
}),
updatedAt: new Date(),
})
.where(eq(ssoProviders.id, provider.id));
logger.info(`[SSO] Normalized OIDC configuration for provider ${provider.providerId}`);
} catch (error) {
if (error instanceof OidcConfigError) {
logger.warn(`[SSO] Unable to normalize provider ${provider.providerId}: ${error.message}`);
} else {
logger.error(`[SSO] Unexpected error normalizing provider ${provider.providerId}`, error);
}
}
}
}

View File

@@ -2,6 +2,10 @@ import type { APIContext } from "astro";
import { createSecureErrorResponse } from "@/lib/utils"; import { createSecureErrorResponse } from "@/lib/utils";
import { requireAuth } from "@/lib/utils/auth-helpers"; import { requireAuth } from "@/lib/utils/auth-helpers";
import { auth } from "@/lib/auth"; import { auth } from "@/lib/auth";
import { db, ssoProviders } from "@/lib/db";
import { eq } from "drizzle-orm";
import { nanoid } from "nanoid";
import { normalizeOidcProviderConfig, OidcConfigError } from "@/lib/sso/oidc-config";
// POST /api/auth/sso/register - Register a new SSO provider using Better Auth // POST /api/auth/sso/register - Register a new SSO provider using Better Auth
export async function POST(context: APIContext) { export async function POST(context: APIContext) {
@@ -104,43 +108,37 @@ export async function POST(context: APIContext) {
userInfoEndpoint, userInfoEndpoint,
scopes, scopes,
pkce = true, pkce = true,
mapping = { mapping,
id: "sub",
email: "email",
emailVerified: "email_verified",
name: "name",
image: "picture",
}
} = body; } = body;
// Use provided scopes or default if not specified try {
const finalScopes = scopes || ["openid", "email", "profile"]; const normalized = await normalizeOidcProviderConfig(validatedIssuer, {
clientId,
clientSecret,
authorizationEndpoint,
tokenEndpoint,
jwksEndpoint,
userInfoEndpoint,
discoveryEndpoint,
scopes,
pkce,
mapping,
});
// Validate endpoint URLs if provided registrationBody.oidcConfig = normalized.oidcConfig;
const validateUrl = (url: string | undefined, name: string): string | undefined => { registrationBody.mapping = normalized.mapping;
if (!url) return undefined; } catch (error) {
if (typeof url !== 'string' || url.trim() === '') return undefined; if (error instanceof OidcConfigError) {
try { return new Response(
const validatedUrl = new URL(url.trim()); JSON.stringify({ error: error.message }),
return validatedUrl.toString(); {
} catch (e) { status: 400,
console.warn(`Invalid ${name} URL: ${url}, skipping`); headers: { "Content-Type": "application/json" },
return undefined; }
);
} }
}; throw error;
}
registrationBody.oidcConfig = {
clientId: clientId || undefined,
clientSecret: clientSecret || undefined,
authorizationEndpoint: validateUrl(authorizationEndpoint, 'authorization endpoint'),
tokenEndpoint: validateUrl(tokenEndpoint, 'token endpoint'),
jwksEndpoint: validateUrl(jwksEndpoint, 'JWKS endpoint'),
discoveryEndpoint: validateUrl(discoveryEndpoint, 'discovery endpoint'),
userInfoEndpoint: validateUrl(userInfoEndpoint, 'userinfo endpoint'),
scopes: finalScopes,
pkce,
};
registrationBody.mapping = mapping;
} }
// Get the user's auth headers to make the request // Get the user's auth headers to make the request
@@ -168,7 +166,52 @@ export async function POST(context: APIContext) {
} }
const result = await response.json(); const result = await response.json();
// Mirror provider entry into local SSO table for UI listing
try {
const existing = await db
.select()
.from(ssoProviders)
.where(eq(ssoProviders.providerId, registrationBody.providerId))
.limit(1);
const values: any = {
issuer: registrationBody.issuer,
domain: registrationBody.domain,
organizationId: registrationBody.organizationId,
updatedAt: new Date(),
};
if (registrationBody.oidcConfig) {
values.oidcConfig = JSON.stringify({
...registrationBody.oidcConfig,
mapping: registrationBody.mapping,
});
}
if (existing.length > 0) {
await db
.update(ssoProviders)
.set(values)
.where(eq(ssoProviders.id, existing[0].id));
} else {
await db.insert(ssoProviders).values({
id: nanoid(),
issuer: registrationBody.issuer,
domain: registrationBody.domain,
oidcConfig: JSON.stringify({
...registrationBody.oidcConfig,
mapping: registrationBody.mapping,
}),
userId: user.id,
providerId: registrationBody.providerId,
organizationId: registrationBody.organizationId,
});
}
} catch (mirroringError) {
console.warn("Failed to mirror SSO provider to local DB:", mirroringError);
}
return new Response(JSON.stringify(result), { return new Response(JSON.stringify(result), {
status: 201, status: 201,
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
@@ -199,4 +242,4 @@ export async function GET(context: APIContext) {
} catch (error) { } catch (error) {
return createSecureErrorResponse(error, "SSO provider listing"); return createSecureErrorResponse(error, "SSO provider listing");
} }
} }

View File

@@ -4,6 +4,7 @@ import { requireAuth } from "@/lib/utils/auth-helpers";
import { db, ssoProviders } from "@/lib/db"; import { db, ssoProviders } from "@/lib/db";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { normalizeOidcProviderConfig, OidcConfigError, type RawOidcConfig } from "@/lib/sso/oidc-config";
// GET /api/sso/providers - List all SSO providers // GET /api/sso/providers - List all SSO providers
export async function GET(context: APIContext) { export async function GET(context: APIContext) {
@@ -45,10 +46,12 @@ export async function POST(context: APIContext) {
tokenEndpoint, tokenEndpoint,
jwksEndpoint, jwksEndpoint,
userInfoEndpoint, userInfoEndpoint,
discoveryEndpoint,
mapping, mapping,
providerId, providerId,
organizationId, organizationId,
scopes, scopes,
pkce,
} = body; } = body;
// Validate required fields // Validate required fields
@@ -79,22 +82,51 @@ export async function POST(context: APIContext) {
); );
} }
// Create OIDC config object // Clean issuer URL (remove trailing slash); validate format
const oidcConfig = { let cleanIssuer = issuer;
clientId, try {
clientSecret, const issuerUrl = new URL(issuer.toString().trim());
authorizationEndpoint, cleanIssuer = issuerUrl.toString().replace(/\/$/, "");
tokenEndpoint, } catch {
jwksEndpoint, return new Response(
userInfoEndpoint, JSON.stringify({ error: `Invalid issuer URL format: ${issuer}` }),
scopes: scopes || ["openid", "email", "profile"], {
mapping: mapping || { status: 400,
id: "sub", headers: { "Content-Type": "application/json" },
email: "email", }
emailVerified: "email_verified", );
name: "name", }
image: "picture",
}, let normalized;
try {
normalized = await normalizeOidcProviderConfig(cleanIssuer, {
clientId,
clientSecret,
authorizationEndpoint,
tokenEndpoint,
jwksEndpoint,
userInfoEndpoint,
discoveryEndpoint,
scopes,
pkce,
mapping,
});
} catch (error) {
if (error instanceof OidcConfigError) {
return new Response(
JSON.stringify({ error: error.message }),
{
status: 400,
headers: { "Content-Type": "application/json" },
}
);
}
throw error;
}
const storedOidcConfig = {
...normalized.oidcConfig,
mapping: normalized.mapping,
}; };
// Insert new provider // Insert new provider
@@ -102,9 +134,9 @@ export async function POST(context: APIContext) {
.insert(ssoProviders) .insert(ssoProviders)
.values({ .values({
id: nanoid(), id: nanoid(),
issuer, issuer: cleanIssuer,
domain, domain,
oidcConfig: JSON.stringify(oidcConfig), oidcConfig: JSON.stringify(storedOidcConfig),
userId: user.id, userId: user.id,
providerId, providerId,
organizationId, organizationId,
@@ -156,7 +188,9 @@ export async function PUT(context: APIContext) {
tokenEndpoint, tokenEndpoint,
jwksEndpoint, jwksEndpoint,
userInfoEndpoint, userInfoEndpoint,
discoveryEndpoint,
scopes, scopes,
pkce,
organizationId, organizationId,
} = body; } = body;
@@ -179,26 +213,62 @@ export async function PUT(context: APIContext) {
// Parse existing config // Parse existing config
const existingConfig = JSON.parse(existingProvider.oidcConfig); const existingConfig = JSON.parse(existingProvider.oidcConfig);
const effectiveIssuer = issuer || existingProvider.issuer;
// Create updated OIDC config let cleanIssuer = effectiveIssuer;
const updatedOidcConfig = { try {
...existingConfig, const issuerUrl = new URL(effectiveIssuer.toString().trim());
clientId: clientId || existingConfig.clientId, cleanIssuer = issuerUrl.toString().replace(/\/$/, "");
clientSecret: clientSecret || existingConfig.clientSecret, } catch {
authorizationEndpoint: authorizationEndpoint || existingConfig.authorizationEndpoint, return new Response(
tokenEndpoint: tokenEndpoint || existingConfig.tokenEndpoint, JSON.stringify({ error: `Invalid issuer URL format: ${effectiveIssuer}` }),
jwksEndpoint: jwksEndpoint || existingConfig.jwksEndpoint, {
userInfoEndpoint: userInfoEndpoint || existingConfig.userInfoEndpoint, status: 400,
scopes: scopes || existingConfig.scopes || ["openid", "email", "profile"], headers: { "Content-Type": "application/json" },
}
);
}
const mergedConfig: RawOidcConfig = {
clientId: clientId ?? existingConfig.clientId,
clientSecret: clientSecret ?? existingConfig.clientSecret,
authorizationEndpoint: authorizationEndpoint ?? existingConfig.authorizationEndpoint,
tokenEndpoint: tokenEndpoint ?? existingConfig.tokenEndpoint,
jwksEndpoint: jwksEndpoint ?? existingConfig.jwksEndpoint,
userInfoEndpoint: userInfoEndpoint ?? existingConfig.userInfoEndpoint,
discoveryEndpoint: discoveryEndpoint ?? existingConfig.discoveryEndpoint,
scopes: scopes ?? existingConfig.scopes,
pkce: pkce ?? existingConfig.pkce,
mapping: existingConfig.mapping,
};
let normalized;
try {
normalized = await normalizeOidcProviderConfig(cleanIssuer, mergedConfig);
} catch (error) {
if (error instanceof OidcConfigError) {
return new Response(
JSON.stringify({ error: error.message }),
{
status: 400,
headers: { "Content-Type": "application/json" },
}
);
}
throw error;
}
const storedOidcConfig = {
...normalized.oidcConfig,
mapping: normalized.mapping,
}; };
// Update provider
const [updatedProvider] = await db const [updatedProvider] = await db
.update(ssoProviders) .update(ssoProviders)
.set({ .set({
issuer: issuer || existingProvider.issuer, issuer: cleanIssuer,
domain: domain || existingProvider.domain, domain: domain || existingProvider.domain,
oidcConfig: JSON.stringify(updatedOidcConfig), oidcConfig: JSON.stringify(storedOidcConfig),
organizationId: organizationId !== undefined ? organizationId : existingProvider.organizationId, organizationId: organizationId !== undefined ? organizationId : existingProvider.organizationId,
updatedAt: new Date(), updatedAt: new Date(),
}) })
@@ -259,4 +329,4 @@ export async function DELETE(context: APIContext) {
} catch (error) { } catch (error) {
return createSecureErrorResponse(error, "SSO providers API"); return createSecureErrorResponse(error, "SSO providers API");
} }
} }