diff --git a/lib/HyperGraphPartialRippingSolver.ts b/lib/HyperGraphPartialRippingSolver.ts new file mode 100644 index 0000000..42bb330 --- /dev/null +++ b/lib/HyperGraphPartialRippingSolver.ts @@ -0,0 +1,45 @@ +import type { Connection, HyperGraph, SerializedConnection } from "./types" +import type { + Candidate, + Region, + RegionPort, + SerializedHyperGraph, +} from "./types" +import { HyperGraphSolver } from "./HyperGraphSolver" + +export type HyperGraphPartialRippingInput = { + inputGraph: HyperGraph | SerializedHyperGraph + inputConnections: (Connection | SerializedConnection)[] + greedyMultiplier?: number + ripCost?: number + rippingEnabled?: boolean + ripCostThreshold?: number +} + +export class HyperGraphPartialRippingSolver< + RegionType extends Region = Region, + RegionPortType extends RegionPort = RegionPort, + CandidateType extends Candidate = Candidate< + RegionType, + RegionPortType + >, +> extends HyperGraphSolver { + override getSolverName(): string { + return "HyperGraphPartialRippingSolver" + } + + ripCostThreshold = 0 + + constructor(input: HyperGraphPartialRippingInput) { + super({ + ...input, + rippingEnabled: input.rippingEnabled ?? true, + }) + this.ripCostThreshold = input.ripCostThreshold ?? this.ripCostThreshold + } + + override shouldAllowRip(candidate: CandidateType): boolean { + const priorCost = candidate.parent?.g ?? 0 + return priorCost >= this.ripCostThreshold + } +} diff --git a/lib/HyperGraphSolver.ts b/lib/HyperGraphSolver.ts index 923214c..d92db3d 100644 --- a/lib/HyperGraphSolver.ts +++ b/lib/HyperGraphSolver.ts @@ -147,6 +147,16 @@ export class HyperGraphSolver< return [] } + /** + * OPTIONALLY OVERRIDE THIS + * + * Return true if a candidate that requires ripping should be considered. + * This allows partial ripping strategies to gate when ripping is allowed. + */ + shouldAllowRip(_candidate: CandidateType): boolean { + return true + } + computeG(candidate: CandidateType): number { return ( candidate.parent!.g + @@ -194,6 +204,13 @@ export class HyperGraphSolver< if (!this.rippingEnabled && newCandidate.ripRequired) { continue } + if ( + this.rippingEnabled && + newCandidate.ripRequired && + !this.shouldAllowRip(newCandidate as CandidateType) + ) { + continue + } nextCandidatesByRegion[newCandidate.nextRegion!.regionId] ??= [] nextCandidatesByRegion[newCandidate.nextRegion!.regionId].push( diff --git a/lib/index.ts b/lib/index.ts index bf07763..5d0003b 100644 --- a/lib/index.ts +++ b/lib/index.ts @@ -3,6 +3,7 @@ export * from "./JumperGraphSolver/jumper-graph-generator/generateJumperGrid" export * from "./JumperGraphSolver/jumper-graph-generator/createGraphWithConnectionsFromBaseGraph" export * from "./JumperGraphSolver/JumperGraphSolver" export * from "./HyperGraphSolver" +export * from "./HyperGraphPartialRippingSolver" export * from "./convertHyperGraphToSerializedHyperGraph" export * from "./convertConnectionsToSerializedConnections" export * from "./JumperGraphSolver/geometry/applyTransformToGraph" diff --git a/tests/__snapshots__/hypergraph-partial-ripping.snap.svg b/tests/__snapshots__/hypergraph-partial-ripping.snap.svg new file mode 100644 index 0000000..a11873d --- /dev/null +++ b/tests/__snapshots__/hypergraph-partial-ripping.snap.svg @@ -0,0 +1 @@ +thresholdZeroconn-2 ripconn-1 okthresholdTwoconn-1 okconn-2 ok \ No newline at end of file diff --git a/tests/hypergraph-partial-ripping.test.ts b/tests/hypergraph-partial-ripping.test.ts new file mode 100644 index 0000000..dc162c7 --- /dev/null +++ b/tests/hypergraph-partial-ripping.test.ts @@ -0,0 +1,208 @@ +import { expect, test } from "bun:test" +import { HyperGraphPartialRippingSolver } from "lib/HyperGraphPartialRippingSolver" +import type { HyperGraph, Connection } from "lib/types" + +type BasicRegion = { + regionId: string + ports: BasicPort[] + d: Record +} + +type BasicPort = { + portId: string + region1: BasicRegion + region2: BasicRegion + d: Record + assignment?: never + ripCount?: number +} + +const buildGraph = (): { graph: HyperGraph; connections: Connection[] } => { + const regionA: BasicRegion = { regionId: "A", ports: [], d: {} } + const regionB: BasicRegion = { regionId: "B", ports: [], d: {} } + const regionC: BasicRegion = { regionId: "C", ports: [], d: {} } + const regionD: BasicRegion = { regionId: "D", ports: [], d: {} } + const regionE: BasicRegion = { regionId: "E", ports: [], d: {} } + const regionF: BasicRegion = { regionId: "F", ports: [], d: {} } + const regionG: BasicRegion = { regionId: "G", ports: [], d: {} } + + const port1: BasicPort = { + portId: "P1", + region1: regionA, + region2: regionC, + d: {}, + } + const port2: BasicPort = { + portId: "P2", + region1: regionC, + region2: regionB, + d: {}, + } + const port3: BasicPort = { + portId: "P3", + region1: regionA, + region2: regionD, + d: {}, + } + const port4: BasicPort = { + portId: "P4", + region1: regionD, + region2: regionE, + d: {}, + } + const port5: BasicPort = { + portId: "P5", + region1: regionE, + region2: regionB, + d: {}, + } + const port6: BasicPort = { + portId: "P6", + region1: regionA, + region2: regionF, + d: {}, + } + const port7: BasicPort = { + portId: "P7", + region1: regionF, + region2: regionG, + d: {}, + } + const port8: BasicPort = { + portId: "P8", + region1: regionG, + region2: regionB, + d: {}, + } + + regionA.ports.push(port1, port3, port6) + regionB.ports.push(port2, port5, port8) + regionC.ports.push(port1, port2) + regionD.ports.push(port3, port4) + regionE.ports.push(port4, port5) + regionF.ports.push(port6, port7) + regionG.ports.push(port7, port8) + + const graph: HyperGraph = { + regions: [regionA, regionB, regionC, regionD, regionE, regionF, regionG], + ports: [port1, port2, port3, port4, port5, port6, port7, port8], + } + + const connections: Connection[] = [ + { + connectionId: "conn-1", + mutuallyConnectedNetworkId: "net-1", + startRegion: regionA, + endRegion: regionB, + }, + { + connectionId: "conn-2", + mutuallyConnectedNetworkId: "net-2", + startRegion: regionA, + endRegion: regionB, + }, + ] + + return { graph, connections } +} + +class BasicPartialRippingSolver extends HyperGraphPartialRippingSolver< + BasicRegion, + BasicPort +> { + override estimateCostToEnd(): number { + return 0 + } + + override computeIncreasedRegionCostIfPortsAreUsed(): number { + return 1 + } + + override getPortUsagePenalty(port: BasicPort): number { + return (port.ripCount ?? 0) * 5 + } +} + +const solveWithThreshold = (ripCostThreshold: number) => { + const { graph, connections } = buildGraph() + const solver = new BasicPartialRippingSolver({ + inputGraph: graph, + inputConnections: connections, + ripCostThreshold, + }) + solver.solve() + return solver.solvedRoutes.map((route) => ({ + connectionId: route.connection.connectionId, + requiredRip: route.requiredRip, + portIds: route.path.map((candidate) => candidate.port.portId), + })) +} + +const renderSvg = (results: { + thresholdZero: Array<{ connectionId: string; requiredRip: boolean }> + thresholdTwo: Array<{ connectionId: string; requiredRip: boolean }> +}) => { + const rowHeight = 40 + const gap = 10 + const leftPadding = 20 + const topPadding = 20 + const barWidth = 260 + const barHeight = 18 + + const rows = [ + { label: "thresholdZero", items: results.thresholdZero }, + { label: "thresholdTwo", items: results.thresholdTwo }, + ] + + const height = topPadding + rows.length * rowHeight + (rows.length - 1) * gap + + let y = topPadding + const bars = rows + .map((row) => { + const rowY = y + y += rowHeight + gap + const label = `${row.label}` + const rects = row.items + .map((item, index) => { + const color = item.requiredRip ? "#e74c3c" : "#2ecc71" + const rectX = leftPadding + 120 + index * (barWidth + 12) + const rectY = rowY + const rect = `` + const text = `${ + item.connectionId + }${item.requiredRip ? " rip" : " ok"}` + return `${rect}${text}` + }) + .join("") + return `${label}${rects}` + }) + .join("") + + return `${bars}` +} + +test("hypergraph partial ripping defers ripping until threshold", () => { + const results = { + thresholdZero: solveWithThreshold(0), + thresholdTwo: solveWithThreshold(2), + } + + const svg = renderSvg({ + thresholdZero: results.thresholdZero.map( + ({ connectionId, requiredRip }) => ({ + connectionId, + requiredRip, + }), + ), + thresholdTwo: results.thresholdTwo.map(({ connectionId, requiredRip }) => ({ + connectionId, + requiredRip, + })), + }) + + expect(svg).toMatchSvgSnapshot(import.meta.path) +})