diff --git a/docs/dev/DESIGN.md b/docs/dev/DESIGN.md index 12e9c3e..25f6073 100644 --- a/docs/dev/DESIGN.md +++ b/docs/dev/DESIGN.md @@ -91,7 +91,7 @@ Unified pattern across tabs: SettingsHeader + EntityListCard + EntityModal; Inli - Users: create/edit; role picker; delete via ConfirmModal. - Groups: create/edit; manage membership; one Tag per Group; delete via ConfirmModal. - Taxonomy: Tags, Tool Categories (by type), Tools, Threat Actors (attach ATT&CK techniques), Crown Jewels, Log Sources. -- Database: stats; Backup (download JSON, selectable sections), Restore (ConfirmModal + file, selectable sections, optional clear), Clear Data (ConfirmModal per selection). +- Data: overview metrics; export/import a combined operations + taxonomy backup (always replaces existing data); clear-all confirmation. ## Data & Validation diff --git a/src/app/(protected-routes)/settings/data/page.tsx b/src/app/(protected-routes)/settings/data/page.tsx new file mode 100644 index 0000000..6c5c6db --- /dev/null +++ b/src/app/(protected-routes)/settings/data/page.tsx @@ -0,0 +1,259 @@ +"use client"; + +import { useEffect, useRef, useState } from "react"; + +import { CheckCircle } from "lucide-react"; + +import ConfirmModal from "@components/ui/confirm-modal"; +import { Button, Card, CardContent, CardHeader, CardTitle } from "@components/ui"; +import { logger } from "@lib/logger"; +import { api } from "@/trpc/react"; + +export default function DataSettingsPage() { + const [restoreFile, setRestoreFile] = useState(null); + const [showRestoreConfirm, setShowRestoreConfirm] = useState(false); + const [showClearConfirm, setShowClearConfirm] = useState(false); + const [toastMessage, setToastMessage] = useState(null); + const [restoreError, setRestoreError] = useState(null); + const fileInputRef = useRef(null); + + const resetRestoreSelection = () => { + setRestoreFile(null); + if (fileInputRef.current) { + fileInputRef.current.value = ""; + } + }; + + useEffect(() => { + if (!toastMessage) return; + + const timeout = window.setTimeout(() => setToastMessage(null), 4000); + + return () => window.clearTimeout(timeout); + }, [toastMessage]); + + const utils = api.useUtils(); + const { data: stats, refetch: refetchStats } = api.data.getStats.useQuery(); + + const backupMutation = api.data.backup.useMutation({ + onSuccess: (data) => { + const blob = new Blob([data], { type: "application/json" }); + const url = URL.createObjectURL(blob); + const a = document.createElement("a"); + a.href = url; + a.download = `ttpx-data-${new Date().toISOString().split("T")[0]}.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + }, + }); + + const restoreMutation = api.data.restore.useMutation({ + onSuccess: () => { + void refetchStats(); + void utils.invalidate(); + resetRestoreSelection(); + setShowRestoreConfirm(false); + setToastMessage("Operations and taxonomy data have been imported."); + setRestoreError(null); + }, + onError: (error) => { + setShowRestoreConfirm(false); + setRestoreError(error.message); + resetRestoreSelection(); + }, + }); + + const clearDataMutation = api.data.clearData.useMutation({ + onSuccess: () => { + void refetchStats(); + void utils.invalidate(); + setShowClearConfirm(false); + setToastMessage("Operations and taxonomy data have been cleared."); + }, + onError: () => { + setShowClearConfirm(false); + }, + }); + + const handleRestore = async () => { + if (!restoreFile) return; + + try { + const text = await restoreFile.text(); + setRestoreError(null); + restoreMutation.mutate({ backupData: text }); + } catch (error) { + logger.error("Failed to read backup file", error); + } + }; + + return ( +
+ {toastMessage && ( +
+
+ +
+

Success

+

{toastMessage}

+
+
+
+ )} + +
+

Data Management

+

+ Export, import, or clear operations and taxonomy data. +

+
+ + {showRestoreConfirm && ( + { + resetRestoreSelection(); + setShowRestoreConfirm(false); + }} + loading={restoreMutation.isPending} + /> + )} + + + + Data Overview + + + {stats ? ( +
+ {[ + { label: "Operations", value: stats.operations }, + { label: "Techniques", value: stats.techniques }, + { label: "Outcomes", value: stats.outcomes }, + { label: "Threat Actors", value: stats.threatActors }, + { label: "Crown Jewels", value: stats.crownJewels }, + { label: "Tags", value: stats.tags }, + { label: "Tools", value: stats.tools }, + { label: "Log Sources", value: stats.logSources }, + ].map(({ label, value }) => ( +
+
{value}
+
{label}
+
+ ))} +
+ ) : ( +
Loading statistics...
+ )} +
+
+ +
+ + + Export Data + + +

+ Download a JSON file containing all operations and taxonomy records. +

+ +
+
+ + + + Import Data + + +

+ Import a backup to replace the current operations and taxonomy data set. +

+ { + const file = event.target.files?.[0] ?? null; + + if (!file) { + resetRestoreSelection(); + setShowRestoreConfirm(false); + return; + } + + setRestoreFile(file); + setRestoreError(null); + setShowRestoreConfirm(true); + event.target.value = ""; + }} + /> + + {restoreError &&
{restoreError}
} +
+
+ + + + ⚠️ Clear Data + + +

+ Permanently delete all operations and taxonomy data. This cannot be undone. +

+ + {clearDataMutation.error && ( +
{clearDataMutation.error.message}
+ )} +
+
+
+ + {showClearConfirm && ( + clearDataMutation.mutate()} + onCancel={() => setShowClearConfirm(false)} + loading={clearDataMutation.isPending} + /> + )} +
+ ); +} diff --git a/src/app/(protected-routes)/settings/database/page.tsx b/src/app/(protected-routes)/settings/database/page.tsx deleted file mode 100644 index bc8b5fb..0000000 --- a/src/app/(protected-routes)/settings/database/page.tsx +++ /dev/null @@ -1,344 +0,0 @@ -"use client"; - -import { useState } from "react"; -import { api } from "@/trpc/react"; -import { Button, Card, CardHeader, CardTitle, CardContent } from "@components/ui"; -import { logger } from "@lib/logger"; -import ConfirmModal from "@components/ui/confirm-modal"; - -export default function DatabaseSettingsPage() { - const [restoreFile, setRestoreFile] = useState(null); - // Backup options - const [includeTaxonomyAndOperations, setIncludeTaxonomyAndOperations] = useState(true); - const [includeUsersAndGroups, setIncludeUsersAndGroups] = useState(false); - // Restore options - const [restoreTaxonomyAndOperations, setRestoreTaxonomyAndOperations] = useState(true); - const [restoreUsersAndGroups, setRestoreUsersAndGroups] = useState(false); - const [clearBefore, setClearBefore] = useState(true); - const [clearOperations, setClearOperations] = useState(false); - const [clearTaxonomy, setClearTaxonomy] = useState(false); - const [showRestoreConfirm, setShowRestoreConfirm] = useState(false); - const [showClearConfirm, setShowClearConfirm] = useState(false); - // typed confirm state removed in favor of ConfirmModal - - // Store confirmation state for mutations - const [pendingClearOptions, setPendingClearOptions] = useState<{clearOperations: boolean, clearTaxonomy: boolean} | null>(null); - - const utils = api.useUtils(); - - // Queries - const { data: stats, refetch: refetchStats } = api.database.getStats.useQuery(); - - // Mutations - const backupMutation = api.database.backup.useMutation({ - onSuccess: (data) => { - // Create and download backup file - const blob = new Blob([data], { type: "application/json" }); - const url = URL.createObjectURL(blob); - const a = document.createElement("a"); - a.href = url; - a.download = `ttpx-backup-${new Date().toISOString().split("T")[0]}.json`; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - }, - }); - - const restoreMutation = api.database.restore.useMutation({ - onSuccess: () => { - void refetchStats(); - void utils.invalidate(); - setRestoreFile(null); - setShowRestoreConfirm(false); - // noop - }, - }); - - const clearDataMutation = api.database.clearData.useMutation({ - onSuccess: () => { - void refetchStats(); - void utils.invalidate(); - setShowClearConfirm(false); - setClearOperations(false); - setClearTaxonomy(false); - // noop - setPendingClearOptions(null); - }, - }); - - const handleRestore = async () => { - if (!restoreFile) return; - - try { - const text = await restoreFile.text(); - restoreMutation.mutate({ - backupData: text, - restoreTaxonomyAndOperations, - restoreUsersAndGroups, - clearBefore, - }); - } catch (error) { - logger.error("Failed to read file:", error); - } - }; - - const handleClearData = () => { - if (!pendingClearOptions) return; - clearDataMutation.mutate(pendingClearOptions); - }; - - return ( -
-
-

Database Management

-
- {showRestoreConfirm && ( - setShowRestoreConfirm(false)} - loading={restoreMutation.isPending} - /> - )} - - {/* Database Statistics */} - - - Database Statistics - - - {stats ? ( -
-
-
{stats.users}
-
Users
-
-
-
{stats.operations}
-
Operations
-
-
-
{stats.techniques}
-
Techniques
-
-
-
{stats.outcomes}
-
Outcomes
-
-
-
{stats.threatActors}
-
Threat Actors
-
-
-
{stats.tools}
-
Tools
-
-
- ) : ( -
- Loading statistics... -
- )} -
-
- - {/* Backup & Restore */} -
- - - Backup Database - - -

- Export selected data to a JSON file. -

-
-
- setIncludeTaxonomyAndOperations(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
-
- setIncludeUsersAndGroups(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
-
- -
-
- - - - Restore Database - - -

- Restore selected data from a backup file. -

-
-
- setRestoreTaxonomyAndOperations(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
-
- setRestoreUsersAndGroups(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
-
- setClearBefore(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
- setRestoreFile(e.target.files?.[0] ?? null)} - className="w-full text-[var(--color-text-primary)] bg-[var(--color-surface)] border border-[var(--color-border)] rounded-[var(--radius-md)] p-2" - /> - -
- {restoreMutation.error && ( -
- {restoreMutation.error.message} -
- )} -
-
-
- {/* Clear Data */} -
- - - ⚠️ Clear Data - - -

- Permanently delete data from your database. -

-
-
- setClearOperations(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
-
- setClearTaxonomy(e.target.checked)} - className="rounded border-[var(--color-border)]" - /> - -
- -
- {clearDataMutation.error && ( -
- {clearDataMutation.error.message} -
- )} -
-
-
- {showClearConfirm && ( - { setShowClearConfirm(false); setPendingClearOptions(null); }} - loading={clearDataMutation.isPending} - /> - )} -
- ); -} diff --git a/src/features/shared/layout/sidebar-nav.tsx b/src/features/shared/layout/sidebar-nav.tsx index 86a5ab7..aa64db9 100644 --- a/src/features/shared/layout/sidebar-nav.tsx +++ b/src/features/shared/layout/sidebar-nav.tsx @@ -77,9 +77,9 @@ const navigation: NavItem[] = [ href: "/settings/groups", }, { - key: "database", - label: "Database", - href: "/settings/database", + key: "data", + label: "Data", + href: "/settings/data", }, ], diff --git a/src/server/api/root.ts b/src/server/api/root.ts index b031b59..b421c3a 100644 --- a/src/server/api/root.ts +++ b/src/server/api/root.ts @@ -4,7 +4,7 @@ import { groupsRouter } from "@/server/api/routers/groups"; import { operationsRouter } from "@/server/api/routers/operations"; import { techniquesRouter } from "@/server/api/routers/techniques"; import { outcomesRouter } from "@/server/api/routers/outcomes"; -import { databaseRouter } from "@/server/api/routers/database"; +import { dataRouter } from "@/server/api/routers/data"; import { analyticsRouter } from "@/server/api/routers/analytics"; import { importRouter } from "@/server/api/routers/import"; import { createCallerFactory, createTRPCRouter } from "@/server/api/trpc"; @@ -21,7 +21,7 @@ export const appRouter = createTRPCRouter({ operations: operationsRouter, techniques: techniquesRouter, outcomes: outcomesRouter, - database: databaseRouter, + data: dataRouter, analytics: analyticsRouter, import: importRouter, }); diff --git a/src/server/api/routers/data/index.ts b/src/server/api/routers/data/index.ts new file mode 100644 index 0000000..d0e53fe --- /dev/null +++ b/src/server/api/routers/data/index.ts @@ -0,0 +1,369 @@ +import { z } from "zod"; +import type { Prisma } from "@prisma/client"; +import { TRPCError } from "@trpc/server"; + +import { createTRPCRouter, adminProcedure } from "@/server/api/trpc"; +import { logger } from "@/server/logger"; + +const clearUserData = async (tx: Prisma.TransactionClient) => { + await tx.outcome.deleteMany(); + await tx.technique.deleteMany(); + await tx.attackFlowLayout.deleteMany(); + await tx.operation.deleteMany(); + + await tx.tool.deleteMany(); + await tx.toolCategory.deleteMany(); + await tx.logSource.deleteMany(); + await tx.tag.deleteMany(); + await tx.crownJewel.deleteMany(); + await tx.threatActor.deleteMany(); +}; + +const threatActorSchema = z.object({ + id: z.string().optional(), + name: z.string(), + description: z.string(), + topThreat: z.boolean().optional(), +}); + +const crownJewelSchema = z.object({ + id: z.string().optional(), + name: z.string(), + description: z.string(), +}); + +const tagSchema = z.object({ + id: z.string().optional(), + name: z.string(), + description: z.string(), + color: z.string().optional(), +}); + +const toolCategorySchema = z.object({ + id: z.string().optional(), + name: z.string(), + type: z.enum(["DEFENSIVE", "OFFENSIVE"]), +}); + +const toolSchema = z.object({ + id: z.string().optional(), + name: z.string(), + categoryId: z.string(), + type: z.enum(["DEFENSIVE", "OFFENSIVE"]), +}); + +const logSourceSchema = z.object({ + id: z.string().optional(), + name: z.string(), + description: z.string(), +}); + +const operationSchema = z.object({ + id: z.number().optional(), + name: z.string(), + description: z.string(), + status: z.enum(["PLANNING", "ACTIVE", "COMPLETED", "CANCELLED"]).optional(), + startDate: z.coerce.date().optional().nullable(), + endDate: z.coerce.date().optional().nullable(), + createdById: z.string(), + threatActorId: z.string().optional().nullable(), + tags: z.array(z.object({ id: z.string() })).optional(), + crownJewels: z.array(z.object({ id: z.string() })).optional(), + visibility: z.enum(["EVERYONE", "GROUPS_ONLY"]).optional(), +}); + +const techniqueSchema = z.object({ + id: z.string().optional(), + description: z.string(), + sortOrder: z.number().int().optional(), + startTime: z.coerce.date().optional().nullable(), + endTime: z.coerce.date().optional().nullable(), + sourceIp: z.string().optional().nullable(), + targetSystem: z.string().optional().nullable(), + crownJewelTargeted: z.boolean().optional(), + crownJewelCompromised: z.boolean().optional(), + executedSuccessfully: z.boolean().optional().nullable(), + operationId: z.number(), + mitreTechniqueId: z.string().optional().nullable(), + mitreSubTechniqueId: z.string().optional().nullable(), + tools: z.array(z.object({ id: z.string() })).optional(), +}); + +const outcomeSchema = z.object({ + id: z.string().optional(), + type: z.enum(["DETECTION", "PREVENTION", "ATTRIBUTION"]), + status: z.enum(["NOT_APPLICABLE", "MISSED", "DETECTED", "PREVENTED", "ATTRIBUTED"]), + detectionTime: z.coerce.date().optional().nullable(), + notes: z.string().optional().nullable(), + screenshotUrl: z.string().optional().nullable(), + logData: z.string().optional().nullable(), + techniqueId: z.string(), + tools: z.array(z.object({ id: z.string() })).optional(), + logSources: z.array(z.object({ id: z.string() })).optional(), +}); + +const attackFlowLayoutSchema = z.object({ + id: z.string().optional(), + operationId: z.number(), + nodes: z.custom(), + edges: z.custom(), +}); + +const threatActorTechniqueLinkSchema = z.object({ + threatActorId: z.string(), + mitreTechniqueId: z.string(), +}); + +const backupPayloadSchema = z.object({ + threatActors: z.array(threatActorSchema).optional(), + crownJewels: z.array(crownJewelSchema).optional(), + tags: z.array(tagSchema).optional(), + toolCategories: z.array(toolCategorySchema).optional(), + tools: z.array(toolSchema).optional(), + logSources: z.array(logSourceSchema).optional(), + operations: z.array(operationSchema).optional(), + techniques: z.array(techniqueSchema).optional(), + outcomes: z.array(outcomeSchema).optional(), + attackFlowLayouts: z.array(attackFlowLayoutSchema).optional(), + threatActorTechniqueLinks: z.array(threatActorTechniqueLinkSchema).optional(), +}); + +const backupEnvelopeSchema = z.object({ + version: z.string(), + timestamp: z.string(), + data: backupPayloadSchema, +}); + +export const dataRouter = createTRPCRouter({ + getStats: adminProcedure.query(async ({ ctx }) => { + const db = ctx.db; + + const [ + operationCount, + techniqueCount, + outcomeCount, + threatActorCount, + crownJewelCount, + tagCount, + toolCount, + logSourceCount, + ] = await Promise.all([ + db.operation.count(), + db.technique.count(), + db.outcome.count(), + db.threatActor.count(), + db.crownJewel.count(), + db.tag.count(), + db.tool.count(), + db.logSource.count(), + ]); + + return { + operations: operationCount, + techniques: techniqueCount, + outcomes: outcomeCount, + threatActors: threatActorCount, + crownJewels: crownJewelCount, + tags: tagCount, + tools: toolCount, + logSources: logSourceCount, + }; + }), + + backup: adminProcedure.mutation(async ({ ctx }) => { + const db = ctx.db; + + try { + const [ + crownJewels, + tags, + toolCategories, + tools, + logSources, + operations, + techniques, + outcomes, + attackFlowLayouts, + ] = await Promise.all([ + db.crownJewel.findMany(), + db.tag.findMany(), + db.toolCategory.findMany(), + db.tool.findMany(), + db.logSource.findMany(), + db.operation.findMany({ + include: { + tags: { select: { id: true } }, + crownJewels: { select: { id: true } }, + }, + }), + db.technique.findMany({ + include: { + tools: { select: { id: true } }, + }, + }), + db.outcome.findMany({ + include: { + tools: { select: { id: true } }, + logSources: { select: { id: true } }, + }, + }), + db.attackFlowLayout.findMany(), + ]); + + const threatActorRows = await db.threatActor.findMany({ + include: { mitreTechniques: { select: { id: true } } }, + }); + + const threatActors = threatActorRows.map(({ mitreTechniques: _mitreTechniques, ...actor }) => actor); + const threatActorTechniqueLinks = threatActorRows.flatMap((actor) => + actor.mitreTechniques.map((tech) => ({ + threatActorId: actor.id, + mitreTechniqueId: tech.id, + })), + ); + + return JSON.stringify( + { + version: "2.0", + timestamp: new Date().toISOString(), + data: { + threatActors, + crownJewels, + tags, + toolCategories, + tools, + logSources, + operations, + techniques, + outcomes, + attackFlowLayouts, + threatActorTechniqueLinks, + }, + }, + null, + 2, + ); + } catch (error) { + logger.error({ event: "data.backup_failed", error }, "Failed to create data backup"); + throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Failed to create backup" }); + } + }), + + restore: adminProcedure + .input(z.object({ backupData: z.string() })) + .mutation(async ({ ctx, input }) => { + const db = ctx.db; + + let raw: unknown; + try { + raw = JSON.parse(input.backupData); + } catch { + throw new TRPCError({ code: "BAD_REQUEST", message: "Invalid backup file" }); + } + + let payload: z.infer; + const parsedEnvelope = backupEnvelopeSchema.safeParse(raw); + + if (parsedEnvelope.success) { + payload = parsedEnvelope.data.data; + } else { + const parsedPayload = backupPayloadSchema.safeParse(raw); + if (!parsedPayload.success) { + throw new TRPCError({ code: "BAD_REQUEST", message: "Backup file is missing data" }); + } + payload = parsedPayload.data; + } + + try { + await db.$transaction(async (tx: Prisma.TransactionClient) => { + await clearUserData(tx); + + if (payload.threatActors?.length) { + await tx.threatActor.createMany({ data: payload.threatActors }); + } + if (payload.crownJewels?.length) { + await tx.crownJewel.createMany({ data: payload.crownJewels }); + } + if (payload.tags?.length) { + await tx.tag.createMany({ data: payload.tags }); + } + if (payload.toolCategories?.length) { + await tx.toolCategory.createMany({ data: payload.toolCategories }); + } + if (payload.tools?.length) { + await tx.tool.createMany({ data: payload.tools }); + } + if (payload.logSources?.length) { + await tx.logSource.createMany({ data: payload.logSources }); + } + + for (const op of payload.operations ?? []) { + const { tags: opTags = [], crownJewels: opCrownJewels = [], ...operationFields } = op; + + await tx.operation.create({ + data: { + ...operationFields, + // Access groups are not restored; default all operations to everyone-visible. + visibility: "EVERYONE", + tags: opTags.length ? { connect: opTags.map(({ id }) => ({ id })) } : undefined, + crownJewels: opCrownJewels.length ? { connect: opCrownJewels.map(({ id }) => ({ id })) } : undefined, + }, + }); + } + + for (const technique of payload.techniques ?? []) { + const { tools: techniqueTools = [], ...techniqueFields } = technique; + + await tx.technique.create({ + data: { + ...techniqueFields, + tools: techniqueTools.length ? { connect: techniqueTools.map(({ id }) => ({ id })) } : undefined, + }, + }); + } + + for (const outcome of payload.outcomes ?? []) { + const { tools: outcomeTools = [], logSources: outcomeLogSources = [], ...outcomeFields } = outcome; + + await tx.outcome.create({ + data: { + ...outcomeFields, + tools: outcomeTools.length ? { connect: outcomeTools.map(({ id }) => ({ id })) } : undefined, + logSources: outcomeLogSources.length ? { connect: outcomeLogSources.map(({ id }) => ({ id })) } : undefined, + }, + }); + } + + for (const layout of payload.attackFlowLayouts ?? []) { + await tx.attackFlowLayout.create({ data: layout }); + } + + for (const link of payload.threatActorTechniqueLinks ?? []) { + await tx.threatActor.update({ + where: { id: link.threatActorId }, + data: { mitreTechniques: { connect: { id: link.mitreTechniqueId } } }, + }); + } + }); + + return { success: true }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + logger.error({ event: "data.restore_failed", message }, "Restore error"); + throw new TRPCError({ code: "BAD_REQUEST", message: "Failed to restore backup" }); + } + }), + + clearData: adminProcedure.mutation(async ({ ctx }) => { + const db = ctx.db; + + try { + await db.$transaction(clearUserData); + return { success: true }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + logger.error({ event: "data.clear_failed", message }, "Clear data error"); + throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Failed to clear data" }); + } + }), +}); diff --git a/src/server/api/routers/database/backup.ts b/src/server/api/routers/database/backup.ts deleted file mode 100644 index fd05a4a..0000000 --- a/src/server/api/routers/database/backup.ts +++ /dev/null @@ -1,128 +0,0 @@ -import { z } from "zod"; -import { createTRPCRouter, adminProcedure } from "@/server/api/trpc"; -import type { AuthedContext } from "@/server/api/trpc"; -import type { DatabaseClient } from "@/server/db"; -import { TRPCError } from "@trpc/server"; - -interface BackupData { - version: string; - timestamp: string; - data: Record; -} - -const isDatabaseClient = (value: unknown): value is DatabaseClient => { - if (typeof value !== "object" || value === null) { - return false; - } - - const candidate = value as { - user?: { findMany?: unknown }; - operation?: { findMany?: unknown }; - }; - - return typeof candidate.user?.findMany === "function" && typeof candidate.operation?.findMany === "function"; -}; - -const getPrismaClient = (value: unknown): DatabaseClient => { - if (!isDatabaseClient(value)) { - throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Database client not initialized" }); - } - return value; -}; - -export const databaseBackupRouter = createTRPCRouter({ - getStats: adminProcedure.query(async ({ ctx }: { ctx: AuthedContext }) => { - const db = getPrismaClient(ctx.db); - const [ - userCount, - operationCount, - techniqueCount, - outcomeCount, - threatActorCount, - crownJewelCount, - tagCount, - toolCount, - logSourceCount, - mitreTacticCount, - mitreTechniqueCount, - mitreSubTechniqueCount, - groupCount, - userGroupCount, - ] = await Promise.all([ - db.user.count(), - db.operation.count(), - db.technique.count(), - db.outcome.count(), - db.threatActor.count(), - db.crownJewel.count(), - db.tag.count(), - db.tool.count(), - db.logSource.count(), - db.mitreTactic.count(), - db.mitreTechnique.count(), - db.mitreSubTechnique.count(), - db.group.count(), - db.userGroup.count(), - ]); - - return { - users: userCount, - operations: operationCount, - techniques: techniqueCount, - outcomes: outcomeCount, - threatActors: threatActorCount, - crownJewels: crownJewelCount, - tags: tagCount, - tools: toolCount, - logSources: logSourceCount, - mitreTactics: mitreTacticCount, - mitreTechniques: mitreTechniqueCount, - mitreSubTechniques: mitreSubTechniqueCount, - groups: groupCount, - userGroups: userGroupCount, - }; - }), - - backup: adminProcedure - .input(z.object({ includeTaxonomyAndOperations: z.boolean().default(true), includeUsersAndGroups: z.boolean().default(false), includeMitre: z.boolean().default(true) }).optional()) - .mutation(async ({ ctx, input }: { ctx: AuthedContext; input: { includeTaxonomyAndOperations?: boolean; includeUsersAndGroups?: boolean; includeMitre?: boolean } | undefined }) => { - const db = getPrismaClient(ctx.db); - try { - const opts = input ?? { includeTaxonomyAndOperations: true, includeUsersAndGroups: false, includeMitre: true }; - const payload: Record = {}; - const data: BackupData = { version: "1.0", timestamp: new Date().toISOString(), data: payload }; - - if (opts.includeMitre) { - payload.mitreTactics = await db.mitreTactic.findMany(); - payload.mitreTechniques = await db.mitreTechnique.findMany(); - payload.mitreSubTechniques = await db.mitreSubTechnique.findMany(); - } - - if (opts.includeTaxonomyAndOperations) { - payload.threatActors = await db.threatActor.findMany(); - payload.crownJewels = await db.crownJewel.findMany(); - payload.tags = await db.tag.findMany(); - payload.toolCategories = await db.toolCategory.findMany(); - payload.tools = await db.tool.findMany(); - payload.logSources = await db.logSource.findMany(); - payload.operations = await db.operation.findMany({ include: { tags: true, crownJewels: true, accessGroups: { select: { groupId: true } }, } }); - payload.techniques = await db.technique.findMany({ include: { tools: true } }); - payload.outcomes = await db.outcome.findMany({ include: { tools: true, logSources: true } }); - payload.attackFlowLayouts = await db.attackFlowLayout.findMany(); - const taWithTechs = await db.threatActor.findMany({ select: { id: true, mitreTechniques: { select: { id: true } } } }); - payload.threatActorTechniqueLinks = taWithTechs.flatMap((ta) => ta.mitreTechniques.map((tech) => ({ threatActorId: ta.id, mitreTechniqueId: tech.id }))); - } - - if (opts.includeUsersAndGroups) { - payload.users = await db.user.findMany({ select: { id: true, email: true, name: true, role: true, lastLogin: true } }); - payload.authenticators = await db.authenticator.findMany(); - payload.groups = await db.group.findMany({ select: { id: true, name: true, description: true } }); - payload.userGroups = await db.userGroup.findMany({ select: { userId: true, groupId: true } }); - } - - return JSON.stringify(data, null, 2); - } catch { - throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Failed to create backup" }); - } - }), -}); diff --git a/src/server/api/routers/database/index.ts b/src/server/api/routers/database/index.ts deleted file mode 100644 index 4c69812..0000000 --- a/src/server/api/routers/database/index.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { createTRPCRouter } from "@/server/api/trpc"; -import { databaseBackupRouter } from "./backup"; -import { databaseRestoreRouter } from "./restore"; - -export const databaseRouter = createTRPCRouter({ - ...databaseBackupRouter._def.procedures, - ...databaseRestoreRouter._def.procedures, -}); - diff --git a/src/server/api/routers/database/restore.ts b/src/server/api/routers/database/restore.ts deleted file mode 100644 index 0927c1e..0000000 --- a/src/server/api/routers/database/restore.ts +++ /dev/null @@ -1,355 +0,0 @@ -import { z } from "zod"; -import type { Prisma } from "@prisma/client"; -import { createTRPCRouter, adminProcedure } from "@/server/api/trpc"; -import type { AuthedContext } from "@/server/api/trpc"; -import type { DatabaseClient } from "@/server/db"; -import { TRPCError } from "@trpc/server"; -import { logger } from "@/server/logger"; - -const isDatabaseClient = (value: unknown): value is DatabaseClient => { - if (typeof value !== "object" || value === null) { - return false; - } - - const candidate = value as { - $transaction?: unknown; - operation?: { deleteMany?: unknown }; - threatActor?: { deleteMany?: unknown }; - }; - - return ( - typeof candidate.$transaction === "function" && - typeof candidate.operation?.deleteMany === "function" && - typeof candidate.threatActor?.deleteMany === "function" - ); -}; - -const getPrismaClient = (value: unknown): DatabaseClient => { - if (!isDatabaseClient(value)) { - throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Database client not initialized" }); - } - return value; -}; - -const threatActorSchema = z.object({ id: z.string().optional(), name: z.string(), description: z.string(), topThreat: z.boolean().optional() }); -const crownJewelSchema = z.object({ id: z.string().optional(), name: z.string(), description: z.string() }); -const tagSchema = z.object({ id: z.string().optional(), name: z.string(), description: z.string(), color: z.string().optional() }); -const toolCategorySchema = z.object({ id: z.string().optional(), name: z.string(), type: z.enum(["DEFENSIVE", "OFFENSIVE"]) }); -const toolSchema = z.object({ id: z.string().optional(), name: z.string(), categoryId: z.string(), type: z.enum(["DEFENSIVE", "OFFENSIVE"]) }); -const logSourceSchema = z.object({ id: z.string().optional(), name: z.string(), description: z.string() }); -const operationSchema = z.object({ - id: z.number().optional(), - name: z.string(), - description: z.string(), - status: z.enum(["PLANNING", "ACTIVE", "COMPLETED", "CANCELLED"]).optional(), - startDate: z.coerce.date().optional().nullable(), - endDate: z.coerce.date().optional().nullable(), - createdById: z.string(), - threatActorId: z.string().optional().nullable(), - tags: z.array(z.object({ id: z.string() })).optional(), - crownJewels: z.array(z.object({ id: z.string() })).optional(), - visibility: z.enum(["EVERYONE", "GROUPS_ONLY"]).optional(), - accessGroups: z.array(z.object({ groupId: z.string() })).optional(), -}); -const techniqueSchema = z.object({ - id: z.string().optional(), - description: z.string(), - sortOrder: z.number().int().optional(), - startTime: z.coerce.date().optional().nullable(), - endTime: z.coerce.date().optional().nullable(), - sourceIp: z.string().optional().nullable(), - targetSystem: z.string().optional().nullable(), - crownJewelTargeted: z.boolean().optional(), - crownJewelCompromised: z.boolean().optional(), - operationId: z.number(), - mitreTechniqueId: z.string().optional().nullable(), - mitreSubTechniqueId: z.string().optional().nullable(), - tools: z.array(z.object({ id: z.string() })).optional(), -}); -const outcomeSchema = z.object({ - id: z.string().optional(), - type: z.enum(["DETECTION", "PREVENTION", "ATTRIBUTION"]), - status: z.enum(["NOT_APPLICABLE", "MISSED", "DETECTED", "PREVENTED", "ATTRIBUTED"]), - detectionTime: z.coerce.date().optional().nullable(), - notes: z.string().optional().nullable(), - screenshotUrl: z.string().optional().nullable(), - logData: z.string().optional().nullable(), - techniqueId: z.string(), - tools: z.array(z.object({ id: z.string() })).optional(), - logSources: z.array(z.object({ id: z.string() })).optional(), -}); -const attackFlowLayoutSchema = z.object({ id: z.string().optional(), operationId: z.number(), nodes: z.custom(), edges: z.custom() }); -const threatActorTechniqueLinkSchema = z.object({ threatActorId: z.string(), mitreTechniqueId: z.string() }); -const authenticatorSchema = z.object({ - id: z.string().optional(), - credentialID: z.string(), - userId: z.string(), - providerAccountId: z.string(), - credentialPublicKey: z.string(), - counter: z.number(), - credentialDeviceType: z.string(), - credentialBackedUp: z.boolean(), - transports: z.string().optional().nullable(), -}); - -const backupPayloadSchema = z.object({ - threatActors: z.array(threatActorSchema).optional(), - crownJewels: z.array(crownJewelSchema).optional(), - tags: z.array(tagSchema).optional(), - toolCategories: z.array(toolCategorySchema).optional(), - tools: z.array(toolSchema).optional(), - logSources: z.array(logSourceSchema).optional(), - operations: z.array(operationSchema).optional(), - techniques: z.array(techniqueSchema).optional(), - outcomes: z.array(outcomeSchema).optional(), - attackFlowLayouts: z.array(attackFlowLayoutSchema).optional(), - threatActorTechniqueLinks: z.array(threatActorTechniqueLinkSchema).optional(), - users: z.array(z.object({ - id: z.string().optional(), - email: z.string().email(), - name: z.string().nullable().optional(), - role: z.enum(["ADMIN", "OPERATOR", "VIEWER"]).optional(), - lastLogin: z.coerce.date().optional().nullable(), - })).optional(), - authenticators: z.array(authenticatorSchema).optional(), - groups: z.array(z.object({ id: z.string().optional(), name: z.string(), description: z.string() })).optional(), - userGroups: z.array(z.object({ userId: z.string(), groupId: z.string() })).optional(), - mitreTactics: z.array(z.unknown()).optional(), - mitreTechniques: z.array(z.unknown()).optional(), - mitreSubTechniques: z.array(z.unknown()).optional(), -}); - -type BackupPayload = z.infer; - -export const databaseRestoreRouter = createTRPCRouter({ - restore: adminProcedure - .input(z.object({ backupData: z.string(), restoreTaxonomyAndOperations: z.boolean().default(true), restoreUsersAndGroups: z.boolean().default(false), clearBefore: z.boolean().default(true) })) - .mutation(async ({ ctx, input }: { ctx: AuthedContext; input: { backupData: string; restoreTaxonomyAndOperations: boolean; restoreUsersAndGroups: boolean; clearBefore: boolean } }) => { - const db = getPrismaClient(ctx.db); - try { - const data = JSON.parse(input.backupData) as { version?: string; data?: unknown }; - if (!data.version || !data.data) throw new Error("Invalid backup format"); - if (!input.restoreTaxonomyAndOperations && !input.restoreUsersAndGroups) throw new Error("Select at least one section to restore"); - if (input.restoreTaxonomyAndOperations && !input.clearBefore) throw new Error("Restoring taxonomy and operations requires clearing existing data first"); - - if (input.clearBefore && input.restoreTaxonomyAndOperations) { - await db.$transaction(async (tx: Prisma.TransactionClient) => { - await tx.outcome.deleteMany(); - await tx.technique.deleteMany(); - await tx.operation.deleteMany(); - await tx.tool.deleteMany(); - await tx.toolCategory.deleteMany(); - await tx.logSource.deleteMany(); - await tx.tag.deleteMany(); - await tx.crownJewel.deleteMany(); - await tx.threatActor.deleteMany(); - }); - } - - const payload: BackupPayload = backupPayloadSchema.parse(data.data); - - await db.$transaction(async (tx: Prisma.TransactionClient) => { - if (input.restoreUsersAndGroups) { - if (payload.users?.length) { - for (const user of payload.users) { - await tx.user.upsert({ - where: { email: user.email }, - update: { - name: user.name ?? undefined, - role: user.role ?? undefined, - lastLogin: user.lastLogin ?? undefined, - }, - create: { - email: user.email, - name: user.name ?? null, - role: user.role ?? "VIEWER", - lastLogin: user.lastLogin ?? null, - }, - }); - } - } - - if (payload.authenticators) { - await tx.authenticator.deleteMany(); - for (const auth of payload.authenticators) { - await tx.authenticator.create({ - data: { - id: auth.id, - credentialID: auth.credentialID, - userId: auth.userId, - providerAccountId: auth.providerAccountId, - credentialPublicKey: auth.credentialPublicKey, - counter: auth.counter, - credentialDeviceType: auth.credentialDeviceType, - credentialBackedUp: auth.credentialBackedUp, - transports: auth.transports ?? null, - }, - }); - } - } - - if (payload.groups?.length) { - for (const group of payload.groups) { - await tx.group.upsert({ - where: group.id ? { id: group.id } : { name: group.name }, - update: { - name: group.name, - description: group.description, - }, - create: { - ...(group.id ? { id: group.id } : {}), - name: group.name, - description: group.description, - }, - }); - } - } - } - - if (input.restoreTaxonomyAndOperations) { - if (payload.threatActors?.length) await tx.threatActor.createMany({ data: payload.threatActors }); - if (payload.crownJewels?.length) await tx.crownJewel.createMany({ data: payload.crownJewels }); - if (payload.tags?.length) await tx.tag.createMany({ data: payload.tags }); - if (payload.toolCategories?.length) await tx.toolCategory.createMany({ data: payload.toolCategories }); - if (payload.tools?.length) await tx.tool.createMany({ data: payload.tools }); - if (payload.logSources?.length) await tx.logSource.createMany({ data: payload.logSources }); - - const accessGroupIds = new Set(); - for (const op of payload.operations ?? []) { - for (const ag of op.accessGroups ?? []) { - accessGroupIds.add(ag.groupId); - } - } - - if (accessGroupIds.size > 0) { - const groups = await tx.group.findMany({ where: { id: { in: Array.from(accessGroupIds) } } }); - if (groups.length !== accessGroupIds.size) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "One or more groups referenced by operations are missing. Restore groups before operations.", - }); - } - } - - for (const op of payload.operations ?? []) { - const { - tags: opTags = [], - crownJewels: opCrownJewels = [], - accessGroups: opAccessGroups, - visibility: opVisibility, - ...operationFields - } = op; - - const created = await tx.operation.create({ - data: { - ...operationFields, - visibility: opVisibility ?? "EVERYONE", - tags: opTags.length ? { connect: opTags.map(({ id }) => ({ id })) } : undefined, - crownJewels: opCrownJewels.length ? { connect: opCrownJewels.map(({ id }) => ({ id })) } : undefined, - }, - }); - - if (opAccessGroups?.length) { - await tx.operationAccessGroup.createMany({ - data: opAccessGroups.map(({ groupId }) => ({ operationId: created.id, groupId })), - }); - } - } - - for (const technique of payload.techniques ?? []) { - const { tools: techniqueTools = [], ...techniqueFields } = technique; - - await tx.technique.create({ - data: { - ...techniqueFields, - tools: techniqueTools.length ? { connect: techniqueTools.map(({ id }) => ({ id })) } : undefined, - }, - }); - } - - for (const outcome of payload.outcomes ?? []) { - const { - tools: outcomeTools = [], - logSources: outcomeLogSources = [], - ...outcomeFields - } = outcome; - - await tx.outcome.create({ - data: { - ...outcomeFields, - tools: outcomeTools.length ? { connect: outcomeTools.map(({ id }) => ({ id })) } : undefined, - logSources: outcomeLogSources.length ? { connect: outcomeLogSources.map(({ id }) => ({ id })) } : undefined, - }, - }); - } - - for (const layoutRecord of payload.attackFlowLayouts ?? []) { - await tx.attackFlowLayout.upsert({ - where: { operationId: layoutRecord.operationId }, - update: { nodes: layoutRecord.nodes, edges: layoutRecord.edges }, - create: { - operationId: layoutRecord.operationId, - nodes: layoutRecord.nodes, - edges: layoutRecord.edges, - }, - }); - } - - for (const linkRecord of payload.threatActorTechniqueLinks ?? []) { - await tx.threatActor.update({ - where: { id: linkRecord.threatActorId }, - data: { mitreTechniques: { connect: { id: linkRecord.mitreTechniqueId } } }, - }); - } - } - - if (input.restoreUsersAndGroups) { - for (const membership of payload.userGroups ?? []) { - const userExists = await tx.user.findUnique({ where: { id: membership.userId } }); - const groupExists = await tx.group.findUnique({ where: { id: membership.groupId } }); - if (!userExists || !groupExists) continue; - await tx.userGroup.upsert({ - where: { userId_groupId: { userId: membership.userId, groupId: membership.groupId } }, - update: {}, - create: { userId: membership.userId, groupId: membership.groupId }, - }); - } - } - }); - - return { success: true }; - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - logger.error({ event: "database.restore_failed", message }, "Restore error"); - throw new TRPCError({ code: "BAD_REQUEST", message: error instanceof Error ? error.message : "Failed to restore backup" }); - } - }), - - clearData: adminProcedure - .input(z.object({ clearOperations: z.boolean(), clearTaxonomy: z.boolean() })) - .mutation(async ({ ctx, input }: { ctx: AuthedContext; input: { clearOperations: boolean; clearTaxonomy: boolean } }) => { - const db = getPrismaClient(ctx.db); - try { - await db.$transaction(async (tx: Prisma.TransactionClient) => { - if (input.clearOperations) { - await tx.outcome.deleteMany(); - await tx.technique.deleteMany(); - await tx.operation.deleteMany(); - } - if (input.clearTaxonomy) { - await tx.tool.deleteMany(); - await tx.toolCategory.deleteMany(); - await tx.logSource.deleteMany(); - await tx.tag.deleteMany(); - await tx.crownJewel.deleteMany(); - await tx.threatActor.deleteMany(); - } - }); - return { success: true }; - } catch (error) { - const message = error instanceof Error ? error.message : String(error); - logger.error({ event: "database.clear_failed", message }, "Clear data error"); - throw new TRPCError({ code: "INTERNAL_SERVER_ERROR", message: "Failed to clear data" }); - } - }), -}); diff --git a/src/server/auth/config.ts b/src/server/auth/config.ts index ea6f0f8..124ab6b 100644 --- a/src/server/auth/config.ts +++ b/src/server/auth/config.ts @@ -86,7 +86,7 @@ const isAdapter = (value: unknown): value is Adapter => { }); }; -const prismaAdapter = (() => { +const prismaAdapter: Adapter = (() => { if (!isAdapterFactory(PrismaAdapter)) { throw new Error("Invalid Prisma adapter export"); } diff --git a/src/test/data-restore.test.ts b/src/test/data-restore.test.ts new file mode 100644 index 0000000..5264025 --- /dev/null +++ b/src/test/data-restore.test.ts @@ -0,0 +1,99 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { UserRole } from "@prisma/client"; + +import { dataRouter } from "@/server/api/routers/data"; + +vi.mock("@/server/db", () => ({ + db: { + $transaction: vi.fn(), + threatActor: { createMany: vi.fn(), deleteMany: vi.fn(), update: vi.fn() }, + crownJewel: { createMany: vi.fn(), deleteMany: vi.fn() }, + tag: { createMany: vi.fn(), deleteMany: vi.fn() }, + toolCategory: { createMany: vi.fn(), deleteMany: vi.fn() }, + tool: { createMany: vi.fn(), deleteMany: vi.fn() }, + logSource: { createMany: vi.fn(), deleteMany: vi.fn() }, + operation: { create: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, + technique: { create: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, + outcome: { create: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, + attackFlowLayout: { create: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, + }, +})); + +const { db } = await import("@/server/db"); +const mockDb = vi.mocked(db, true); + +const createCaller = (role: UserRole) => + dataRouter.createCaller({ + headers: new Headers(), + session: { user: { id: "u1", role }, expires: "2099-01-01" }, + db: mockDb, + requestId: "data-restore-test", + }); + +describe("Data Restore", () => { + beforeEach(() => { + vi.clearAllMocks(); + mockDb.$transaction.mockImplementation(async (callback) => { + return await callback(mockDb); + }); + }); + + it("restores taxonomy and operations with clearing", async () => { + const caller = createCaller(UserRole.ADMIN); + + mockDb.threatActor.update.mockResolvedValue({}); + + const payload = { + threatActors: [{ id: "ta1", name: "APT29", description: "desc" }], + crownJewels: [{ id: "cj1", name: "DB", description: "desc" }], + tags: [{ id: "tag1", name: "Stealth", description: "d" }], + toolCategories: [{ id: "cat1", name: "EDR", type: "DEFENSIVE" as const }], + tools: [{ id: "tool1", name: "Falcon", categoryId: "cat1", type: "DEFENSIVE" as const }], + logSources: [{ id: "log1", name: "SIEM", description: "d" }], + operations: [ + { id: 1, name: "Op1", description: "d", createdById: "u1", tags: [{ id: "tag1" }], crownJewels: [{ id: "cj1" }] }, + ], + techniques: [{ id: "tech-inst", description: "d", operationId: 1, tools: [{ id: "tool1" }] }], + outcomes: [ + { + id: "out1", + type: "DETECTION" as const, + status: "DETECTED" as const, + techniqueId: "tech-inst", + tools: [{ id: "tool1" }], + logSources: [{ id: "log1" }], + }, + ], + attackFlowLayouts: [{ id: "layout1", operationId: 1, nodes: [], edges: [] }], + threatActorTechniqueLinks: [{ threatActorId: "ta1", mitreTechniqueId: "tech-1" }], + }; + + const backup = JSON.stringify({ version: "2.0", timestamp: new Date().toISOString(), data: payload }); + + await caller.restore({ backupData: backup }); + + expect(mockDb.operation.deleteMany).toHaveBeenCalled(); + expect(mockDb.tool.deleteMany).toHaveBeenCalled(); + expect(mockDb.threatActor.createMany).toHaveBeenCalledWith({ data: payload.threatActors }); + expect(mockDb.operation.create).toHaveBeenCalled(); + expect(mockDb.technique.create).toHaveBeenCalled(); + expect(mockDb.outcome.create).toHaveBeenCalled(); + expect(mockDb.attackFlowLayout.create).toHaveBeenCalled(); + expect(mockDb.threatActor.update).toHaveBeenCalledWith({ + where: { id: "ta1" }, + data: { mitreTechniques: { connect: { id: "tech-1" } } }, + }); + }); + + it("rejects invalid backup data", async () => { + const caller = createCaller(UserRole.ADMIN); + + await expect(caller.restore({ backupData: "not-json" })).rejects.toThrow("Invalid backup file"); + }); + + it("rejects non-admin access", async () => { + const caller = createCaller(UserRole.OPERATOR); + + await expect(caller.restore({ backupData: "{}" })).rejects.toThrow("Admin access required"); + }); +}); diff --git a/src/test/data.test.ts b/src/test/data.test.ts new file mode 100644 index 0000000..e9eb214 --- /dev/null +++ b/src/test/data.test.ts @@ -0,0 +1,129 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { UserRole } from "@prisma/client"; + +import { dataRouter } from "@/server/api/routers/data"; + +vi.mock("@/server/db", () => ({ + db: { + operation: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + technique: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + outcome: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + threatActor: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + crownJewel: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + tag: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + tool: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + toolCategory: { findMany: vi.fn(), deleteMany: vi.fn() }, + logSource: { count: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, + attackFlowLayout: { findMany: vi.fn(), deleteMany: vi.fn() }, + $transaction: vi.fn(), + }, +})); + +const { db } = await import("@/server/db"); +const mockDb = vi.mocked(db, true); + +const createMockContext = (role: UserRole) => ({ + headers: new Headers(), + session: { user: { id: "test-user-id", role }, expires: "2099-01-01" }, + db: mockDb, + requestId: "data-test", +}); + +const createCaller = (role: UserRole) => dataRouter.createCaller(createMockContext(role)); + +describe("Data Router", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe("getStats", () => { + it("returns overview counts for admins", async () => { + const caller = createCaller(UserRole.ADMIN); + + mockDb.operation.count.mockResolvedValue(4); + mockDb.technique.count.mockResolvedValue(9); + mockDb.outcome.count.mockResolvedValue(5); + mockDb.threatActor.count.mockResolvedValue(3); + mockDb.crownJewel.count.mockResolvedValue(2); + mockDb.tag.count.mockResolvedValue(6); + mockDb.tool.count.mockResolvedValue(7); + mockDb.logSource.count.mockResolvedValue(8); + + await expect(caller.getStats()).resolves.toEqual({ + operations: 4, + techniques: 9, + outcomes: 5, + threatActors: 3, + crownJewels: 2, + tags: 6, + tools: 7, + logSources: 8, + }); + }); + + it("rejects non-admin access", async () => { + const caller = createCaller(UserRole.VIEWER); + + await expect(caller.getStats()).rejects.toThrow("Admin access required"); + }); + }); + + describe("backup", () => { + it("creates a backup payload for admins", async () => { + const caller = createCaller(UserRole.ADMIN); + + mockDb.crownJewel.findMany.mockResolvedValue([]); + mockDb.tag.findMany.mockResolvedValue([]); + mockDb.toolCategory.findMany.mockResolvedValue([]); + mockDb.tool.findMany.mockResolvedValue([]); + mockDb.logSource.findMany.mockResolvedValue([]); + mockDb.operation.findMany.mockResolvedValue([]); + mockDb.technique.findMany.mockResolvedValue([]); + mockDb.outcome.findMany.mockResolvedValue([]); + mockDb.attackFlowLayout.findMany.mockResolvedValue([]); + mockDb.threatActor.findMany.mockResolvedValue([]); + + const result = await caller.backup(); + + expect(result).toContain('"version": "2.0"'); + expect(result).toContain('"data":'); + }); + + it("rejects non-admin access", async () => { + const caller = createCaller(UserRole.OPERATOR); + + await expect(caller.backup()).rejects.toThrow("Admin access required"); + }); + }); + + describe("clearData", () => { + beforeEach(() => { + mockDb.$transaction.mockImplementation(async (callback) => { + return await callback(mockDb); + }); + }); + + it("clears operations and taxonomy data together", async () => { + const caller = createCaller(UserRole.ADMIN); + + await caller.clearData(); + + expect(mockDb.outcome.deleteMany).toHaveBeenCalled(); + expect(mockDb.technique.deleteMany).toHaveBeenCalled(); + expect(mockDb.attackFlowLayout.deleteMany).toHaveBeenCalled(); + expect(mockDb.operation.deleteMany).toHaveBeenCalled(); + expect(mockDb.tool.deleteMany).toHaveBeenCalled(); + expect(mockDb.toolCategory.deleteMany).toHaveBeenCalled(); + expect(mockDb.logSource.deleteMany).toHaveBeenCalled(); + expect(mockDb.tag.deleteMany).toHaveBeenCalled(); + expect(mockDb.crownJewel.deleteMany).toHaveBeenCalled(); + expect(mockDb.threatActor.deleteMany).toHaveBeenCalled(); + }); + + it("rejects non-admin access", async () => { + const caller = createCaller(UserRole.VIEWER); + + await expect(caller.clearData()).rejects.toThrow("Admin access required"); + }); + }); +}); diff --git a/src/test/database-restore.test.ts b/src/test/database-restore.test.ts deleted file mode 100644 index ce85b03..0000000 --- a/src/test/database-restore.test.ts +++ /dev/null @@ -1,120 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; -import { UserRole } from "@prisma/client"; -import { databaseRouter } from "@/server/api/routers/database"; - -vi.mock("@/server/db", () => ({ - db: { - $transaction: vi.fn(), - // Taxonomy - threatActor: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - crownJewel: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - tag: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - toolCategory: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - tool: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - logSource: { createMany: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - // Ops - operation: { create: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - technique: { create: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - outcome: { create: vi.fn(), findMany: vi.fn(), deleteMany: vi.fn() }, - // Users & groups - user: { findMany: vi.fn(), upsert: vi.fn() }, - group: { findMany: vi.fn(), upsert: vi.fn(), deleteMany: vi.fn(), findUnique: vi.fn() }, - userGroup: { findMany: vi.fn(), upsert: vi.fn(), deleteMany: vi.fn() }, - // MITRE (read-only in backup) - mitreTactic: { findMany: vi.fn() }, - mitreTechnique: { findMany: vi.fn() }, - mitreSubTechnique: { findMany: vi.fn() }, - }, -})); - -const { db } = await import("@/server/db"); -const mockDb = vi.mocked(db, true); - -const createCaller = (role: UserRole) => { - const ctx = { - headers: new Headers(), - session: { user: { id: "u1", role }, expires: "2099-01-01" }, - db: mockDb, - requestId: "database-restore-test", - }; - return databaseRouter.createCaller(ctx); -}; - -describe("Database Restore", () => { - beforeEach(() => vi.clearAllMocks()); - - it("restores taxonomy and operations after clearing when requested", async () => { - const caller = createCaller(UserRole.ADMIN); - - mockDb.$transaction.mockImplementation(async (cb: (tx: typeof mockDb) => Promise | void) => { - await cb(mockDb); - return undefined as unknown as void; - }); - - const payload = { - threatActors: [{ id: "ta1", name: "APT29", description: "desc" }], - crownJewels: [{ id: "cj1", name: "DB", description: "desc" }], - tags: [{ id: "tg1", name: "Stealth", description: "d" }], - toolCategories: [{ id: "tc1", name: "EDR", type: "DEFENSIVE" }], - tools: [{ id: "tl1", name: "Falcon", categoryId: "tc1", type: "DEFENSIVE" }], - logSources: [{ id: "ls1", name: "SIEM", description: "d" }], - operations: [{ id: 1, name: "Op1", description: "d", createdById: "u1", tags: [{ id: "tg1" }], crownJewels: [{ id: "cj1" }] }], - techniques: [{ id: "tech1", description: "d", operationId: 1 }], - outcomes: [{ id: "out1", type: "DETECTION", status: "DETECTED", techniqueId: "tech1" }], - }; - const backup = JSON.stringify({ version: "1.0", timestamp: new Date().toISOString(), data: payload }); - - await caller.restore({ backupData: backup, restoreTaxonomyAndOperations: true, restoreUsersAndGroups: false, clearBefore: true }); - - expect(mockDb.threatActor.createMany).toHaveBeenCalled(); - expect(mockDb.operation.create).toHaveBeenCalled(); - expect(mockDb.technique.create).toHaveBeenCalled(); - expect(mockDb.outcome.create).toHaveBeenCalled(); - }); - - it("rejects when neither section selected", async () => { - const caller = createCaller(UserRole.ADMIN); - const backup = JSON.stringify({ version: "1.0", timestamp: new Date().toISOString(), data: {} }); - await expect( - caller.restore({ backupData: backup, restoreTaxonomyAndOperations: false, restoreUsersAndGroups: false, clearBefore: false }) - ).rejects.toThrow("Select at least one section to restore"); - }); - - it("requires clearBefore when restoring taxonomy+operations", async () => { - const caller = createCaller(UserRole.ADMIN); - const backup = JSON.stringify({ version: "1.0", timestamp: new Date().toISOString(), data: {} }); - await expect( - caller.restore({ backupData: backup, restoreTaxonomyAndOperations: true, restoreUsersAndGroups: false, clearBefore: false }) - ).rejects.toThrow("requires clearing existing data first"); - }); - - it("fails when operations reference missing access groups", async () => { - const caller = createCaller(UserRole.ADMIN); - - mockDb.$transaction.mockImplementation(async (cb: (tx: typeof mockDb) => Promise | void) => { - await cb(mockDb); - return undefined as unknown as void; - }); - - const payload = { - operations: [ - { - id: 1, - name: "Op1", - description: "d", - createdById: "u1", - visibility: "GROUPS_ONLY", - accessGroups: [{ groupId: "g1" }], - }, - ], - }; - - const backup = JSON.stringify({ version: "1.0", timestamp: new Date().toISOString(), data: payload }); - - mockDb.group.findMany.mockResolvedValue([]); - - await expect( - caller.restore({ backupData: backup, restoreTaxonomyAndOperations: true, restoreUsersAndGroups: false, clearBefore: true }) - ).rejects.toThrow("One or more groups referenced by operations are missing"); - }); -}); diff --git a/src/test/database.test.ts b/src/test/database.test.ts deleted file mode 100644 index 9a5f13e..0000000 --- a/src/test/database.test.ts +++ /dev/null @@ -1,187 +0,0 @@ -import { describe, it, expect, vi, beforeEach } from "vitest"; -import { UserRole } from "@prisma/client"; -import { databaseRouter } from "@/server/api/routers/database"; - -// Mock database -vi.mock("@/server/db", () => ({ - db: { - user: { count: vi.fn(), findMany: vi.fn() }, - operation: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - mitreTactic: { count: vi.fn(), findMany: vi.fn() }, - mitreTechnique: { count: vi.fn(), findMany: vi.fn() }, - mitreSubTechnique: { count: vi.fn(), findMany: vi.fn() }, - technique: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - outcome: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - threatActor: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - crownJewel: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - tag: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - tool: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - toolCategory: { deleteMany: vi.fn(), findMany: vi.fn() }, - logSource: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - group: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - userGroup: { count: vi.fn(), deleteMany: vi.fn(), findMany: vi.fn() }, - attackFlowLayout: { findMany: vi.fn() }, - $transaction: vi.fn(), - }, -})); - -const { db } = await import("@/server/db"); -const mockDb = vi.mocked(db, true); - -// Add missing method (for prior tests); keep for safety -mockDb.user.findFirst = vi.fn(); - -// Helper to create mock context -const createMockContext = (role: UserRole) => ({ - headers: new Headers(), - session: { - user: { id: "test-user-id", role }, - expires: "2099-01-01", - }, - db: mockDb, - requestId: "database-test", -}); - -// Create tRPC caller with mock context -const createCaller = (role: UserRole) => { - const ctx = createMockContext(role); - return databaseRouter.createCaller(ctx); -}; - -describe("Database Router", () => { - beforeEach(() => { - vi.clearAllMocks(); - }); - - describe("getStats", () => { - it("should return database statistics for admin users", async () => { - const caller = createCaller(UserRole.ADMIN); - - // Mock count responses - mockDb.user.count.mockResolvedValue(5); - mockDb.operation.count.mockResolvedValue(10); - mockDb.mitreTactic.count.mockResolvedValue(14); - mockDb.mitreTechnique.count.mockResolvedValue(200); - mockDb.mitreSubTechnique.count.mockResolvedValue(400); - mockDb.technique.count.mockResolvedValue(50); - mockDb.outcome.count.mockResolvedValue(150); - mockDb.threatActor.count.mockResolvedValue(8); - mockDb.crownJewel.count.mockResolvedValue(12); - mockDb.tag.count.mockResolvedValue(6); - mockDb.tool.count.mockResolvedValue(15); - mockDb.logSource.count.mockResolvedValue(4); - mockDb.group.count.mockResolvedValue(3); - mockDb.userGroup.count.mockResolvedValue(10); - - const result = await caller.getStats(); - - expect(result).toEqual({ - users: 5, - operations: 10, - mitreTactics: 14, - mitreTechniques: 200, - mitreSubTechniques: 400, - techniques: 50, - outcomes: 150, - threatActors: 8, - crownJewels: 12, - tags: 6, - tools: 15, - logSources: 4, - groups: 3, - userGroups: 10, - }); - }); - - it("should throw FORBIDDEN for non-admin users", async () => { - const caller = createCaller(UserRole.VIEWER); - - await expect(caller.getStats()).rejects.toThrow("Admin access required"); - }); - }); - - describe("backup", () => { - it("should create backup for admin users", async () => { - const caller = createCaller(UserRole.ADMIN); - - // Mock findMany methods for backup data - mockDb.user.findMany = vi.fn().mockResolvedValue([ - { id: "1", email: "admin@example.com", name: "Admin" } - ]); - mockDb.operation.findMany = vi.fn().mockResolvedValue([]); - mockDb.mitreTactic.findMany = vi.fn().mockResolvedValue([]); - mockDb.mitreTechnique.findMany = vi.fn().mockResolvedValue([]); - mockDb.mitreSubTechnique.findMany = vi.fn().mockResolvedValue([]); - mockDb.threatActor.findMany = vi.fn().mockResolvedValue([]); - mockDb.crownJewel.findMany = vi.fn().mockResolvedValue([]); - mockDb.tag.findMany = vi.fn().mockResolvedValue([]); - mockDb.tool.findMany = vi.fn().mockResolvedValue([]); - mockDb.logSource.findMany = vi.fn().mockResolvedValue([]); - mockDb.attackFlowLayout.findMany = vi.fn().mockResolvedValue([]); - mockDb.group.findMany = vi.fn().mockResolvedValue([]); - mockDb.userGroup.findMany = vi.fn().mockResolvedValue([]); - mockDb.technique.findMany = vi.fn().mockResolvedValue([]); - mockDb.outcome.findMany = vi.fn().mockResolvedValue([]); - - const result = await caller.backup(); - - expect(result).toContain('"version": "1.0"'); - expect(result).toContain('"data":'); - expect(typeof result).toBe("string"); - }); - - it("should throw FORBIDDEN for non-admin users", async () => { - const caller = createCaller(UserRole.OPERATOR); - - await expect(caller.backup()).rejects.toThrow("Admin access required"); - }); - }); - - describe("clearData", () => { - it("should clear operations data for admin users", async () => { - const caller = createCaller(UserRole.ADMIN); - - mockDb.$transaction.mockImplementation(async (callback) => { - return await callback(mockDb); - }); - mockDb.outcome.deleteMany.mockResolvedValue({ count: 10 }); - mockDb.technique.deleteMany.mockResolvedValue({ count: 5 }); - mockDb.operation.deleteMany.mockResolvedValue({ count: 3 }); - - await caller.clearData({ clearOperations: true, clearTaxonomy: false }); - - expect(mockDb.outcome.deleteMany).toHaveBeenCalled(); - expect(mockDb.technique.deleteMany).toHaveBeenCalled(); - expect(mockDb.operation.deleteMany).toHaveBeenCalled(); - expect(mockDb.threatActor.deleteMany).not.toHaveBeenCalled(); - }); - - it("should clear taxonomy data for admin users", async () => { - const caller = createCaller(UserRole.ADMIN); - - mockDb.$transaction.mockImplementation(async (callback) => { - return await callback(mockDb); - }); - mockDb.userGroup.deleteMany.mockResolvedValue({ count: 5 }); - mockDb.group.deleteMany.mockResolvedValue({ count: 2 }); - mockDb.logSource.deleteMany.mockResolvedValue({ count: 3 }); - mockDb.tool.deleteMany.mockResolvedValue({ count: 8 }); - mockDb.tag.deleteMany.mockResolvedValue({ count: 4 }); - mockDb.crownJewel.deleteMany.mockResolvedValue({ count: 6 }); - mockDb.threatActor.deleteMany.mockResolvedValue({ count: 7 }); - - await caller.clearData({ clearOperations: false, clearTaxonomy: true }); - - expect(mockDb.threatActor.deleteMany).toHaveBeenCalled(); - expect(mockDb.tag.deleteMany).toHaveBeenCalled(); - expect(mockDb.tool.deleteMany).toHaveBeenCalled(); - expect(mockDb.operation.deleteMany).not.toHaveBeenCalled(); - }); - - it("should throw FORBIDDEN for non-admin users", async () => { - const caller = createCaller(UserRole.VIEWER); - - await expect(caller.clearData({ clearOperations: true, clearTaxonomy: false })).rejects.toThrow("Admin access required"); - }); - }); -});