diff --git a/CMakeLists.txt b/CMakeLists.txt index c2239cdcb0..fc1a296dbe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ option(FF_BUILD_ALL_EXAMPLES "build all examples. Overrides others" OFF) option(FF_BUILD_UNIT_TESTS "build non-operator unit tests" OFF) option(FF_BUILD_SUBSTITUTION_TOOL "build substitution conversion tool" OFF) option(FF_BUILD_VISUALIZATION_TOOL "build substitution visualization tool" ON) +option(FF_BUILD_SP_IZATION_BENCHMARKING "build sp-ization benchmarking" ON) option(FF_BUILD_ARG_PARSER "build command line argument parser" OFF) option(FF_BUILD_BIN_EXPORT_MODEL_ARCH "build export-model-arch utility" ON) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 212cf68a17..ac19f9011e 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -6,6 +6,10 @@ if(FF_BUILD_VISUALIZATION_TOOL) add_subdirectory(substitution-to-dot) endif() +if(FF_BUILD_SP_IZATION_BENCHMARKING) + add_subdirectory(sp_ization_benchmarking) +endif() + if(FF_BUILD_ARG_PARSER) add_subdirectory(arg_parser) endif() diff --git a/bin/export-model-arch/src/export-model-arch/main.cc b/bin/export-model-arch/src/export-model-arch/main.cc index 82aebd2b2c..29be28b0ef 100644 --- a/bin/export-model-arch/src/export-model-arch/main.cc +++ b/bin/export-model-arch/src/export-model-arch/main.cc @@ -1,6 +1,6 @@ #include "compiler/series_parallel/computation_graph/computation_graph_binary_sp_decomposition.h" #include "compiler/series_parallel/computation_graph/get_computation_graph_series_parallel_decomposition.h" -#include "export_model_arch/json_sp_model_export.dtg.h" +#include "export-model-arch/json_sp_model_export.dtg.h" #include "models/bert/bert.h" #include "models/candle_uno/candle_uno.h" #include "models/dlrm/dlrm.h" diff --git a/bin/sp_ization_benchmarking/CMakeLists.txt b/bin/sp_ization_benchmarking/CMakeLists.txt new file mode 100644 index 0000000000..a24e84e31d --- /dev/null +++ b/bin/sp_ization_benchmarking/CMakeLists.txt @@ -0,0 +1,9 @@ +ff_add_executable( + NAME + sp-ization-benchmarking + SRC_PATTERNS + *.cc + DEPS + utils + rapidcheck +) diff --git a/bin/sp_ization_benchmarking/distributions.cc b/bin/sp_ization_benchmarking/distributions.cc new file mode 100644 index 0000000000..6c59f58a34 --- /dev/null +++ b/bin/sp_ization_benchmarking/distributions.cc @@ -0,0 +1,55 @@ +#include "distributions.h" + +namespace FlexFlow { + +Constant::Constant(float val) : val(val) {} + +float Constant::operator()() const { + return val; +} + +Uniform::Uniform(float a, float b) : a(a), b(b) {} + +float Uniform::operator()() const { + return a + ((static_cast(std::rand()) / RAND_MAX) * (b - a)); +} + +Bernoulli::Bernoulli(float p) : p(p) {} + +float Bernoulli::operator()() const { + return (Uniform(0, 1)() < p); +} + +Binary::Binary(float a, float b, float p) : a(a), b(b), p(p) {} + +float Binary::operator()() const { + return (Bernoulli(p)() ? a : b); +} + +Chooser::Chooser(std::vector items) : items(items) {} + +float Chooser::operator()() const { + return items[std::rand() % items.size()]; +} + +UniformNoise::UniformNoise(float lower, float upper) + : lower(lower), upper(upper) {} + +float UniformNoise::operator()() const { + return Uniform(lower, upper)(); +} + +float NoNoise::operator()() const { + return 1; +} + +GaussianNoise::GaussianNoise(float mean, float stddev) + : mean(mean), stddev(stddev) {} + +float GaussianNoise::operator()() const { + static std::default_random_engine generator; + static std::normal_distribution distribution(mean, stddev); + return distribution(generator); +} + +} // namespace FlexFlow diff --git a/bin/sp_ization_benchmarking/distributions.h b/bin/sp_ization_benchmarking/distributions.h new file mode 100644 index 0000000000..ea24d55898 --- /dev/null +++ b/bin/sp_ization_benchmarking/distributions.h @@ -0,0 +1,81 @@ +#ifndef _FLEXFLOW_DISTRIBUTIONS_H +#define _FLEXFLOW_DISTRIBUTIONS_H + +#include "utils/graph/node/node.dtg.h" +#include +#include +#include + +namespace FlexFlow { + +struct Constant { + float val; + Constant(float val = 1); + float operator()() const; +}; + +struct Uniform { + float a, b; + Uniform(float a = 0, float b = 1); + float operator()() const; +}; + +struct Bernoulli { + float p; + Bernoulli(float p = 0.5); + float operator()() const; +}; + +struct Binary { + float a, b, p; + Binary(float a = 0, float b = 1, float p = 0.5); + float operator()() const; +}; + +struct Chooser { + std::vector items; + Chooser(std::vector); + float operator()() const; +}; + +struct UniformNoise { + float lower, upper; + UniformNoise(float lower = 0.9, float upper = 1.1); + float operator()() const; +}; + +struct NoNoise { + float operator()() const; +}; + +struct GaussianNoise { + float mean, stddev; + GaussianNoise(float mean = 1, float stddev = .1); + float operator()() const; +}; + +template +std::unordered_map + make_cost_map(std::unordered_set const &nodes, + Dist const &distribution) { + std::unordered_map cost_map; + for (Node const &node : nodes) { + cost_map[node] = distribution(); + } + return cost_map; +} + +template +std::unordered_map + add_noise_to_cost_map(std::unordered_map cost_map, + Noise const &noise) { + std::unordered_map noisy_cost_map; + for (auto const &[node, cost] : cost_map) { + noisy_cost_map[node] = noise() * cost; + } + return noisy_cost_map; +} + +} // namespace FlexFlow + +#endif diff --git a/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h new file mode 100644 index 0000000000..2769ffdcf4 --- /dev/null +++ b/bin/sp_ization_benchmarking/nasnet_bench_graph_generator.h @@ -0,0 +1,135 @@ +/** + * @brief Utilities for generating random DAGs based on the NASNet-A + * architecture. NASNet-A is composed of a series of cells, which we randomly generate. + * + * For context, see: + * - Paper: https://arxiv.org/abs/1902.09635 + * - Reference implementation: https://github.com/google-research/nasbench/blob/master/nasbench/api.py + */ + +#include "utils/containers.h" +#include "utils/containers/all_of.h" +#include "utils/containers/repeat.h" +#include "utils/containers/transform.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/digraph_generation.h" +#include +#include + +constexpr size_t MIN_NODES = 6; +constexpr size_t MAX_NODES = 8; +constexpr size_t MIN_EDGES = 8; +constexpr size_t MAX_EDGES = 11; +constexpr size_t NUM_CELLS = 9; + +using AdjacencyMatrix = std::vector>; +namespace FlexFlow { +struct NasNetBenchConfig { + AdjacencyMatrix adjacency_matrix; +}; + +bool is_valid_config(NasNetBenchConfig const &config) { + AdjacencyMatrix const &matrix = config.adjacency_matrix; + const size_t size = matrix.size(); + + auto is_valid_size = [](size_t s) { + return s >= MIN_NODES && s <= MAX_NODES; + }; + + auto is_square_matrix = [&](auto const &m) { + return all_of(m, [&](const auto &row) { return row.size() == size; }); + }; + + auto is_upper_triangular = [&](auto const &m) { + for (size_t i = 0; i < size; ++i) { + for (size_t j = 0; j <= i; ++j) { + if (matrix[i][j]) { + return false; + } + } + } + return true; + }; + + return is_valid_size(size) && is_square_matrix(matrix) && + is_upper_triangular(matrix); +} + +bool is_valid_cell(DiGraphView const &g) { + size_t num_edges = get_edges(g).size(); + return (is_acyclic(g)) && (get_initial_nodes(g).size() == 1) && + (get_terminal_nodes(g).size() == 1) && (num_edges <= MAX_EDGES) && + (num_edges >= MIN_EDGES) && (num_edges <= MAX_NODES) && + (num_edges >= MIN_NODES) && + (num_edges > num_nodes(g)); // filter linear cell and diamond cell +} + +NasNetBenchConfig generate_random_config() { + static std::uniform_int_distribution<> size_dist(MIN_NODES, MAX_NODES); + Binary bin = Binary(0, 1); + + size_t num_nodes = Uniform(MIN_NODES, MAX_NODES)(); + std::vector> matrix(num_nodes, + std::vector(num_nodes, false)); + + for (size_t i = 0; i < num_nodes; ++i) { + for (size_t j = i + 1; j < num_nodes; ++j) { + matrix[i][j] = bin(); + } + } + + return {matrix}; +} + +std::optional + maybe_generate_nasnet_bench_cell(NasNetBenchConfig const &config) { + if (!is_valid_config(config)) { + return std::nullopt; + } + + DiGraph g = DiGraph::create(); + std::vector nodes = add_nodes(g, config.adjacency_matrix.size()); + + for (size_t i = 0; i < nodes.size(); ++i) { + for (size_t j = i + 1; j < nodes.size(); ++j) { + if (config.adjacency_matrix[i][j]) { + g.add_edge(DirectedEdge{nodes[i], nodes[j]}); + } + } + } + + g = materialize_digraph_view(transitive_reduction(g)); + + if (!is_valid_cell(g)) { + return std::nullopt; + } + + return g; +} + +DiGraph generate_nasnet_bench_cell() { + while (true) { + NasNetBenchConfig config = generate_random_config(); + std::optional maybe_cell = + maybe_generate_nasnet_bench_cell(config); + if (maybe_cell) { + return maybe_cell.value(); + } + } +} + +DiGraph generate_nasnet_bench_network() { + DiGraph g = series_composition( + transform(repeat(NUM_CELLS, generate_nasnet_bench_cell), + [](auto const cell) -> DiGraphView { return cell; })); + return g; +} +} // namespace FlexFlow diff --git a/bin/sp_ization_benchmarking/sample_graphs.h b/bin/sp_ization_benchmarking/sample_graphs.h new file mode 100644 index 0000000000..1ec53a80b8 --- /dev/null +++ b/bin/sp_ization_benchmarking/sample_graphs.h @@ -0,0 +1,351 @@ +#ifndef FLEXFLOW_GRAPH_GENERATION_H +#define FLEXFLOW_GRAPH_GENERATION_H + +#include "distributions.h" +#include "sample_graphs.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/digraph_generation.h" +#include + +namespace FlexFlow { + +std::tuple make_normal_taso_nasnet_cell() { + DiGraph g = DiGraph::create(); + std::vector inputs = add_nodes(g, 2); + std::vector sep = add_nodes(g, 5); + std::vector id = add_nodes(g, 2); + std::vector avg = add_nodes(g, 3); + std::vector add = add_nodes(g, 5); + std::vector concat = add_nodes(g, 1); + + std::vector edges = {DirectedEdge{inputs.at(0), sep.at(1)}, + DirectedEdge{inputs.at(0), id.at(1)}, + DirectedEdge{inputs.at(0), avg.at(1)}, + DirectedEdge{inputs.at(0), avg.at(2)}, + DirectedEdge{inputs.at(0), sep.at(3)}, + DirectedEdge{inputs.at(0), sep.at(4)}, + DirectedEdge{inputs.at(1), sep.at(0)}, + DirectedEdge{inputs.at(1), id.at(0)}, + DirectedEdge{inputs.at(1), avg.at(0)}, + DirectedEdge{inputs.at(1), sep.at(2)}, + DirectedEdge{sep.at(0), add.at(0)}, + DirectedEdge{id.at(0), add.at(0)}, + DirectedEdge{sep.at(1), add.at(1)}, + DirectedEdge{sep.at(2), add.at(1)}, + DirectedEdge{avg.at(0), add.at(2)}, + DirectedEdge{id.at(1), add.at(2)}, + DirectedEdge{avg.at(1), add.at(3)}, + DirectedEdge{avg.at(2), add.at(3)}, + DirectedEdge{sep.at(3), add.at(4)}, + DirectedEdge{sep.at(4), add.at(4)}}; + add_edges(g, edges); + + for (Node const &a : add) { + g.add_edge(DirectedEdge{a, concat.at(0)}); + } + + assert(get_terminal_nodes(g).size() == 1); + assert(get_initial_nodes(g).size() == 2); + assert(is_acyclic(g)); + return {g, inputs.at(0), inputs.at(1)}; +} + +std::tuple make_reduction_taso_nasnet_cell() { + DiGraph g = DiGraph::create(); + std::vector inputs = add_nodes(g, 2); + std::vector sep = add_nodes(g, 5); + std::vector id = add_nodes(g, 1); + std::vector avg = add_nodes(g, 2); + std::vector max = add_nodes(g, 2); + std::vector add = add_nodes(g, 5); + std::vector concat = add_nodes(g, 1); + + std::vector edges = {DirectedEdge{inputs.at(0), sep.at(0)}, + DirectedEdge{inputs.at(0), sep.at(2)}, + DirectedEdge{inputs.at(0), sep.at(3)}, + DirectedEdge{inputs.at(1), max.at(1)}, + DirectedEdge{inputs.at(1), sep.at(1)}, + DirectedEdge{inputs.at(1), max.at(0)}, + DirectedEdge{inputs.at(1), avg.at(0)}, + DirectedEdge{sep.at(0), add.at(0)}, + DirectedEdge{sep.at(1), add.at(0)}, + DirectedEdge{max.at(0), add.at(1)}, + DirectedEdge{sep.at(2), add.at(1)}, + DirectedEdge{avg.at(0), add.at(2)}, + DirectedEdge{sep.at(3), add.at(2)}, + DirectedEdge{max.at(1), add.at(3)}, + DirectedEdge{sep.at(4), add.at(3)}, + DirectedEdge{avg.at(1), add.at(4)}, + DirectedEdge{id.at(0), add.at(4)}, + DirectedEdge{add.at(0), sep.at(4)}, + DirectedEdge{add.at(0), avg.at(1)}, + DirectedEdge{add.at(1), id.at(0)}, + DirectedEdge{add.at(2), concat.at(0)}, + DirectedEdge{add.at(3), concat.at(0)}, + DirectedEdge{add.at(4), concat.at(0)}}; + + add_edges(g, edges); + + assert(get_terminal_nodes(g).size() == 1); + assert(get_initial_nodes(g).size() == 2); + assert(is_acyclic(g)); + return {g, inputs.at(0), inputs.at(1)}; +} + +DiGraph make_full_taso_nasnet(size_t num_reduction_cells, size_t N) { + DiGraph g = DiGraph::create(); + Node input = g.add_node(); + std::deque outputting = {input, input, input}; + std::deque inputting; + size_t num_cells = num_reduction_cells + N * (num_reduction_cells + 1); + for (int i = 0; i < num_cells; i++) { + auto [s, earlier_input, later_input] = + (i % (N + 1) == N) ? make_reduction_taso_nasnet_cell() + : make_normal_taso_nasnet_cell(); + Node cell_output = get_only(get_terminal_nodes(s)); + std::unordered_map node_map = parallel_extend(g, s); + later_input = node_map.at(later_input); + earlier_input = node_map.at(earlier_input); + cell_output = node_map.at(cell_output); + + outputting.push_back(cell_output); + outputting.push_back(cell_output); + inputting.push_back(earlier_input); + inputting.push_back(later_input); + + Node a = outputting.front(); + Node b = inputting.front(); + inputting.pop_front(); + outputting.pop_front(); + g.add_edge(DirectedEdge{a, b}); + + a = outputting.front(); + b = inputting.front(); + inputting.pop_front(); + outputting.pop_front(); + g.add_edge(DirectedEdge{a, b}); + + assert(is_2_terminal_dag(g)); + assert(inputting.size() == 0); + assert(outputting.size() == 3); + } + return g; +} + +DiGraph make_linear(size_t length) { + DiGraph g = DiGraph::create(); + if (length == 0) { + return g; + } + std::vector nodes = add_nodes(g, length); + + for (size_t i = 0; i < length - 1; ++i) { + g.add_edge(DirectedEdge{nodes.at(i), nodes.at(i + 1)}); + } + + return g; +} + +DiGraph make_rhombus() { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}; + + add_edges(g, edges); + return g; +} + +DiGraph make_diamond() { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }; + + add_edges(g, edges); + return g; +} + +DiGraph make_fully_connected(std::vector layer_sizes) { + DiGraph g = DiGraph::create(); + std::vector> layers = + transform(layer_sizes, [&g](size_t size) { return add_nodes(g, size); }); + + std::vector edges; + + for (size_t i = 0; i < layers.size() - 1; ++i) { + for (Node const &n1 : layers.at(i)) { + for (Node const &n2 : layers.at(i + 1)) { + edges.push_back(DirectedEdge{n1, n2}); + } + } + } + + add_edges(g, edges); + return g; +} + +DiGraph make_parallel_chains(size_t chain_length, size_t chain_num) { + DiGraph g = DiGraph::create(); + assert(chain_length >= 3); + assert(chain_num >= 1); + std::vector> chains; + + for (size_t i = 0; i < chain_num; i++) { + std::vector chain_nodes = add_nodes(g, chain_length - 2); + chains.push_back(chain_nodes); + + for (size_t j = 0; j < chain_length - 3; j++) { + g.add_edge(DirectedEdge{chain_nodes.at(j), chain_nodes.at(j + 1)}); + } + } + + Node source = g.add_node(); + Node sink = g.add_node(); + + for (std::vector const &chain : chains) { + g.add_edge(DirectedEdge{source, chain.front()}); + g.add_edge(DirectedEdge{chain.back(), sink}); + } + + return g; +} + +DiGraph make_sample_dag_1() { + + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(0), n.at(6)}, + DirectedEdge{n.at(2), n.at(6)}, + DirectedEdge{n.at(6), n.at(5)}}; + add_edges(g, edges); + assert(is_2_terminal_dag(g)); + return g; +} + +DiGraph make_sample_dag_2() { + NOT_IMPLEMENTED(); +} + +DiGraph make_sample_dag_3() { + // Taken by "A New Algorithm for Mapping DAGs to Series-ParallelSplit Form, + // Escribano et Al, 2002" + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 18); + + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(10)}, DirectedEdge{n.at(2), n.at(11)}, + DirectedEdge{n.at(2), n.at(12)}, DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(3), n.at(6)}, DirectedEdge{n.at(4), n.at(6)}, + DirectedEdge{n.at(4), n.at(7)}, DirectedEdge{n.at(4), n.at(10)}, + DirectedEdge{n.at(5), n.at(8)}, DirectedEdge{n.at(6), n.at(8)}, + DirectedEdge{n.at(6), n.at(9)}, DirectedEdge{n.at(7), n.at(8)}, + DirectedEdge{n.at(8), n.at(17)}, DirectedEdge{n.at(9), n.at(17)}, + DirectedEdge{n.at(10), n.at(16)}, DirectedEdge{n.at(11), n.at(16)}, + DirectedEdge{n.at(12), n.at(13)}, DirectedEdge{n.at(12), n.at(14)}, + DirectedEdge{n.at(13), n.at(15)}, DirectedEdge{n.at(14), n.at(15)}, + DirectedEdge{n.at(15), n.at(16)}, DirectedEdge{n.at(16), n.at(17)}}; + + add_edges(g, edges); + return g; +} + +DiGraph make_taso_nasnet_cell() { + // From the TASO paper, pg 57 + DiGraph g = DiGraph::create(); + Node root = g.add_node(); + std::vector input = add_nodes(g, 2); + std::vector dwc = add_nodes(g, 5); + std::vector conv = add_nodes(g, 5); + std::vector avg = add_nodes(g, 3); + std::vector add = add_nodes(g, 5); + Node concat = g.add_node(); + + std::vector edges = {DirectedEdge{root, input.at(0)}, + DirectedEdge{root, input.at(1)}, + DirectedEdge{input.at(0), dwc.at(0)}, + DirectedEdge{input.at(0), dwc.at(1)}, + DirectedEdge{input.at(0), avg.at(0)}, + DirectedEdge{input.at(0), avg.at(1)}, + DirectedEdge{input.at(0), avg.at(2)}, + DirectedEdge{input.at(0), dwc.at(2)}, + DirectedEdge{input.at(1), add.at(2)}, + DirectedEdge{input.at(1), dwc.at(3)}, + DirectedEdge{input.at(1), dwc.at(4)}, + DirectedEdge{input.at(1), add.at(4)}, + DirectedEdge{dwc.at(0), conv.at(0)}, + DirectedEdge{dwc.at(1), conv.at(1)}, + DirectedEdge{dwc.at(2), conv.at(2)}, + DirectedEdge{dwc.at(3), conv.at(3)}, + DirectedEdge{dwc.at(4), conv.at(4)}, + DirectedEdge{conv.at(0), add.at(0)}, + DirectedEdge{conv.at(1), add.at(0)}, + DirectedEdge{avg.at(0), add.at(1)}, + DirectedEdge{avg.at(1), add.at(1)}, + DirectedEdge{avg.at(2), add.at(2)}, + DirectedEdge{conv.at(2), add.at(3)}, + DirectedEdge{conv.at(3), add.at(3)}, + DirectedEdge{conv.at(4), add.at(4)}}; + + add_edges(g, edges); + + for (Node const &a : add) { + g.add_edge(DirectedEdge{a, concat}); + } + return g; +} + +DiGraph make_2_terminal_random_dag(size_t num_nodes, float p, size_t step) { + DiGraph g = DiGraph::create(); + Bernoulli sampler = Bernoulli(p); + std::vector n = add_nodes(g, num_nodes - 2); + for (int i = 0; i < n.size(); i++) { + for (int j = i + step + 1; j < n.size(); j++) { + if (sampler()) { + g.add_edge(DirectedEdge{n.at(i), n.at(j)}); + } + } + } + std::unordered_set sinks = get_terminal_nodes(g); + std::unordered_set sources = get_initial_nodes(g); + Node sink = g.add_node(); + Node source = g.add_node(); + for (Node s : sources) { + g.add_edge(DirectedEdge{source, s}); + } + for (Node s : sinks) { + g.add_edge(DirectedEdge{s, sink}); + } + assert(is_2_terminal_dag(g)); + return g; +} + +} // namespace FlexFlow + +#endif diff --git a/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc new file mode 100644 index 0000000000..f8139d8f1f --- /dev/null +++ b/bin/sp_ization_benchmarking/sp_ization_benchmarking.cc @@ -0,0 +1,565 @@ +/** + * @file sp_ization_benchmarking.cpp + * @brief Benchmarking different SP-ization techniques on various graphs. + * + * @details + * Algorithms: + * - work_duplicating_spization_with_coalescing + * - stratum_sync_sp_ization + * Weight distributions: + * - Constant + * - Uniform(0, 1) + * - Binary(0, 100) + * - Chooser({1.0, 25.0, 500.0}) //sample uniformly from the given weights + * Noise distributions: + * - NoNoise + * - GaussianNoise(1, 0.1) + * - UniformNoise(0.8, 1.25) + * Graph types: + * ... + * + * @note To run the benchmark, go to build/normal/bin/sp_ization_benchmarking, + * run make and then ./sp_ization_benchmarking + */ + +#include "distributions.h" +#include "nasnet_bench_graph_generator.h" +#include "sample_graphs.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/sp_ization/naive_work_duplicating_spization.h" +#include "utils/graph/series_parallel/sp_ization/naive_stratum_sync.h" +#include +#include +#include +#include + +constexpr size_t REPEAT = 500; + +using namespace FlexFlow; +using Result = std::tuple; +using CombinedResult = std::tuple; + +template +CombinedResult perform_benchmark_given_graph(DiGraphView const &g, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + Result work_duplicating = {0, 0, 0}; + Result barrier_sync = {0, 0, 0}; + + for (int i = 0; i < repeat; i++) { + auto cost_map = make_cost_map(get_nodes(g), Dist); + + SeriesParallelDecomposition sp1 = + work_duplicating_spization_with_coalescing(g); + SeriesParallelDecomposition sp2 = stratum_sync_sp_ization(g); + + auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); + + std::get<0>(work_duplicating) += + relative_work_increase(g, sp1, noisy_cost_map); + std::get<1>(work_duplicating) += + relative_critical_path_cost_increase(g, sp1, noisy_cost_map); + std::get<2>(work_duplicating) += + relative_num_dependencies_increase(g, sp1); + + std::get<0>(barrier_sync) += relative_work_increase(g, sp2, noisy_cost_map); + std::get<1>(barrier_sync) += + relative_critical_path_cost_increase(g, sp2, noisy_cost_map); + std::get<2>(barrier_sync) += relative_num_dependencies_increase(g, sp2); + } + + std::vector results = { + work_duplicating, barrier_sync}; + + for (Result &r : results) { + std::get<0>(r) /= repeat; + std::get<1>(r) /= repeat; + std::get<2>(r) /= repeat; + } + + return {results[0], results[1]}; +} + +template +CombinedResult + perform_benchmark_given_graph_generator(G const &graph_generator, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + Result work_duplicating = {0, 0, 0}; + Result barrier_sync = {0, 0, 0}; + + for (int i = 0; i < repeat; i++) { + DiGraphView g = graph_generator(); + auto cost_map = make_cost_map(get_nodes(g), Dist); + + SeriesParallelDecomposition sp1 = + work_duplicating_spization_with_coalescing(g); + SeriesParallelDecomposition sp2 = stratum_sync_sp_ization(g); + + auto noisy_cost_map = add_noise_to_cost_map(cost_map, Noise); + + std::get<0>(work_duplicating) += + relative_work_increase(g, sp1, noisy_cost_map); + std::get<1>(work_duplicating) += + relative_critical_path_cost_increase(g, sp1, noisy_cost_map); + std::get<2>(work_duplicating) += + relative_num_dependencies_increase(g, sp1); + + std::get<0>(barrier_sync) += relative_work_increase(g, sp2, noisy_cost_map); + std::get<1>(barrier_sync) += + relative_critical_path_cost_increase(g, sp2, noisy_cost_map); + std::get<2>(barrier_sync) += relative_num_dependencies_increase(g, sp2); + } + + std::vector results = { + work_duplicating, barrier_sync}; + + for (Result &r : results) { + std::get<0>(r) /= repeat; + std::get<1>(r) /= repeat; + std::get<2>(r) /= repeat; + } + + return {results[0], results[1]}; +} + +void output_benchmark(CombinedResult const &combined_result, + std::string const &title) { + auto [work_dup, stratum_sync] = combined_result; + std::cout << std::fixed << std::setprecision(3); + std::cout << "Benchmark for " << title << std::endl; + std::cout << "Technique | Work-Increase | Critical-Path-Increase | " + "Dependencies-Increase" + << std::endl; + std::cout << "Barrier Sync | " << std::get<0>(stratum_sync) << " | " + << std::get<1>(stratum_sync) << " | " << std::get<2>(stratum_sync) + << std::endl; + std::cout << "Work Duplication | " << std::get<0>(work_dup) << " | " + << std::get<1>(work_dup) << " | " << std::get<2>(work_dup) + << std::endl; + std::cout << std::endl; +} + +template +void bench_mark_given_graph(std::string title, + DiGraphView const &g, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + output_benchmark(perform_benchmark_given_graph(g, Dist, Noise, repeat), + title); +} + +template +void bench_mark_given_graph_generator(std::string title, + G const &generator, + D const &Dist, + N const &Noise = NoNoise(), + size_t repeat = REPEAT) { + output_benchmark( + perform_benchmark_given_graph_generator(generator, Dist, Noise, repeat), + title); +} + +int main() { + { + DiGraph g = make_sample_dag_3(); + bench_mark_given_graph("sample_dag_3, Constant(1)", g, Constant(1)); + bench_mark_given_graph("sample_dag_3, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "sample_dag_3, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "sample_dag_3, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph("sample_dag_3, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("sample_dag_3, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph( + "sample_dag_3, Chooser({1.0, 20.0, 500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "sample_dag_3, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + DiGraph g = make_taso_nasnet_cell(); + bench_mark_given_graph("taso_nasnet_cell, Constant(1)", g, Constant(1)); + bench_mark_given_graph( + "taso_nasnet_cell, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "taso_nasnet_cell, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "taso_nasnet_cell, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("taso_nasnet_cell, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph("taso_nasnet_cell, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "taso_nasnet_cell, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + DiGraph g = make_parallel_chains(10, 5); + bench_mark_given_graph("parallel_chains, Constant(1)", g, Constant(1)); + bench_mark_given_graph( + "parallel_chains, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Uniform(0,1)", g, Uniform(0, 1)); + bench_mark_given_graph( + "parallel_chains, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Binary(1, 80)", g, Binary(1, 80)); + bench_mark_given_graph( + "parallel_chains, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1)); + + bench_mark_given_graph("parallel_chains, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0})); + bench_mark_given_graph( + "parallel_chains, Chooser({1.0, 20.0, 500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25)); + bench_mark_given_graph( + "parallel_chains, Chooser({1.0, 20.0, 500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1)); + } + + { + + auto generate_2_terminal_random_dag = []() { + return make_2_terminal_random_dag(60, .12, 5); + }; + size_t repeat = 100; + bench_mark_given_graph_generator("make_2_terminal_random_dag, Constant(1)", + generate_2_terminal_random_dag, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Constant(1), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Constant(1), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator("make_2_terminal_random_dag, Uniform(0,1)", + generate_2_terminal_random_dag, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Uniform(0,1), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Uniform(0,1), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80)", + generate_2_terminal_random_dag, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80), UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Binary(1, 80), GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0})", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "make_2_terminal_random_dag, Chooser({1.0, 20.0, 500.0}), " + "GaussianNoise(1, 0.1)", + generate_2_terminal_random_dag, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } + + { + size_t repeat = 100; + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1)", + generate_nasnet_bench_network, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Constant(1), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1)", + generate_nasnet_bench_network, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Uniform(0,1), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80)", + generate_nasnet_bench_network, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80), UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Binary(1, 80), GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0})", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0}), " + "UniformNoise(0.8, 1.25)", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph_generator( + "generate_nasnet_bench_network, Chooser({1.0, 20.0, 500.0}), " + "GaussianNoise(1, 0.1)", + generate_nasnet_bench_network, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } + + { + size_t repeat = 10; + DiGraph g = make_full_taso_nasnet(1, 1); + bench_mark_given_graph("make_full_taso_nasnet, Constant(1)", + g, + Constant(1), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Constant(1), UniformNoise(0.8, 1.25)", + g, + Constant(1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Constant(1), GaussianNoise(1, 0.1)", + g, + Constant(1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Uniform(0,1)", + g, + Uniform(0, 1), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Uniform(0,1), UniformNoise(0.8, 1.25)", + g, + Uniform(0, 1), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Uniform(0,1), GaussianNoise(1, 0.1)", + g, + Uniform(0, 1), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Binary(1, 80)", + g, + Binary(1, 80), + NoNoise(), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Binary(1, 80), UniformNoise(0.8, 1.25)", + g, + Binary(1, 80), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph( + "make_full_taso_nasnet, Binary(1, 80), GaussianNoise(1, 0.1)", + g, + Binary(1, 80), + GaussianNoise(1, 0.1), + repeat); + + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, 500.0})", + g, + Chooser({1.0, 20.0, 500.0}), + NoNoise(), + repeat); + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, " + "500.0}), UniformNoise(0.8, 1.25)", + g, + Chooser({1.0, 20.0, 500.0}), + UniformNoise(0.8, 1.25), + repeat); + bench_mark_given_graph("make_full_taso_nasnet, Chooser({1.0, 20.0, " + "500.0}), GaussianNoise(1, 0.1)", + g, + Chooser({1.0, 20.0, 500.0}), + GaussianNoise(1, 0.1), + repeat); + } +} diff --git a/lib/utils/include/utils/containers/invert_map.h b/lib/utils/include/utils/containers/invert_map.h new file mode 100644 index 0000000000..6f0c04a189 --- /dev/null +++ b/lib/utils/include/utils/containers/invert_map.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INVERT_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_INVERT_MAP_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_map> + invert_map(std::unordered_map const &m) { + std::unordered_map> result; + for (auto const &[key, value] : m) { + result[value].insert(key); + } + return result; +} +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_set_of.h b/lib/utils/include/utils/containers/unordered_set_of.h index 74c7683460..5a8dcd9f55 100644 --- a/lib/utils/include/utils/containers/unordered_set_of.h +++ b/lib/utils/include/utils/containers/unordered_set_of.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNIQUE_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_SET_OF_H #include "utils/hash/pair.h" #include diff --git a/lib/utils/include/utils/graph/algorithms.h b/lib/utils/include/utils/graph/algorithms.h index 4fdec740f8..c2e969d13a 100644 --- a/lib/utils/include/utils/graph/algorithms.h +++ b/lib/utils/include/utils/graph/algorithms.h @@ -59,6 +59,7 @@ std::vector std::vector get_bfs_ordering(DiGraphView const &, std::unordered_set const &starting_points); + std::vector get_unchecked_topological_ordering(DiGraphView const &); std::unordered_set diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_ancestors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_ancestors.h new file mode 100644 index 0000000000..77f756f11b --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_ancestors.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_ANCESTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_ANCESTORS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { +/** + * @brief Computes the set of all ancestors of a given node `n` in a directed + * graph, which is the set of all nodes `m` for which a directed path from `m` to + * `n` exists. + * + * @note `n` is not considered to be its own ancestor, and is thus not + * included in the returned set. + **/ +std::unordered_set get_ancestors(DiGraphView const &g, Node const &n); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_bottlenecks.h b/lib/utils/include/utils/graph/digraph/algorithms/get_bottlenecks.h new file mode 100644 index 0000000000..69eb435144 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_bottlenecks.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_BOTTLENECKS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_BOTTLENECKS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +/** + * @brief Returns the bottlenecks of the graph. + * + * A bottleneck is a node through which all paths from any sink to any source + must pass. + + * @note + * The graph must be acyclic and singly connected. + * Note that, under the definition of bottleneck, a source / sink is a + bottleneck if and only if it's the unique source / sink of the graph. + */ +std::unordered_set get_bottlenecks(DiGraphView const &g); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h b/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h new file mode 100644 index 0000000000..2e4c9eb5a3 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_descendants.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_DESCENDANTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_DESCENDANTS_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +/** + * @brief Computes the set of all descendants of a given node in a directed + * graph. + * + * @note `starting_node` is not considered to be its own descendant, and is thus + * not included in the returned set. + **/ +std::unordered_set get_descendants(DiGraphView const &g, + Node const &starting_node); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h b/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h new file mode 100644 index 0000000000..16ab2798a5 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LONGEST_PATH_LENGTHS_FROM_ROOT_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LONGEST_PATH_LENGTHS_FROM_ROOT_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +/** + * @brief Computes the longest path lengths from the root in directed acyclic + * graph. + * + * @return std::unordered_map For each node n, returns the length + * (i.e. number of nodes) of the longest path from the root to n. + * + * @note The root has a path length of 1. g must be acyclic. + */ +std::unordered_map + get_longest_path_lengths_from_root(DiGraphView const &g); + +/** + * @brief Computes the weighted longest path lengths from the root in a directed + * acyclic graph. + * + * @return std::unordered_map For each node n, returns the length + * (i.e. the sum of the weights of all the nodes) of the longest path from the + * root to n. + * + * @note The root has a path length equal to its weight. g must be acyclic. + */ +std::unordered_map get_weighted_longest_path_lengths_from_root( + DiGraphView const &g, std::unordered_map const &node_costs); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_lowest_common_ancestors.h b/lib/utils/include/utils/graph/digraph/algorithms/get_lowest_common_ancestors.h new file mode 100644 index 0000000000..60b0d32ae2 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_lowest_common_ancestors.h @@ -0,0 +1,41 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LOWEST_COMMON_ANCESTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_LOWEST_COMMON_ANCESTORS_H + +#include "utils/graph/digraph/digraph_view.h" +#include + +namespace FlexFlow { + +/** + * @brief Finds the lowest common ancestor (LCA) of a set of nodes in a directed + * graph. + * + * @details + * Within this function, we consider the set of ancestors of a given node to + * include the node itself, so the lowest common ancestor of a set of nodes can + * be contained in the input set of nodes itself. + * + * For example, consider the following directed graph: + * + * ``` + * digraph { + * 0 -> 1; + * 0 -> 2; + * 1 -> 3; + * 1 -> 4; + * } + * ``` + * + * The lowest common ancestor of nodes 3 and 1 is 1. + * + * @note + * In a Directed Acyclic Graph, a set of nodes can have no LCA, a unique node as + * LCA, or a set of nodes as LCA. + */ +std::optional> + get_lowest_common_ancestors(DiGraphView const &g, + std::unordered_set const &nodes); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h b/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h new file mode 100644 index 0000000000..db017c11da --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/node.dtg.h" + +namespace FlexFlow { + +/** + * @brief Returns a topologically ordered vector of nodes, with the topological + * traversal starting from the starting node. + * + * @note Nodes present within the graph that are not reachable by a traversal + * starting from the starting_node will not be included in the returned vector. + * g must be an acyclic graph + */ +std::vector + get_topological_ordering_from_starting_node(DiGraphView const &g, + Node const &starting_node); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_GET_TOPOLOGICAL_ORDERING_FROM_STARTING_NODE_H diff --git a/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h b/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h new file mode 100644 index 0000000000..3b588c7984 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/is_2_terminal_dag.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_2_TERMINAL_DAG_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_2_TERMINAL_DAG_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_2_terminal_dag(DiGraphView const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h b/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h index 909dc3aef4..ce63f75395 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/is_acyclic.h @@ -5,7 +5,7 @@ namespace FlexFlow { -std::optional is_acyclic(DiGraphView const &); +bool is_acyclic(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/digraph/algorithms/is_tree.h b/lib/utils/include/utils/graph/digraph/algorithms/is_tree.h new file mode 100644 index 0000000000..340d54aab4 --- /dev/null +++ b/lib/utils/include/utils/graph/digraph/algorithms/is_tree.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_TREE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_IS_TREE_H + +#include "utils/graph/digraph/digraph_view.h" + +namespace FlexFlow { + +bool is_tree(DiGraphView const &g); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h b/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h index b4cdc62f83..ad11c6388c 100644 --- a/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h +++ b/lib/utils/include/utils/graph/digraph/algorithms/transitive_reduction.h @@ -1,6 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_REDUCTION_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_DIGRAPH_ALGORITHMS_TRANSITIVE_REDUCTION_H +#include "utils/graph/digraph/digraph.h" #include "utils/graph/digraph/digraph_view.h" namespace FlexFlow { @@ -21,7 +22,7 @@ struct DirectedEdgeMaskView final : public IDiGraphView { std::unordered_set edge_mask; }; -DiGraphView transitive_reduction(DiGraphView const &); +DiGraph transitive_reduction(DiGraphView const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/series_parallel/digraph_generation.h b/lib/utils/include/utils/graph/series_parallel/digraph_generation.h new file mode 100644 index 0000000000..a3a9edba80 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/digraph_generation.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_DIGRAPH_GENERATION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_DIGRAPH_GENERATION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +std::unordered_map parallel_extend(DiGraph &g, + DiGraphView const &ext); +std::unordered_map serial_extend(DiGraph &g, + DiGraphView const &ext); +DiGraph series_composition(DiGraphView const &g1, DiGraphView const &g2); +DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2); +DiGraph series_composition(std::vector const &graphs); +DiGraph parallel_composition(std::vector const &graphs); + +/** + * @brief Constructs a directed DiGraph from a series-parallel decomposition. + * + * @details The transformation is performed recursively as follows: + * - Nodes in the decomposition remain the same in the resulting graph (but the node ids are fresh) + * - For serial composition between graphs, an all-to-all connection is created + * between the terminal nodes of one graph and the initial nodes of the + * following one. + * - For parallel composition between graphs, the union of the graphs is taken + * without adding any additional edges. + * + */ +DiGraph digraph_from_sp_decomposition(SeriesParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h new file mode 100644 index 0000000000..f8e9f52cb5 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h @@ -0,0 +1,49 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H + +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +/** + * @brief For a given node \p node and a series-parallel decomposition \p sp, + * returns the set of nodes that are ancestors of the given node. We require + * that \p node is a node in \p sp. + * + * @details The ancestors are + * computed recursively as follows: + * + * If \p sp is a single node, then the ancestors are always empty. + * + * If \p sp = S(a_1, ..., a_n), where a_i are series-parallel structures, + * and suppose that the node \p node is in a_j. Then: + * - All the nodes in a_1, ..., a_{j-1} are ancestors of \p node. + * - There is some subset of nodes of a_j that are ancestors of \p node + * (which we recursively compute). + * - The nodes in a_{j+1}, ..., a_n are NOT ancestors of \p node. + * + * If \p sp = P(a_1, ..., a_n), where a_i are series-parallel structures, + * then there is exactly one branch a_j of \p sp that contains \p node. + * All the other branches are not ancestors of \p node (since they are + * parallel to it). So we recursively compute the ancestors of \p node + * within a_j. + * + * @example + * For sp = S(n0, P(S(n1, n2), n3), n4, n5): + * + * node | ancestors + * -----|---------- + * n0 | {} + * n1 | {n0} + * n2 | {n0, n1} + * n3 | {n0} + * n4 | {n0, n1, n2, n3} + * n5 | {n0, n1, n2, n3, n4} + * + */ +std::unordered_set get_ancestors(SeriesParallelDecomposition const &sp, + Node const &node); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml new file mode 100644 index 0000000000..b4b975bb4e --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "NonNormalParallelSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct NonNormalSeriesSplit" +] + +post_includes = [ + "utils/graph/series_parallel/non_normal_series_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/unordered_multiset.h", + "utils/hash/unordered_multiset.h", +] + +[[fields]] +name = "children" +type = "std::unordered_multiset>" +indirect = true + diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml new file mode 100644 index 0000000000..008e58dc3f --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "NonNormalSeriesSplit" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +fwd_decls = [ + "struct NonNormalParallelSplit" +] + +post_includes = [ + "utils/graph/series_parallel/non_normal_parallel_split.dtg.h", +] + +includes = [ + "", + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/variant.h", + "utils/fmt/vector.h", + "utils/hash/vector.h", +] + +[[fields]] +name = "children" +type = "std::vector>" + diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml new file mode 100644 index 0000000000..c82e771385 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "NonNormalSPDecomposition" +type = "variant" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "utils/graph/series_parallel/non_normal_parallel_split.dtg.h", + "utils/graph/series_parallel/non_normal_series_split.dtg.h", + "utils/graph/node/node.dtg.h", +] + +[[values]] +type = "::FlexFlow::NonNormalSeriesSplit" + +[[values]] +type = "::FlexFlow::NonNormalParallelSplit" + +[[values]] +type = "::FlexFlow::Node" + diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.h b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.h new file mode 100644 index 0000000000..eeb4590c79 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_NON_NORMAL_SP_DECOMPOSITION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_NON_NORMAL_SP_DECOMPOSITION_H + +#include "utils/graph/series_parallel/non_normal_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include +#include + +namespace FlexFlow { + +bool is_empty_non_normal(NonNormalSPDecomposition const &sp); + +NonNormalSPDecomposition non_normal_series_composition( + std::vector const &sp_compositions); + +NonNormalSPDecomposition non_normal_parallel_composition( + std::unordered_multiset const &sp_compositions); + +NonNormalSPDecomposition as_non_normal(SeriesParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h b/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h new file mode 100644 index 0000000000..c9f40b8f4b --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/normalize_sp_decomposition.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_normalize_sp_decomposition_H + +#include "utils/graph/series_parallel/non_normal_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +/** + * @brief Normalizes a series-parallel decomposition into its canonical form. + * + * @details A normalized series-parallel decomposition satisfies the following + * invariants: + * - No empty SeriesSplit or ParallelSplit nodes (i.e., nodes with zero children) + * - No SeriesSplit or ParallelSplit nodes with exactly one child + * (these are replaced by their child) + * + * These invariants ensure a unique canonical representation for any given + * series-parallel structure. + * + * Examples: + * - S(P(S()), Node(1), Node(2)) -> S(Node(1), Node(2)) + * - S(S(Node(1)), P(Node(2))) -> S(Node(1), Node(2)) + * + */ +SeriesParallelDecomposition + normalize_sp_decomposition(NonNormalSPDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h index b3fc201ca5..06db05b8aa 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.h @@ -3,6 +3,7 @@ #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.dtg.h" #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/nonnegative_int/nonnegative_int.h" #include namespace FlexFlow { @@ -17,18 +18,14 @@ std::unordered_multiset get_nodes(SeriesSplit const &); std::unordered_multiset get_nodes(ParallelSplit const &); std::unordered_multiset get_nodes(Node const &); -bool is_empty(Node const &node); -bool is_empty(SeriesSplit const &serial); -bool is_empty(ParallelSplit const ¶llel); -bool is_empty(SeriesParallelDecomposition const &sp); - bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp); -SeriesParallelDecomposition delete_node(SeriesParallelDecomposition sp, - Node const &node); - -// duplicate nodes within `sp` are counted multiple times -size_t num_nodes(SeriesParallelDecomposition const &sp); +/** + * @brief Counts the total number of nodes in a series-parallel decomposition + * @note Nodes that appear multiple times in the decomposition are counted + * multiple times + */ +nonnegative_int num_nodes(SeriesParallelDecomposition const &sp); SeriesParallelDecomposition series_composition( std::vector const &sp_compositions); diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h b/lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h new file mode 100644 index 0000000000..935b8a9e52 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_metrics.h @@ -0,0 +1,73 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_METRICS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_METRICS_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +/** + * @brief Maps each node to the number of times it appears in the decomposition. + * + */ +std::unordered_map + get_num_occurrences_of_nodes(SeriesParallelDecomposition const &sp); + +/** + * @brief Calculates the total cumulative cost of all nodes in the + * decomposition. + * + */ +float work_cost(SeriesParallelDecomposition const &sp, + std::unordered_map cost_map); + +float work_cost(DiGraphView const &g, + std::unordered_map const &cost_map); + +/** + * @brief Computes the total number of edges the decomposition has when viewed + * as a DiGraph where Series connections are all to all. + * + */ +nonnegative_int num_dependencies(SeriesParallelDecomposition const &sp); + +nonnegative_int num_dependencies(DiGraphView const &g); + +float critical_path_cost(SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +float critical_path_cost(DiGraphView const &g, + std::unordered_map const &cost_map); + +/** + * @brief Calculates the relative increase in total work cost between the + * original (possibly non-series-parallel) graph and a possible series-parallel + * decomposition of that graph. + */ +float relative_work_increase(DiGraphView const &g, + SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +/** + * @brief Calculates the relative increase in critical path cost between the + * original (possibly non-series-parallel) graph and a possible series-parallel + * decomposition of that graph. + */ +float relative_critical_path_cost_increase( + DiGraphView const &g, + SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map); + +/** + * @brief Calculates the relative increase in the number of dependencies between + * the original (possibly non-series-parallel) graph and a possible + * series-parallel decomposition of that graph. + */ +float relative_num_dependencies_increase(DiGraphView const &g, + SeriesParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif // _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIES_PARALLEL_METRICS_H diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/README.md b/lib/utils/include/utils/graph/series_parallel/sp_ization/README.md new file mode 100644 index 0000000000..3630c052b8 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/README.md @@ -0,0 +1,196 @@ +# SP-ization + +As a refresher, a series-parallel decomposition (SPD) is an algebraic datatype that looks as follows: +```haskell +data ParallelChild = Node | Series SeriesChild SeriesChild +data SeriesChild = Node | Parallel ParallelChild ParallelChild + +data SPD = Series SeriesChild SeriesChild | + Parallel ParallelChild ParallelChild | + Node +``` + +SP-ization is the process of transforming a DAG into a series-parallel decomposition (SPD), such that the dependencies present in the original DAG are preserved in the SPD. +Each node in the SPD may optionally have an associated cost, which is a positive scalar value. + +Note that SP-ization is itself trivial (e.g. a topological sort of the nodes is an SPD). But we generally care about preserving as much of the structure / parallelism of the original DAG as possible. So, there are 2 properties that we care about: +1. **Work**: the sum of cost of nodes in the SPD +2. **Critical Path Cost**: the cost of the longest path in the SPD + +We have 2 main ways of achieving this: +1. **Work (Node) Duplicating SP-ization**: preserves the critical path, but may duplicate nodes +2. **Dependency (Edge) Addition SP-ization**: preserves the set of nodes, but may add edges + +## Node (Work) Duplicating SP-ization + +### Naive ([work_duplicating_spization.h](work_duplicating_spization.h)) + +Transforms a directed acyclic graph (DAG) into a Series Parallel (SP) graph. The critical path cost is unchanged, and the SP-ization is done solely through node duplication. + +The resulting graph, encoded as a SeriesParallelDecomposition, is a tree whose critical path is the same as that of the original graph. The tree is constructed as follows: +- Denote SP(n) as the SeriesParallelDecomposition of the subgraph of g whose nodes are all the ancestors of n. +- Denote the predecessors of n as M. +- Then: + - SP(n) = S(n, P({SP(m) for m in M})) + - SP(root) = root + - SP(sink) = SP(g) +Where P, S represent parallel, serial composition respectively. + +Example: +```dot +digraph G { + n1 -> n2; + n1 -> n3; + n2 -> n4; + n2 -> n5; + n3 -> n5; + n5 -> n6; + n4 -> n6; +} +``` +becomes +```dot +digraph SP { + n1 -> n2; + n2 -> n3; + n3 -> n4; + n5 -> n6; + n6 -> n7; + n7 -> n4; + n8 -> n9; + n9 -> n7; +} +``` + +We can roughly think of it as the parallel composition of all the possible paths from source to sink. + +### With Coalescing ([work_duplicating_spization.h](work_duplicating_spization.h)) + +Transforms a directed acyclic graph (DAG) into a Series Parallel (SP) graph with coalescing. The critical path cost is unchanged, and the SP-ization is done solely through node (work) duplication. + +This SP-ization technique, compared to the previous step, adds an additional coalescing step during parallel composition to reduce node duplication. The recursive formulation is equivalent, but the parallelization performs an additional coalescing step, where parallel strands with common heads are merged together. Example: P(S(1,2), S(1,3)) -> P(1, S(2,3)). + +Example: +```dot +digraph G { + n1 -> n2; + n1 -> n3; + n2 -> n4; + n2 -> n5; + n3 -> n5; + n5 -> n6; + n4 -> n6; +} +``` +becomes +```dot +digraph SP { + n1 -> n2; + n2 -> n3; + n3 -> n4; + n1 -> n6; + n6 -> n7; + n7 -> n4; + n1 -> n9; + n9 -> n7; +} +``` + +## Dependency Addition SP-ization + +### Naive Stratum Sync ([naive_stratum_sync.h](naive_stratum_sync.h)) + +Transforms a directed acyclic graph (DAG) into a Series Parallel (SP) graph. The total number of nodes remains unchanged, and the SP-ization is done solely through edge (dependency) addition. + +The graph is first partitioned into strata: the i_th stratum contains all the nodes whose critical path length has length i. The nodes in a given stratum are composed in parallel, and the strata are serially composed in succession. + +Example: +```dot +digraph G { + n1 -> n2; + n1 -> n3; + n2 -> n4; + n2 -> n5; + n3 -> n5; + n5 -> n6; + n4 -> n6; +} +``` +becomes +```dot +digraph SP { + n1 -> n2; + n1 -> n3; + n2 -> n4; + n2 -> n5; + n3 -> n5; + n4 -> n6; + n5 -> n6; +} +``` + +### Escribano Algorithm ([escribano_algo.h](escribano_algo.h)) + +Paper is present here: https://www.infor.uva.es/wp-content/uploads/2016/10/IT-DI-2002-0002.pdf. +In the naive stratum sync algorithm, we add an all-to-all connection between all nodes in one stratum and the next. The escribano algorithm by contrast, leverages the fact that it might be possible to synchronize consecutive strata by adding smaller, more local connections that still yield a valid SP-ization graph. + +Example: +```dot +digraph G { + 0 -> 1; + 0 -> 2; + 0 -> 3; + 1 -> 4; + 1 -> 5; + 2 -> 5; + 3 -> 6; + 4 -> 7; + 5 -> 7; + 6 -> 7; +} +``` + +The strata are: {0}, {1, 2, 3}, {4, 5, 6}, {7}. + +The naive stratum sync yields the following, adding an all-to-all connection between consecutive strata: +``` +S(0, P(1, 2, 3), P(4, 5, 6), 7) +``` + +While the escribano algorithm is able to identify that strata 1 and 2 can be synced without adding an all-to-all connection: nodes {1, 2} only connect to {4, 5}, and node {3} only connects to {6}. It thus yields the following: +``` +S(0, P(S(P(1, 2), P(4, 5)), S(3, 6)), 7) +``` + +Our implementation, rather than building the SPD one stratum at a time, builds it one node at a time. + +### Flexible Algorithm ([flexible_algo.h](flexible_algo.h)) + +Consider the following N-graph: + +```dot +digraph N { + a -> c; + a -> d; + b -> d; +} +``` + +Note that there are multiple valid SP-izations for this. +1) S(P(a, b), P(c, d)) — adds edge (b, c) +2) S(P(S(a, c), b), d) — adds edge (c, d) +3) S(a, P(c, S(b, d))) — adds edge (a, b) +(you could also simply turn the graph into a straight line, but it's a strictly worse SP-ization than the others present here, so we'll ignore it) + +Depending on the cost map, each option can potentially be the best: +- {a:1, b:1, c:1, d:1}: SP1 is optimal (CP=2), SP2 and SP3 worsen it (CP=3) +- {a:1, b:3, c:2, d:1}: SP2 is optimal (CP=4), SP1 and SP3 worsen it (CP=5) +- {a:1, b:2, c:3, d:1}: SP3 is optimal (CP=4), SP1 and SP2 worsen it (CP=5) + +Thus, even for this simple graph, the best SP-ization depends on the cost map. + +The flexible algorithm expands the escribano algorithm by generalizing it to such weighted DAGs. + +In the escribano algorithm, once the sync area (the "forest") is identified, the partition into up and down sets is fixed: up is everything but the last layer, down is the last layer. But this is an arbitrary choice; there are multiple valid ways to partition the forest into an up set and a down set (across which we sync). + +The flexible algorithm exploits this by searching across all valid up/down partitions of the forest and selecting the one that minimizes the sum of critical path costs of the up and down subgraphs (i.e., the critical path cost of the resulting SP-ized subgraph after the sync). diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h new file mode 100644 index 0000000000..a71bacc05a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_IS_VALID_SP_IZATION_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_IS_VALID_SP_IZATION_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { +/** + * @brief Checks if dependencies are maintained between a directed graph and its + * series-parallel decomposition. + * + * @details This function ensures that the series-parallel decomposition is a + * valid sp-ization of the given directed graph, by checking that dependencies + * are maintained. Dependencies are considered maintained if: + * - Both the directed graph and the series-parallel decomposition contain the + * same set of nodes. + * - For every node in the directed graph, all its ancestors are also ancestors + * within the series-parallel decomposition. + * + */ +bool dependencies_are_maintained(DiGraphView const &g, + SeriesParallelDecomposition const &sp); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/escribano_algo.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/escribano_algo.h new file mode 100644 index 0000000000..b50a1d63eb --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/escribano_algo.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_ESCRIBANO_ALGO_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_ESCRIBANO_ALGO_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include +namespace FlexFlow { + +DiGraph add_dummy_nodes(DiGraph g, + std::unordered_map &node_roles); + +std::unordered_set + get_component(DiGraph const &g, + Node const &node, + std::unordered_map const &depth_map, + std::unordered_map const &node_roles); + +/** + * @brief See @ref lib/utils/include/utils/graph/series_parallel/sp_ization/README.md "README.md" for explanation. + */ +SeriesParallelDecomposition escribano_sp_ization(DiGraph g); + +} // namespace FlexFlow + +#endif + diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/flexible_algo.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/flexible_algo.h new file mode 100644 index 0000000000..7065f09ddf --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/flexible_algo.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_FLEXIBLE_ALGO_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_FLEXIBLE_ALGO_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include "utils/graph/series_parallel/sp_ization/up_down_partition.dtg.h" +#include +#include + +namespace FlexFlow { + +/** + * @brief See @ref lib/utils/include/utils/graph/series_parallel/sp_ization/README.md "README.md" for explanation. + */ +SeriesParallelDecomposition + flexible_sp_ization(DiGraphView const &g, + std::unordered_map const &cost_map); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/naive_stratum_sync.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/naive_stratum_sync.h new file mode 100644 index 0000000000..d7ac00a563 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/naive_stratum_sync.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_NAIVE_STRATUM_SYNC_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_NAIVE_STRATUM_SYNC_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +namespace FlexFlow { + +/** + * @brief See @ref lib/utils/include/utils/graph/series_parallel/sp_ization/README.md "README.md" for explanation. + **/ +SeriesParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g); + +} // namespace FlexFlow + +#endif + diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.dtg.toml b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.dtg.toml new file mode 100644 index 0000000000..47c35412d5 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.dtg.toml @@ -0,0 +1,32 @@ +# Enum to categorize the role of each node in the 2 following spization algorithms: +# - escribano_algo +# - flexible_sync +# PURE: Original nodes from the input graph. +# SYNC: Synchronization nodes added by the algorithm. A sync node s expresses an +# all-to-all connection: all "up" nodes have an edge to s, and s has an edge to +# all "down" nodes. These nodes are contracted out at the end of the algorithm. +# DUMMY: Needed for the node-by-node version of the escribano algorithm. These nodes +# are added such that, for every edge (n1, n2), the difference in depth between +# n1 and n2 is at most 1 (we do this by taking all the edges where the difference +# in depth between the 2 nodes is greater than 1, and we "break up" that edge +# into a chain of dummy nodes). These nodes are contracted out at the end of the +# algorithm. + +namespace = "FlexFlow" +name = "NodeRole" +type = "enum" +features = [ + "hash", + "json", + "rapidcheck", + "fmt", +] + +[[values]] +name = "PURE" + +[[values]] +name = "SYNC" + +[[values]] +name = "DUMMY" diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.h new file mode 100644 index 0000000000..b8b8e5fe3a --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/node_role.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_NODE_ROLE_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_NODE_ROLE_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include + +namespace FlexFlow { + +std::unordered_map + get_initial_node_role_map(DiGraphView const &g); + +/** + * @brief Contracts out nodes of a given role from the graph. + * + * When a node is contracted out, all its predecessors are connected to all + * its successors, and then the node is removed. + */ +DiGraph contract_out_nodes_of_given_role( + DiGraph g, + NodeRole const &role, + std::unordered_map const &node_roles); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.dtg.toml new file mode 100644 index 0000000000..8eeb7a7d40 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "UpDownPartition" +type = "struct" +features = [ + "eq", + "hash", + "fmt", +] + +includes = [ + "", + "utils/graph/node/node.dtg.h", +] + +src_includes = [ + "utils/fmt/unordered_set.h", + "utils/hash/unordered_set.h", +] + +[[fields]] +name = "up" +type = "std::unordered_set<::FlexFlow::Node>" + +[[fields]] +name = "down" +type = "std::unordered_set<::FlexFlow::Node>" + diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.h new file mode 100644 index 0000000000..7458ae3c51 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/up_down_partition.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_UP_DOWN_PARTITION_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_SP_IZATION_UP_DOWN_PARTITION_H + +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/series_parallel/sp_ization/up_down_partition.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief Returns the nodes in the up set that have no outgoing edges within + * the up subgraph. + */ +std::unordered_set get_up_frontier(DiGraph const &sp, + UpDownPartition const &partition); + +/** + * @brief Returns the nodes in the down set that have no incoming edges within + * the down subgraph. + */ +std::unordered_set get_down_frontier(DiGraph const &sp, + UpDownPartition const &partition); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/series_parallel/sp_ization/work_duplicating_spization.h b/lib/utils/include/utils/graph/series_parallel/sp_ization/work_duplicating_spization.h new file mode 100644 index 0000000000..b5d8ecc128 --- /dev/null +++ b/lib/utils/include/utils/graph/series_parallel/sp_ization/work_duplicating_spization.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_WORK_DUPLICATING_SPIZATION_H +#define _FLEXFLOW_UTILS_GRAPH_SERIES_PARALLEL_WORK_DUPLICATING_SPIZATION_H + +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +namespace FlexFlow { + +/** + * @brief See @ref lib/utils/include/utils/graph/series_parallel/sp_ization/README.md "README.md" for explanation. + */ +SeriesParallelDecomposition naive_work_duplicating_spization(DiGraphView const &g); + +/** + * @brief See @ref lib/utils/include/utils/graph/series_parallel/sp_ization/README.md "README.md" for explanation. + */ +SeriesParallelDecomposition + work_duplicating_spization_with_coalescing(DiGraphView const &g); + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/src/utils/containers/invert_map.cc b/lib/utils/src/utils/containers/invert_map.cc new file mode 100644 index 0000000000..699503ff96 --- /dev/null +++ b/lib/utils/src/utils/containers/invert_map.cc @@ -0,0 +1,11 @@ +#include "utils/containers/invert_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_map> + invert_map(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/algorithms.cc b/lib/utils/src/utils/graph/algorithms.cc index a6a9ca0ae2..85bd33bb9f 100644 --- a/lib/utils/src/utils/graph/algorithms.cc +++ b/lib/utils/src/utils/graph/algorithms.cc @@ -11,6 +11,7 @@ #include "utils/graph/digraph/algorithms/get_incoming_edges.h" #include "utils/graph/digraph/algorithms/get_node_with_greatest_topo_rank.h" #include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/digraph/directed_edge_query.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/node/node_query.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_ancestors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_ancestors.cc new file mode 100644 index 0000000000..96a34e6f0b --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_ancestors.cc @@ -0,0 +1,12 @@ +#include "utils/graph/digraph/algorithms/get_ancestors.h" +#include "utils/graph/digraph/algorithms/flipped.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" + +namespace FlexFlow { +std::unordered_set get_ancestors(DiGraphView const &g, + Node const &starting_node) { + assert(is_acyclic(g)); + return get_descendants(flipped(g), starting_node); +} +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_bottlenecks.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_bottlenecks.cc new file mode 100644 index 0000000000..84e6e2a57b --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_bottlenecks.cc @@ -0,0 +1,34 @@ +#include "utils/graph/digraph/algorithms/get_bottlenecks.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_only.h" +#include "utils/containers/set_difference.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { +std::unordered_set get_bottlenecks(DiGraphView const &g) { + ASSERT(is_acyclic(g)); + ASSERT(get_weakly_connected_components(g).size() == + 1); // must be singly connected + + std::unordered_set bottlenecks = filter(get_nodes(g), [&](Node const &n) { + DiGraphView subgraph = get_subgraph(g, set_difference(get_nodes(g), {n})); + return get_weakly_connected_components(subgraph).size() == 2; + }); + + if (get_initial_nodes(g).size() == 1) { + bottlenecks.insert(get_only(get_initial_nodes(g))); + } + + if (get_terminal_nodes(g).size() == 1) { + bottlenecks.insert(get_only(get_terminal_nodes(g))); + } + + return bottlenecks; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc new file mode 100644 index 0000000000..7abe5b6616 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_descendants.cc @@ -0,0 +1,19 @@ +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/containers/contains.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { +std::unordered_set get_descendants(DiGraphView const &g, + Node const &starting_node) { + assert(is_acyclic(g)); + assert(contains(get_nodes(g), starting_node)); + + return unordered_set_of(get_bfs_ordering(g, get_successors(g, starting_node))); +}; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index a81158b3e5..2da1b208f4 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -9,6 +9,7 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include #include namespace FlexFlow { diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_initial_nodes.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_initial_nodes.cc index 71a6e18cbc..6be4f8a592 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_initial_nodes.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_initial_nodes.cc @@ -1,5 +1,6 @@ #include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/containers/set_minus.h" +#include "utils/containers/transform.h" #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc new file mode 100644 index 0000000000..036d320f05 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc @@ -0,0 +1,58 @@ +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/containers/all_of.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include + +namespace FlexFlow { + +std::unordered_map get_weighted_longest_path_lengths_from_root( + DiGraphView const &g, std::unordered_map const &node_costs) { + + assert(is_acyclic(g)); + assert(all_of(values(node_costs), [&](float cost) { return cost >= 0; })); + + std::vector topo_order = get_topological_ordering(g); + std::unordered_map longest_path_lengths; + + for (Node const &n : topo_order) { + std::unordered_set predecessor_path_lengths = + transform(get_predecessors(g, n), [&](Node const &pred) { + return longest_path_lengths.at(pred); + }); + longest_path_lengths[n] = + (predecessor_path_lengths.size() == 0) + ? node_costs.at(n) + : maximum(predecessor_path_lengths) + node_costs.at(n); + } + return longest_path_lengths; +} + +std::unordered_map + get_longest_path_lengths_from_root(DiGraphView const &g) { + + assert(is_acyclic(g)); + + std::vector topo_order = get_topological_ordering(g); + std::unordered_map longest_path_lengths; + + for (Node const &n : topo_order) { + std::unordered_set predecessor_path_lengths = + transform(get_predecessors(g, n), [&](Node const &pred) { + return longest_path_lengths.at(pred); + }); + nonnegative_int new_value = (predecessor_path_lengths.size() == 0) + ? 1_n + : maximum(predecessor_path_lengths) + 1_n; + + longest_path_lengths.emplace(n, new_value); + } + + return longest_path_lengths; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc new file mode 100644 index 0000000000..0d5705854d --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc @@ -0,0 +1,46 @@ +#include "utils/containers/filter.h" +#include "utils/containers/intersection.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/maximum.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/graph/digraph/algorithms/get_ancestors.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/node/algorithms.h" +#include "utils/hash/unordered_set.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +namespace FlexFlow { + +std::optional> + get_lowest_common_ancestors(DiGraphView const &g, + std::unordered_set const &nodes) { + ASSERT(is_acyclic(g)); + ASSERT(is_subseteq_of(nodes, get_nodes(g))); + if (num_nodes(g) == 0 || nodes.size() == 0) { + return std::nullopt; + } + std::unordered_set> ancestors = + transform(nodes, [&](Node const &n) { + return set_union(get_ancestors(g, n), {n}); + }); + std::unordered_set common_ancestors = intersection(ancestors).value(); + + if (common_ancestors.empty()) { + return std::unordered_set{}; + } + + std::unordered_map depth_levels = + get_longest_path_lengths_from_root(g); + + nonnegative_int largest_depth_for_common_ancestors = maximum(transform( + common_ancestors, [&](Node const &n) { return depth_levels.at(n); })); + + return filter(common_ancestors, [&](Node const &n) { + return depth_levels.at(n) == largest_depth_for_common_ancestors; + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc index 8df16efe4f..58ae8de110 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering.cc @@ -1,6 +1,9 @@ #include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/containers/contains.h" +#include "utils/exception.h" #include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/digraph/algorithms/is_acyclic.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/traversal.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc new file mode 100644 index 0000000000..590548af66 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_topological_ordering_from_starting_node.cc @@ -0,0 +1,28 @@ +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/traversal.h" + +namespace FlexFlow { + +static std::vector get_unchecked_topological_ordering_from_starting_node( + DiGraphView const &g, Node const &starting_node) { + + std::unordered_set descendants = get_descendants(g, starting_node); + descendants.insert(starting_node); + return get_topological_ordering(get_subgraph(g, descendants)); +} + +std::vector + get_topological_ordering_from_starting_node(DiGraphView const &g, + Node const &starting_node) { + assert(is_acyclic(g)); + return get_unchecked_topological_ordering_from_starting_node(g, + starting_node); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc new file mode 100644 index 0000000000..979fb6838c --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_2_terminal_dag.cc @@ -0,0 +1,12 @@ +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" + +namespace FlexFlow { + +bool is_2_terminal_dag(DiGraphView const &g) { + return (is_acyclic(g) && (get_initial_nodes(g).size() == 1) && + get_terminal_nodes(g).size() == 1); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index 5757c11250..0f6a57e52c 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,31 +1,49 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" -#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/containers/generate_map.h" +#include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/node/algorithms.h" -#include "utils/graph/traversal.h" +#include namespace FlexFlow { -std::optional is_acyclic(DiGraphView const &g) { - if (num_nodes(g) == 0) { - return std::nullopt; - } - std::unordered_set initial_nodes = get_initial_nodes(g); - if (initial_nodes.size() == 0) { +enum class ExplorationStatus { NOT_EXPLORED, BEING_EXPLORED, FULLY_EXPLORED }; + +bool is_acyclic(DiGraphView const &g) { + + std::unordered_map status = + generate_map(get_nodes(g), [](Node const &n) { + return ExplorationStatus::NOT_EXPLORED; + }); + + // Recursively explore a given node and all its successors + // A node is fully explored once we have fully explored all of its successors + // If, while exploring, we find a node that was already being explored, then there is a + // cycle + std::function cycle_downstream_from_node = + [&](Node const &n) -> bool { + status[n] = ExplorationStatus::BEING_EXPLORED; + + for (Node const &successor : get_successors(g, n)) { + if (status.at(successor) == ExplorationStatus::NOT_EXPLORED) { + if (cycle_downstream_from_node( + successor)) { // one of the descendants is part of a cycle + return true; + } + } else if (status.at(successor) == ExplorationStatus::BEING_EXPLORED) { + return true; // we're exploring a node we were already exploring: we + // have hit a cycle + } + } + + status[n] = ExplorationStatus::FULLY_EXPLORED; return false; - } - auto dfs_view = unchecked_dfs(g, initial_nodes); - std::unordered_set seen; - for (unchecked_dfs_iterator it = dfs_view.begin(); it != dfs_view.end(); - it++) { - if (contains(seen, *it)) { + }; + + for (Node const &node : get_nodes(g)) { + if ((status.at(node) == ExplorationStatus::NOT_EXPLORED) && cycle_downstream_from_node(node)) { return false; - } else { - seen.insert(*it); } } - if (seen != get_nodes(g)) { - return false; - } return true; } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_tree.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_tree.cc new file mode 100644 index 0000000000..5028450ac4 --- /dev/null +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_tree.cc @@ -0,0 +1,19 @@ +#include "utils/graph/digraph/algorithms/is_tree.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/undirected/algorithms/get_connected_components.h" + +namespace FlexFlow { + +bool is_tree(DiGraphView const &g) { + ASSERT(num_nodes(g) > 0); + + bool has_single_root = get_initial_nodes(g).size() == 1; + bool is_connected = get_connected_components(as_undirected(g)).size() == 1; + bool node_edge_diff_is_1 = (get_edges(g).size() == num_nodes(g) - 1); + return has_single_root && is_connected && node_edge_diff_is_1; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc index bd865644eb..1e5d0b0ae7 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -6,6 +6,7 @@ #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/digraph/algorithms/materialize_digraph_view.h" #include "utils/graph/digraph/algorithms/transitive_closure.h" +#include "utils/graph/digraph/digraph.h" #include "utils/graph/instances/adjacency_digraph.h" #include "utils/graph/node/algorithms.h" @@ -29,16 +30,14 @@ DirectedEdgeMaskView *DirectedEdgeMaskView::clone() const { return new DirectedEdgeMaskView(this->g, this->edge_mask); } -DiGraphView transitive_reduction(DiGraphView const &g) { - /** - * Logic dropped down to raw adjacency matrix for performance. - * The version going through the full graph abstraction was - * incredibly slow (> minutes) for even moderately sized graphs - * (i.e., 200 nodes) without optimization enabled. - * - * transitive_closure inlined to avoid any drifts in node numbering - * between transitive_closure and transitive_reduction - */ +DiGraph transitive_reduction(DiGraphView const &g) { + // Logic dropped down to raw adjacency matrix for performance. + // The version going through the full graph abstraction was + // incredibly slow (> minutes) for even moderately sized graphs + // (i.e., 200 nodes) without optimization enabled. + // + // transitive_closure inlined to avoid any drifts in node numbering + // between transitive_closure and transitive_reduction bidict nodes = transform_keys(bidict_from_enumerating(get_nodes(g)), diff --git a/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc b/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc new file mode 100644 index 0000000000..818ef5f966 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/digraph_generation.cc @@ -0,0 +1,104 @@ +#include "utils/graph/series_parallel/digraph_generation.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/variant.h" + +namespace FlexFlow { + +std::unordered_map parallel_extend(DiGraph &g, + DiGraphView const &ext) { + std::unordered_map node_map; + for (Node const &node : get_nodes(ext)) { + node_map.emplace(node, g.add_node()); + } + for (DirectedEdge const &edge : get_edges(ext)) { + g.add_edge(DirectedEdge{node_map.at(edge.src), node_map.at(edge.dst)}); + } + return node_map; +} + +std::unordered_map serial_extend(DiGraph &g, + DiGraphView const &ext) { + std::unordered_set original_sinks = get_terminal_nodes(g); + std::unordered_map node_map = parallel_extend(g, ext); + for (Node const &node1 : original_sinks) { + for (Node const &node2 : get_initial_nodes(ext)) { + g.add_edge(DirectedEdge{node1, node_map.at(node2)}); + } + } + return node_map; +} + +DiGraph series_composition(DiGraphView const &g1, DiGraphView const &g2) { + DiGraph g = materialize_digraph_view(g1); + serial_extend(g, g2); + return g; +} + +DiGraph parallel_composition(DiGraphView const &g1, DiGraphView const &g2) { + DiGraph g = materialize_digraph_view(g1); + parallel_extend(g, g2); + return g; +} + +DiGraph series_composition(std::vector const &graphs) { + DiGraph g = DiGraph::create(); + for (DiGraphView const &gs : graphs) { + g = materialize_digraph_view(series_composition(g, gs)); + } + return g; +} + +// TODO(@pietro): should be std::unordered_set, but DiGraphs are +// currently non-hashable +DiGraph parallel_composition(std::vector const &graphs) { + DiGraph g = DiGraph::create(); + for (DiGraphView const &gs : graphs) { + g = materialize_digraph_view(parallel_composition(g, gs)); + } + return g; +} + +static DiGraph digraph_from_sp_decomposition(Node const &node) { + DiGraph g = DiGraph::create(); + g.add_node(); + return g; +} + +static DiGraph digraph_from_sp_decomposition(SeriesSplit const &serial) { + std::vector children = + transform(serial.children, [](std::variant const &child) { + return widen(child); + }); + return series_composition( + transform(children, [](SeriesParallelDecomposition const &child) -> DiGraphView { + return digraph_from_sp_decomposition(child); + })); +} + +static DiGraph digraph_from_sp_decomposition(ParallelSplit const ¶llel) { + std::vector children = + transform(vector_of(parallel.get_children()), [](std::variant const &child) { + return widen(child); + }); + return parallel_composition( + transform(children, [](SeriesParallelDecomposition const &child) -> DiGraphView { + return digraph_from_sp_decomposition(child); + })); +} + +DiGraph digraph_from_sp_decomposition(SeriesParallelDecomposition const &sp) { + return sp.visit( + [](auto const &x) { return digraph_from_sp_decomposition(x); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc b/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc new file mode 100644 index 0000000000..b9f475777e --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/get_ancestors.cc @@ -0,0 +1,55 @@ +#include "utils/graph/series_parallel/get_ancestors.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_only.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/variant.h" +#include + +namespace FlexFlow { + +std::unordered_set get_ancestors(SeriesParallelDecomposition const &sp, + Node const &node); + +static std::unordered_set get_ancestors(Node const &, Node const &node) { + return {}; +} + +static std::unordered_set get_ancestors(SeriesSplit const &serial, + Node const &node) { + std::unordered_set ancestors{}; + for (std::variant const &child : serial.children) { + SeriesParallelDecomposition child_sp = + widen(child); + if (contains(get_nodes(child_sp), node)) { + return set_union(ancestors, get_ancestors(child_sp, node)); + } + ancestors = set_union(ancestors, unordered_set_of(get_nodes(child_sp))); + } + throw std::runtime_error("node not found in SeriesSplit"); +} + +static std::unordered_set get_ancestors(ParallelSplit const ¶llel, + Node const &node) { + SeriesParallelDecomposition branch = + get_only(filter(transform(parallel.get_children(), + [](std::variant const &c) { + return widen(c); + }), + [&](SeriesParallelDecomposition const &child) { + return contains(get_nodes(child), node); + })); + return get_ancestors(branch, node); +} + +std::unordered_set get_ancestors(SeriesParallelDecomposition const &sp, + Node const &node) { + assert(contains(get_nodes(sp), node)); + return sp.visit>( + [&](auto const &t) { return get_ancestors(t, node); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc new file mode 100644 index 0000000000..557678ad35 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc @@ -0,0 +1,100 @@ +#include "utils/graph/series_parallel/non_normal_sp_decomposition.h" +#include "utils/containers/all_of.h" +#include "utils/containers/extend.h" +#include "utils/containers/multiset_union.h" +#include "utils/containers/transform.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/series_parallel/non_normal_parallel_split.dtg.h" +#include "utils/graph/series_parallel/non_normal_series_split.dtg.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/overload.h" +#include "utils/variant.h" + +namespace FlexFlow { + +NonNormalSPDecomposition non_normal_series_composition( + std::vector const &sp_compositions) { + + std::vector> composition{}; + + for (NonNormalSPDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + extend(composition, sp_comp.get().children); + } else if (sp_comp.has()) { + composition.push_back(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.push_back(sp_comp.get()); + } + } + + return NonNormalSPDecomposition{NonNormalSeriesSplit{composition}}; +} + +NonNormalSPDecomposition non_normal_parallel_composition( + std::unordered_multiset const &sp_compositions) { + + std::unordered_multiset< + std::variant<::FlexFlow::NonNormalSeriesSplit, ::FlexFlow::Node>> + composition{}; + + for (NonNormalSPDecomposition const &sp_comp : sp_compositions) { + if (sp_comp.has()) { + composition = multiset_union( + composition, sp_comp.get().get_children()); + } else if (sp_comp.has()) { + composition.insert(sp_comp.get()); + } else { + assert(sp_comp.has()); + composition.insert(sp_comp.get()); + } + } + return NonNormalSPDecomposition(NonNormalParallelSplit{composition}); +} + +static Node as_non_normal(Node const &n) { return n; } + +static NonNormalSeriesSplit as_non_normal(SeriesSplit const &s) { + return non_normal_series_composition( + transform(s.children, + [](std::variant const &child) { + return as_non_normal( + widen(child)); + })) + .get(); +} + +static NonNormalParallelSplit as_non_normal(ParallelSplit const &p) { + return non_normal_parallel_composition( + transform(p.get_children(), + [](std::variant const &child) { + return as_non_normal( + widen(child)); + })) + .get(); +} + +NonNormalSPDecomposition as_non_normal(SeriesParallelDecomposition const &sp) { + return sp.visit( + [](auto const &t) { return NonNormalSPDecomposition{as_non_normal(t)}; }); +} + +bool is_empty_non_normal(NonNormalSPDecomposition const &sp) { + return sp.visit(overload{ + [](Node const &) { return false; }, + [](NonNormalSeriesSplit const &serial) { + return all_of(serial.children, [](auto const &child) { + return is_empty_non_normal(widen(child)); + }); + }, + [](NonNormalParallelSplit const ¶llel) { + return all_of(parallel.get_children(), [](auto const &child) { + return is_empty_non_normal(widen(child)); + }); + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc new file mode 100644 index 0000000000..14187e879a --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -0,0 +1,68 @@ +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/containers/filter.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/exception.h" +#include "utils/graph/series_parallel/non_normal_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/variant.h" + +namespace FlexFlow { + +template +static auto filter_empty(T const &container) { + return filter(container, [](auto const &child) { + return !is_empty_non_normal(widen(child)); + }); +} + +static SeriesParallelDecomposition + normalize_sp_decomposition(Node const &node) { + return SeriesParallelDecomposition(node); +} + +static SeriesParallelDecomposition + normalize_sp_decomposition(NonNormalSeriesSplit const &serial) { + std::vector normalized_children = transform( + filter_empty(serial.children), + [](std::variant const &child) { + return normalize_sp_decomposition( + widen(child)); + }); + + if (normalized_children.empty()) { + throw mk_runtime_error( + "Cannot normalize empty SeriesSplit"); + } + if (normalized_children.size() == 1) { + return get_only(normalized_children); + } + return series_composition(normalized_children); +} + +static SeriesParallelDecomposition + normalize_sp_decomposition(NonNormalParallelSplit const ¶llel) { + std::unordered_multiset normalized_children = + transform(filter_empty(parallel.get_children()), + [](std::variant const &child) { + return normalize_sp_decomposition( + widen(child)); + }); + + if (normalized_children.empty()) { + throw mk_runtime_error( + "Cannot normalize empty ParallelSplit (should be filtered out)"); + } + if (normalized_children.size() == 1) { + return get_only(normalized_children); + } + return parallel_composition(normalized_children); +} + +SeriesParallelDecomposition + normalize_sp_decomposition(NonNormalSPDecomposition const &sp) { + return sp.visit( + [](auto const &t) { return normalize_sp_decomposition(t); }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index 937fc1254e..615564ff83 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -1,6 +1,7 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/containers/all_of.h" #include "utils/containers/extend.h" +#include "utils/containers/get_only.h" #include "utils/containers/multiset_union.h" #include "utils/containers/set_union.h" #include "utils/containers/sum.h" @@ -8,8 +9,11 @@ #include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" +#include "utils/exception.h" #include "utils/graph/series_parallel/intermediate_sp_decomposition_tree.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" #include "utils/hash/unordered_set.h" +#include "utils/nonnegative_int/nonnegative_int.h" #include "utils/variant.h" #include @@ -83,24 +87,26 @@ bool is_empty(Node const &node) { return false; } -bool is_empty(SeriesSplit const &serial) { - return all_of(serial.children, [](auto const &child) { - return is_empty(widen(child)); - }); +nonnegative_int num_nodes(SeriesParallelDecomposition const &sp) { + return sum(values(get_num_occurrences_of_nodes(sp))); } -bool is_empty(ParallelSplit const ¶llel) { - return all_of(parallel.get_children(), [](auto const &child) { - return is_empty(widen(child)); - }); -} - -bool is_empty(SeriesParallelDecomposition const &sp) { - return sp.visit([](auto const &t) { return is_empty(t); }); +bool has_no_duplicate_nodes(SeriesParallelDecomposition const &sp) { + return all_of(values(get_num_occurrences_of_nodes(sp)), + [](nonnegative_int count) { return count == 1_n; }); } SeriesParallelDecomposition series_composition( std::vector const &sp_compositions) { + + if (sp_compositions.empty()) { + throw mk_runtime_error("series_composition: cannot create series composition with zero elements"); + } + + if (sp_compositions.size() == 1) { + return get_only(sp_compositions); + } + std::vector> composition{}; for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { @@ -118,6 +124,14 @@ SeriesParallelDecomposition series_composition( SeriesParallelDecomposition parallel_composition( std::unordered_multiset const &sp_compositions) { + if (sp_compositions.empty()) { + throw mk_runtime_error("parallel_composition: cannot create parallel composition with zero elements"); + } + + if (sp_compositions.size() == 1) { + return get_only(sp_compositions); + } + std::unordered_multiset< std::variant<::FlexFlow::SeriesSplit, ::FlexFlow::Node>> composition{}; diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc new file mode 100644 index 0000000000..f52dc6a394 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc @@ -0,0 +1,125 @@ +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/containers/maximum.h" +#include "utils/containers/sum.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/digraph_generation.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/variant.h" +#include +namespace FlexFlow { + +static std::unordered_map + get_num_occurrences_of_nodes(Node const &node) { + return {{node, 1_n}}; +} + +template +static std::unordered_map + get_num_occurrences_of_nodes_impl(T const &t) { + std::unordered_map counter; + for (Node const &node : get_nodes(t)) { + counter.emplace(node, 0_n).first->second += 1_n; + } + return counter; +} + +static std::unordered_map + get_num_occurrences_of_nodes(ParallelSplit const ¶llel) { + return get_num_occurrences_of_nodes_impl(parallel); +} + +static std::unordered_map + get_num_occurrences_of_nodes(SeriesSplit const &serial) { + return get_num_occurrences_of_nodes_impl(serial); +} + +std::unordered_map + get_num_occurrences_of_nodes(SeriesParallelDecomposition const &sp) { + return get_num_occurrences_of_nodes_impl(sp); +} + +float work_cost(SeriesParallelDecomposition const &sp, + std::unordered_map cost_map) { + return sum(transform(get_nodes(sp), + [&](Node const &node) { return cost_map.at(node); })); +} + +float work_cost(DiGraphView const &g, + std::unordered_map const &cost_map) { + return sum(transform(vector_of(get_nodes(g)), + [&](Node const &node) { return cost_map.at(node); })); +} + +static float critical_path_cost(Node const &node, + std::unordered_map const &cost_map) { + return cost_map.at(node); +} + +static float critical_path_cost(SeriesSplit const &serial, + std::unordered_map const &cost_map) { + return sum(transform( + serial.children, [&](std::variant const &child) { + return critical_path_cost(widen(child), + cost_map); + })); +} + +static float critical_path_cost(ParallelSplit const ¶llel, + std::unordered_map const &cost_map) { + return maximum(transform(parallel.get_children(), + [&](std::variant const &child) { + return critical_path_cost( + widen(child), + cost_map); + })); +} + +float critical_path_cost(SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return sp.visit( + [&](auto const &t) { return critical_path_cost(t, cost_map); }); +} + +float critical_path_cost(DiGraphView const &g, + std::unordered_map const &cost_map) { + return maximum( + values(get_weighted_longest_path_lengths_from_root(g, cost_map))); +} + +nonnegative_int num_dependencies(SeriesParallelDecomposition const &sp) { + return num_dependencies(digraph_from_sp_decomposition(sp)); +} + +nonnegative_int num_dependencies(DiGraphView const &g) { + return nonnegative_int{get_edges(g).size()}; +} + +float relative_work_increase(DiGraphView const &g, + SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return work_cost(sp, cost_map) / work_cost(g, cost_map); +} + +float relative_critical_path_cost_increase( + DiGraphView const &g, + SeriesParallelDecomposition const &sp, + std::unordered_map const &cost_map) { + return critical_path_cost(sp, cost_map) / critical_path_cost(g, cost_map); +} + +float relative_num_dependencies_increase( + DiGraphView const &g, SeriesParallelDecomposition const &sp) { + return static_cast(num_dependencies(sp).unwrap_nonnegative()) / + static_cast(num_dependencies(g).unwrap_nonnegative()); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc new file mode 100644 index 0000000000..b7d2952fdc --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc @@ -0,0 +1,29 @@ +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/graph/digraph/algorithms/get_ancestors.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/get_ancestors.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include + +namespace FlexFlow { + +bool dependencies_are_maintained(DiGraphView const &g, + SeriesParallelDecomposition const &sp) { + ASSERT(has_no_duplicate_nodes(sp)); + if (unordered_set_of(get_nodes(sp)) != get_nodes(g)) { + return false; + } + + for (Node const &n : get_nodes(g)) { + std::unordered_set ancestors_in_g = get_ancestors(g, n); + std::unordered_set ancestors_in_sp = get_ancestors(sp, n); + if (!is_subseteq_of(ancestors_in_g, ancestors_in_sp)) { + return false; + } + } + return true; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc new file mode 100644 index 0000000000..93a35028b9 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -0,0 +1,218 @@ +#include "utils/graph/series_parallel/sp_ization/escribano_algo.h" +#include "utils/containers/filter_keys.h" +#include "utils/containers/get_only.h" +#include "utils/containers/group_by.h" +#include "utils/containers/intersection.h" +#include "utils/containers/map_values.h" +#include "utils/containers/maximum.h" +#include "utils/containers/range.h" +#include "utils/containers/set_union.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/algorithms/get_lowest_common_ancestors.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/get_weakly_connected_components.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/sp_ization/node_role.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include + +#include +#include + +namespace FlexFlow { + +static std::unordered_set + filter_sync_nodes(std::unordered_set const &nodes, + std::unordered_map const &node_roles) { + return filter( + nodes, [&](Node const &n) { return node_roles.at(n) != NodeRole::SYNC; }); +} + +static nonnegative_int get_max_depth(DiGraph const &sp, + std::unordered_map const &depth_map) { + return maximum(values(filter_keys( + depth_map, [&](Node const &n) { return contains(get_nodes(sp), n); }))); +} + +DiGraph add_dummy_nodes(DiGraph g, + std::unordered_map &node_roles) { + std::unordered_map depth_map = get_longest_path_lengths_from_root(g); + for (DirectedEdge const &e : get_edges(g)) { + Node src = e.src; + Node dst = e.dst; + int depth_diff = depth_map.at(dst).unwrap_nonnegative() - depth_map.at(src).unwrap_nonnegative(); + if (depth_diff > 1) { + g.remove_edge(e); + Node prev_node = src; + Node intermediate_node = Node{0}; + for (int i = 1; i < depth_diff; i++) { + intermediate_node = g.add_node(); + node_roles[intermediate_node] = NodeRole::DUMMY; + g.add_edge(DirectedEdge{prev_node, intermediate_node}); + prev_node = intermediate_node; + } + g.add_edge(DirectedEdge{prev_node, dst}); + } + } + return g; +} + +std::unordered_set + get_component(DiGraph const &g, + Node const &node, + std::unordered_map const &depth_map, + std::unordered_map const &node_roles) { + + nonnegative_int max_depth = get_max_depth(g, depth_map); + auto is_in_last_2_layers = [&](Node const &n) { + if (node_roles.at(n) == NodeRole::SYNC) { + if (get_successors(g, n).empty()) { + return true; + } + nonnegative_int successors_depth = + get_only(transform(get_successors(g, n), + [&](Node const &n) { return depth_map.at(n); })); + return successors_depth == max_depth; + } else { + return (depth_map.at(n) == max_depth) || + (depth_map.at(n) + 1_n == max_depth); + } + }; + std::unordered_set last_two_layers_nodes = + filter(get_nodes(g), is_in_last_2_layers); + + DiGraph subgraph = materialize_digraph_view( + get_subgraph(g, last_two_layers_nodes)); + std::unordered_set component = + get_only(filter(get_weakly_connected_components(subgraph), + [&](std::unordered_set const &component) { + return contains(component, node); + })); + std::unordered_set component_without_sync_nodes = + filter_sync_nodes(component, node_roles); + return component_without_sync_nodes; +} + +static std::unordered_set + get_forest_escribano(DiGraph const &g, + Node const &handle, + std::unordered_set const &component, + std::unordered_map const &node_roles) { + std::unordered_set> subtrees = + transform(get_successors(g, handle), [&](Node const &n) { + return set_union(get_descendants(g, n), {n}); + }); + auto subtrees_overlapping_with_component = + filter(subtrees, [&](std::unordered_set subtree) { + return intersection(subtree, component).size() > 0; + }); + std::unordered_set forest = + set_union(subtrees_overlapping_with_component); + forest.insert(handle); + return filter_sync_nodes(forest, node_roles); +} + +static std::pair, std::unordered_set> + get_up_and_down(DiGraph const &g, + std::unordered_set const &forest, + std::unordered_map const &depth_map) { + + nonnegative_int max_depth = get_max_depth(g, depth_map); + auto grouped_by_depth = + group_by(forest, [&](Node const &n) { return depth_map.at(n); }); + return {grouped_by_depth.at_l(nonnegative_int{max_depth.unwrap_nonnegative() - 1}), + grouped_by_depth.at_l(max_depth)}; +} + +static std::unordered_set + edges_to_remove(DiGraph const &g, + std::unordered_set const &up, + std::unordered_set const &down) { + std::unordered_set to_remove; + + for (Node const &u : up) { + to_remove = set_union(to_remove, get_outgoing_edges(g, u)); + } + for (Node const &d : down) { + to_remove = set_union(to_remove, get_incoming_edges(g, d)); + } + + return to_remove; +} + +static std::unordered_set + edges_to_add_escribano(std::unordered_set const &up, + std::unordered_set const &down, + Node const &sync_node) { + return set_union(transform(up, + [&](Node const &u) { + return DirectedEdge{u, sync_node}; + }), + transform(down, [&](Node const &d) { + return DirectedEdge{sync_node, d}; + })); +} + +SeriesParallelDecomposition escribano_sp_ization(DiGraph g) { + ASSERT(is_2_terminal_dag(g)); + ASSERT(is_acyclic(g)); + + std::unordered_map node_roles = get_initial_node_role_map(g); + + g = add_dummy_nodes(g, node_roles); + std::unordered_map depth_map = get_longest_path_lengths_from_root(g); + + DiGraph sp = DiGraph::create(); + Node root = get_only(get_initial_nodes(g)); + sp.add_node_unsafe(root); + size_t sync_node_counter = maximum( + transform(get_nodes(g), [&](Node const &n) { return n.raw_uid; })); + for (Node const &node : get_bfs_ordering(g, {root})) { + if (node == root) { + continue; + } + sp.add_node_unsafe(node); + add_edges(sp, vector_of(get_incoming_edges(g, node))); + std::unordered_set component = + get_component(sp, node, depth_map, node_roles); + Node handle = get_only(get_lowest_common_ancestors(sp, component).value()); + std::unordered_set forest = + get_forest_escribano(sp, handle, component, node_roles); + auto [up, down] = get_up_and_down(sp, forest, depth_map); + + for (DirectedEdge const &e : edges_to_remove(sp, up, down)) { + sp.remove_edge(e); + } + + Node sync_node = Node{++sync_node_counter}; + node_roles[sync_node] = NodeRole::SYNC; + sp.add_node_unsafe(sync_node); + for (DirectedEdge const &e : edges_to_add_escribano(up, down, sync_node)) { + sp.add_edge(e); + } + } + sp = contract_out_nodes_of_given_role(sp, NodeRole::DUMMY, node_roles); + sp = transitive_reduction(sp); + sp = contract_out_nodes_of_given_role(sp, NodeRole::SYNC, node_roles); + return get_series_parallel_decomposition(sp).value(); +} +} // namespace FlexFlow + diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc new file mode 100644 index 0000000000..4f3de2a4da --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc @@ -0,0 +1,328 @@ +#include "utils/graph/series_parallel/sp_ization/flexible_algo.h" +#include "utils/containers/all_of.h" +#include "utils/containers/compare_by.h" +#include "utils/containers/contains.h" +#include "utils/containers/filter.h" +#include "utils/containers/generate_map.h" +#include "utils/containers/get_only.h" +#include "utils/containers/intersection.h" +#include "utils/containers/is_subseteq_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/maximum.h" +#include "utils/containers/set_difference.h" +#include "utils/containers/set_union.h" +#include "utils/containers/sorted_by.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_ancestors.h" +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/algorithms/get_lowest_common_ancestors.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/digraph/algorithms/materialize_digraph_view.h" +#include "utils/graph/digraph/algorithms/transitive_reduction.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/get_series_parallel_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/node_role.h" +#include "utils/graph/series_parallel/sp_ization/up_down_partition.h" +#include + +#include +#include + +namespace FlexFlow { + +static std::unordered_set + get_component(DiGraph const &sp, std::unordered_set const &nodes) { + std::unordered_set parents = set_union( + transform(nodes, [&](Node const &n) { return get_predecessors(sp, n); })); + std::unordered_set children = set_union(transform( + parents, [&](Node const &p) { return get_descendants(sp, p); })); + std::unordered_set other_parents = set_union(transform( + children, [&](Node const &c) { return get_predecessors(sp, c); })); + return set_union(set_union(parents, children), other_parents); +} + +static std::unordered_set + get_forest_flexible(DiGraph const &sp, + Node const &handle, + std::unordered_set const &component, + std::unordered_map const &node_roles) { + std::unordered_set> subtrees = + transform(get_successors(sp, handle), [&](Node const &n) { + return set_union(get_descendants(sp, n), {n}); + }); + + std::unordered_set> overlapping_subtrees = + filter(subtrees, [&](std::unordered_set const &subtree) { + return !intersection(subtree, component).empty(); + }); + + std::unordered_set forest = set_union(overlapping_subtrees); + forest.insert(handle); + + return filter(forest, [&](Node const &n) { + return node_roles.at(n) != NodeRole::SYNC; + }); +} + +static UpDownPartition + get_up_and_down(DiGraph const &sp, + std::unordered_set const &nodes, + std::unordered_set const &forest, + std::unordered_map const &cost_map, + std::unordered_map const &node_roles) { + DiGraph sp_pure = + contract_out_nodes_of_given_role(materialize_digraph_view(sp), + NodeRole::SYNC, + node_roles); + + std::unordered_set base_down = nodes; + std::unordered_set base_up = intersection( + set_union(transform( + nodes, [&](Node const &n) { return get_ancestors(sp_pure, n); })), + forest); + std::unordered_set assignable_nodes = + set_difference(forest, set_union(base_up, base_down)); + + DiGraphView forest_subgraph = get_subgraph(sp_pure, forest); + std::unordered_map critical_path_cost_map = + get_weighted_longest_path_lengths_from_root(forest_subgraph, cost_map); + + auto get_partition_with_max_up_cost = + [&](float reference_cost) -> UpDownPartition { + std::unordered_set up = + set_union(base_up, filter(assignable_nodes, [&](Node const &n) { + return critical_path_cost_map.at(n) <= reference_cost; + })); + std::unordered_set down = + set_difference(set_union(base_down, assignable_nodes), up); + return UpDownPartition{up, down}; + }; + + auto is_valid = [&](UpDownPartition const &partition) -> bool { + if (!is_subseteq_of(nodes, partition.down)) { + return false; + } + + for (Node const &node : get_nodes(sp_pure)) { + if (contains(partition.down, node)) { + for (Node const &child : get_successors(sp_pure, node)) { + if (contains(partition.up, child)) { + return false; + } + } + for (Node const &parent : get_predecessors(sp_pure, node)) { + if (contains(forest, parent) && !contains(partition.up, parent) && + !contains(partition.down, parent)) { + return false; + } + } + } + } + return true; + }; + + std::unordered_set partitions = + transform(assignable_nodes, [&](Node const &n) { + return get_partition_with_max_up_cost(critical_path_cost_map.at(n)); + }); + partitions.insert( + UpDownPartition{base_up, set_union(base_down, assignable_nodes)}); + + std::unordered_set valid_partitions = + filter(partitions, is_valid); + ASSERT(!valid_partitions.empty()); + + auto partition_cost = [&](UpDownPartition const &p) { + float up_cost = critical_path_cost(get_subgraph(sp_pure, p.up), cost_map); + float down_cost = + critical_path_cost(get_subgraph(sp_pure, p.down), cost_map); + return std::make_tuple(up_cost + down_cost, down_cost, p.down.size()); + }; + + return sorted_by(valid_partitions, + compare_by(partition_cost)) + .at(0); +} + +static std::unordered_set edges_to_remove_flexible( + DiGraph const &sp, + std::unordered_set const &up, + std::unordered_set const &down, + std::unordered_map const &node_roles) { + std::unordered_set to_remove; + + // from up to down + for (Node const &u : up) { + for (DirectedEdge const &e : get_outgoing_edges(sp, u)) { + if (contains(down, e.dst)) { + to_remove.insert(e); + } + } + } + + for (Node const &node : get_nodes(sp)) { + if (node_roles.at(node) == NodeRole::SYNC) { + std::unordered_set preds = get_predecessors(sp, node); + std::unordered_set succs = get_successors(sp, node); + if (is_subseteq_of(preds, up) && is_subseteq_of(succs, down)) { + to_remove = set_union(to_remove, get_incoming_edges(sp, node)); + to_remove = set_union(to_remove, get_outgoing_edges(sp, node)); + } + } + } + + return to_remove; +} + +static std::unordered_set + edges_to_add_flexible(DiGraph const &sp, + UpDownPartition const &partition, + Node const &sync_node) { + std::unordered_set up_frontier = get_up_frontier(sp, partition); + std::unordered_set down_frontier = get_down_frontier(sp, partition); + + return set_union(transform(up_frontier, + [&](Node const &u) { + return DirectedEdge{u, sync_node}; + }), + transform(down_frontier, [&](Node const &d) { + return DirectedEdge{sync_node, d}; + })); +} + +static std::unordered_set + get_next_nodes(DiGraph const &sp, + DiGraph const &g, + std::unordered_map const &cost_map) { + std::unordered_map sp_longest_paths = + get_weighted_longest_path_lengths_from_root(sp, cost_map); + + std::unordered_set sp_nodes = get_nodes(sp); + std::unordered_set g_nodes = get_nodes(g); + + // candidate nodes: not in sp but all predecessors in sp + std::unordered_set candidate_nodes = + filter(g_nodes, [&](Node const &node) { + if (contains(sp_nodes, node)) { + return false; + } + std::unordered_set preds = get_predecessors(g, node); + return is_subseteq_of(preds, sp_nodes); + }); + + ASSERT(!candidate_nodes.empty()); + + std::unordered_map critical_path_costs = + generate_map(candidate_nodes, [&](Node const &node) { + std::unordered_set preds = get_predecessors(g, node); + float max_parent_cost = maximum(transform(preds, [&](Node const &pred) { + return sp_longest_paths.at(pred); + })); + return cost_map.at(node) + max_parent_cost; + }); + + Node ref_node = + sorted_by(candidate_nodes, compare_by([&](Node const &n) { + return std::make_pair(critical_path_costs.at(n), n.raw_uid); + })) + .at(0); + + std::unordered_set ref_preds = get_predecessors(g, ref_node); + return filter(candidate_nodes, [&](Node const &node) { + return get_predecessors(g, node) == ref_preds; + }); +} + +static bool cost_map_is_valid(DiGraphView const &g, + std::unordered_map const &cost_map) { + bool has_correct_nodes = get_nodes(g) == keys(cost_map); + bool has_nonnegative_costs = + all_of(values(cost_map), [&](float const &cost) { return cost >= 0.0f; }); + return has_correct_nodes && has_nonnegative_costs; +} + +SeriesParallelDecomposition + flexible_sync_unchecked(DiGraphView const &g, + std::unordered_map cost_map) { + DiGraph g_reduced = + materialize_digraph_view(transitive_reduction(g)); + + std::unordered_map node_roles = + get_initial_node_role_map(g_reduced); + + DiGraph sp = DiGraph::create(); + Node root = get_only(get_initial_nodes(g_reduced)); + sp.add_node_unsafe(root); + + while (!is_subseteq_of(get_nodes(g_reduced), get_nodes(sp))) { + std::unordered_set nodes = get_next_nodes(sp, g_reduced, cost_map); + + for (Node const &node : nodes) { + // @colin: not sure if this matches the spec, the counter for the node uid + // is global and since we have generated these nodes already, we are + // guaranteed that the uid of the sync nodes will not overlap with them. + sp.add_node_unsafe(node); + add_edges(sp, vector_of(get_incoming_edges(g_reduced, node))); + } + // TODO(@pietro): ideally optimize this by selectively removing previously + // added edges + sp = transitive_reduction(sp); + + std::unordered_set component = get_component(sp, nodes); + Node handle = get_only(get_lowest_common_ancestors(sp, component).value()); + std::unordered_set forest = + get_forest_flexible(sp, handle, component, node_roles); + + UpDownPartition partition = + get_up_and_down(sp, nodes, forest, cost_map, node_roles); + + Node sync_node = sp.add_node(); + node_roles[sync_node] = NodeRole::SYNC; + cost_map[sync_node] = 0.0f; + + for (DirectedEdge const &e : edges_to_remove_flexible( + sp, partition.up, partition.down, node_roles)) { + sp.remove_edge(e); + } + for (DirectedEdge const &e : + edges_to_add_flexible(sp, partition, sync_node)) { + sp.add_edge(e); + } + } + + sp = transitive_reduction(sp); + sp = contract_out_nodes_of_given_role(sp, NodeRole::SYNC, node_roles); + + SeriesParallelDecomposition decomp = + get_series_parallel_decomposition(sp).value(); + ASSERT(dependencies_are_maintained(g, decomp)); + + return decomp; +} + +SeriesParallelDecomposition + flexible_sp_ization(DiGraphView const &g, + std::unordered_map const &cost_map) { + ASSERT(is_2_terminal_dag(g)); + ASSERT(is_acyclic(g)); + ASSERT(cost_map_is_valid(g, cost_map)); + + return flexible_sync_unchecked(g, cost_map); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc new file mode 100644 index 0000000000..f8d4430548 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc @@ -0,0 +1,54 @@ +#include "utils/graph/series_parallel/sp_ization/naive_stratum_sync.h" +#include "utils/containers/maximum.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" +#include "utils/fmt/unordered_multiset.h" +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/digraph/algorithms/is_acyclic.h" +#include "utils/graph/series_parallel/non_normal_sp_decomposition.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include + +namespace FlexFlow { + +std::vector> +stratum_split_assuming_unit_cost(DiGraphView const &g) { + std::unordered_map node_to_stratum = + get_longest_path_lengths_from_root(g); + std::vector> result( + maximum(values(node_to_stratum)).unwrap_nonnegative()); + for (auto const &[node, depth] : node_to_stratum) { + result[depth.unwrap_nonnegative() - 1].insert(node); + } + return result; +} + +static SeriesParallelDecomposition +naive_stratum_merge(std::vector> stratum_split) { + std::vector strata = transform( + stratum_split, [](std::unordered_multiset const &stratum_nodes) { + return parallel_composition(transform(stratum_nodes, [](Node const &n) { + return SeriesParallelDecomposition{n}; + })); + }); + return normalize_sp_decomposition(as_non_normal(series_composition(strata))); +} + +SeriesParallelDecomposition +stratum_sync_sp_ization_unchecked(DiGraphView const &g) { + + std::vector> stratum_split = + stratum_split_assuming_unit_cost(g); + return naive_stratum_merge(stratum_split); +} + +SeriesParallelDecomposition stratum_sync_sp_ization(DiGraphView const &g) { + ASSERT(is_acyclic(g)); + SeriesParallelDecomposition sp = stratum_sync_sp_ization_unchecked(g); + ASSERT(dependencies_are_maintained(g, sp)); + return sp; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc new file mode 100644 index 0000000000..97b8f11ec3 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc @@ -0,0 +1,34 @@ +#include "utils/graph/series_parallel/sp_ization/node_role.h" +#include "utils/containers/generate_map.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_successors.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/node/algorithms.h" + +namespace FlexFlow { + +std::unordered_map + get_initial_node_role_map(DiGraphView const &g) { + return generate_map(get_nodes(g), + [](Node const &) { return NodeRole::PURE; }); +} + +DiGraph contract_out_nodes_of_given_role( + DiGraph g, + NodeRole const &role, + std::unordered_map const &node_roles) { + for (Node const &n : get_nodes(g)) { + if (node_roles.at(n) == role) { + for (Node const &pred : get_predecessors(g, n)) { + for (Node const &succ : get_successors(g, n)) { + g.add_edge(DirectedEdge{pred, succ}); + } + } + remove_node(g, n); + } + } + return g; +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/up_down_partition.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/up_down_partition.cc new file mode 100644 index 0000000000..87a7eaa9b4 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/up_down_partition.cc @@ -0,0 +1,25 @@ +#include "utils/graph/series_parallel/sp_ization/up_down_partition.h" +#include "utils/containers/filter.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" + +namespace FlexFlow { + +std::unordered_set get_up_frontier(DiGraph const &sp, + UpDownPartition const &partition) { + DiGraphView up_subgraph = get_subgraph(sp, partition.up); + return filter(partition.up, [&](Node const &node) { + return get_outgoing_edges(up_subgraph, node).empty(); + }); +} + +std::unordered_set get_down_frontier(DiGraph const &sp, + UpDownPartition const &partition) { + DiGraphView down_subgraph = get_subgraph(sp, partition.down); + return filter(partition.down, [&](Node const &node) { + return get_incoming_edges(down_subgraph, node).empty(); + }); +} + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc new file mode 100644 index 0000000000..80c8bbfae5 --- /dev/null +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc @@ -0,0 +1,133 @@ +#include "utils/graph/series_parallel/sp_ization/work_duplicating_spization.h" +#include "utils/containers/get_only.h" +#include "utils/containers/transform.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/vector_of.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_predecessors.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" +#include "utils/graph/digraph/algorithms/is_2_terminal_dag.h" +#include "utils/graph/digraph/digraph_view.h" +#include "utils/graph/series_parallel/non_normal_sp_decomposition.h" +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.h" +#include "utils/variant.h" +#include +#include + +namespace FlexFlow { + +static NonNormalSeriesSplit cut_off_head(NonNormalSeriesSplit const &s) { + ASSERT(s.children.size() > 0); + return NonNormalSeriesSplit{ + std::vector>{ + s.children.begin() + 1, s.children.end()}}; +} + +/* Performs a parallel composition with coalescing, where components with a + * common starting child are merged together + * Example: to parallel compose S(1, 2, 5), S(1, 3, 4): + * without coalescing: P(S(1, 2, 5), S(1, 3, 4)) + * with coalescing: S(1, P( S(2,5), S(3,4) )) + */ +static NonNormalSPDecomposition parallel_composition_with_coalescing( + std::unordered_set const &strands) { + if (strands.size() == 1) { + return NonNormalSPDecomposition{get_only(strands)}; + } + + // group strands by their first ("head") node + std::unordered_map, + std::unordered_set> + grouped_strands; + for (NonNormalSeriesSplit predecessor : + filter(strands, [](NonNormalSeriesSplit const &serial) { + return !is_empty_non_normal(NonNormalSPDecomposition{serial}); + })) { + grouped_strands[predecessor.children.at(0)].insert( + cut_off_head(predecessor)); + } + + // recursively coalesce the strands + std::unordered_multiset coalesced_strands; + for (auto const &[head, tails] : grouped_strands) { + NonNormalSPDecomposition parallel_comp = + parallel_composition_with_coalescing(tails); + + NonNormalSPDecomposition series_comp = non_normal_series_composition( + {widen(head), parallel_comp}); + coalesced_strands.insert( + as_non_normal(normalize_sp_decomposition(series_comp))); + } + + return non_normal_parallel_composition(coalesced_strands); +} + +static SeriesParallelDecomposition +work_duplicating_spization_unchecked_with_coalescing(DiGraphView const &g) { + std::unordered_map node_to_sp; + + Node source = get_only(get_initial_nodes(g)); + node_to_sp.emplace(source, NonNormalSeriesSplit{{source}}); + + for (Node const &node : get_topological_ordering(g)) { + if (node == source) { + continue; + } + std::unordered_set predecessors_as_sp = + transform(get_predecessors(g, node), + [&](Node const &p) { return node_to_sp.at(p); }); + + NonNormalSPDecomposition parallel_composed_predecessors = + as_non_normal(normalize_sp_decomposition( + parallel_composition_with_coalescing(predecessors_as_sp))); + NonNormalSeriesSplit sp_decomp = + non_normal_series_composition( + {parallel_composed_predecessors, NonNormalSPDecomposition{node}}) + .get(); + node_to_sp.emplace(node, sp_decomp); + } + + Node sink = get_only(get_terminal_nodes(g)); + return normalize_sp_decomposition( + NonNormalSPDecomposition{node_to_sp.at(sink)}); +} + +static SeriesParallelDecomposition +work_duplicating_spization_unchecked(DiGraphView const &g) { + std::unordered_map node_to_sp; + + for (Node const &node : get_topological_ordering(g)) { + + std::unordered_multiset predecessors_as_sp = + unordered_multiset_of( + transform(get_predecessors(g, node), + [&](Node const &p) { return node_to_sp.at(p); })); + + NonNormalSPDecomposition sp_decomp = non_normal_series_composition( + {non_normal_parallel_composition(predecessors_as_sp), + NonNormalSPDecomposition{node}}); + + node_to_sp.emplace(node, sp_decomp); + } + + Node sink = get_only(get_terminal_nodes(g)); + return normalize_sp_decomposition(node_to_sp.at(sink)); +} + +SeriesParallelDecomposition +naive_work_duplicating_spization(DiGraphView const &g) { + ASSERT(is_2_terminal_dag(g)); + return work_duplicating_spization_unchecked(g); +} + +SeriesParallelDecomposition +work_duplicating_spization_with_coalescing(DiGraphView const &g) { + ASSERT(is_2_terminal_dag(g)); + return work_duplicating_spization_unchecked_with_coalescing(g); +} + + +} // namespace FlexFlow diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_bottlenecks.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_bottlenecks.cc new file mode 100644 index 0000000000..e66767ffc6 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_bottlenecks.cc @@ -0,0 +1,75 @@ +#include "utils/graph/digraph/algorithms/get_bottlenecks.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_bottlenecks") { + DiGraph g = DiGraph::create(); + + SUBCASE("single node") { + std::vector n = add_nodes(g, 1); + std::unordered_set expected = {n.at(0)}; + CHECK(get_bottlenecks(g) == expected); + } + + SUBCASE("linear graph") { + std::vector n = add_nodes(g, 3); + add_edges( + g, {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(1), n.at(2)}}); + + std::unordered_set expected = {n.at(0), n.at(1), n.at(2)}; + CHECK(get_bottlenecks(g) == expected); + } + + SUBCASE("rhombus") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + + std::unordered_set expected = {n.at(0), n.at(3)}; + CHECK(get_bottlenecks(g) == expected); + } + + SUBCASE("two rhombuses in serial") { + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}}); + + std::unordered_set expected = {n.at(0), n.at(3), n.at(5)}; + CHECK(get_bottlenecks(g) == expected); + } + + SUBCASE("middle bottleneck") { + std::vector n = add_nodes(g, 5); + add_edges(g, + {DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}}); + + std::unordered_set expected = {}; + CHECK(get_bottlenecks(g) == expected); + } + + SUBCASE("single source, multiple sinks") { + std::vector n = add_nodes(g, 3); + add_edges( + g, {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(0), n.at(2)}}); + + std::unordered_set expected = {n.at(0)}; + CHECK(get_bottlenecks(g) == expected); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_descendants.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_descendants.cc new file mode 100644 index 0000000000..a115569139 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_descendants.cc @@ -0,0 +1,127 @@ +#include "utils/graph/digraph/algorithms/get_descendants.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_descendants") { + DiGraph g = DiGraph::create(); + + SUBCASE("single node") { + std::vector n = add_nodes(g, 1); + + std::unordered_set correct = {}; + std::unordered_set result = get_descendants(g, n.at(0)); + CHECK(correct == result); + } + + SUBCASE("linear graph") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + SUBCASE("n.at(0)") { + std::unordered_set correct = {n.at(1), n.at(2), n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(0)); + CHECK(correct == result); + } + + SUBCASE("n.at(1)") { + std::unordered_set correct = {n.at(2), n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(1)); + CHECK(correct == result); + } + + SUBCASE("n.at(2)") { + std::unordered_set correct = {n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(2)); + CHECK(correct == result); + } + + SUBCASE("n.at(3)") { + std::unordered_set correct = {}; + std::unordered_set result = get_descendants(g, n.at(3)); + CHECK(correct == result); + } + } + + SUBCASE("rhombus") { + std::vector n = add_nodes(g, 5); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + SUBCASE("n.at(0)") { + std::unordered_set correct = {n.at(1), n.at(2), n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(0)); + CHECK(correct == result); + } + + SUBCASE("n.at(1)") { + std::unordered_set correct = {n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(1)); + CHECK(correct == result); + } + + SUBCASE("n.at(2)") { + std::unordered_set correct = {n.at(3)}; + std::unordered_set result = get_descendants(g, n.at(2)); + CHECK(correct == result); + } + + SUBCASE("n.at(3)") { + std::unordered_set correct = {}; + std::unordered_set result = get_descendants(g, n.at(3)); + CHECK(correct == result); + } + } + + SUBCASE("disconnected graph") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + SUBCASE("n.at(0)") { + std::unordered_set correct = {n.at(1), n.at(2)}; + std::unordered_set result = get_descendants(g, n.at(0)); + CHECK(correct == result); + } + + SUBCASE("n.at(1)") { + std::unordered_set correct = {n.at(2)}; + std::unordered_set result = get_descendants(g, n.at(1)); + CHECK(correct == result); + } + + SUBCASE("n.at(2)") { + std::unordered_set correct = {}; + std::unordered_set result = get_descendants(g, n.at(2)); + CHECK(correct == result); + } + + SUBCASE("n.at(3)") { + std::unordered_set correct = {n.at(4)}; + std::unordered_set result = get_descendants(g, n.at(3)); + CHECK(correct == result); + } + + SUBCASE("n.at(4)") { + std::unordered_set correct = {}; + std::unordered_set result = get_descendants(g, n.at(4)); + CHECK(correct == result); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc new file mode 100644 index 0000000000..9dd44fe4ec --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.cc @@ -0,0 +1,62 @@ +#include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_longest_path_lengths_from_root") { + SUBCASE("linear graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + + add_edges(g, edges); + + std::unordered_map expected_lengths = { + {n.at(0), 1_n}, + {n.at(1), 2_n}, + {n.at(2), 3_n}, + {n.at(3), 4_n}, + {n.at(4), 5_n}, + }; + + CHECK(get_longest_path_lengths_from_root(g) == expected_lengths); + } + + SUBCASE("more complex graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + std::vector edges = {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(0), n.at(6)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(5), n.at(6)}}; + + add_edges(g, edges); + + std::unordered_map expected_lengths = { + {n.at(0), 1_n}, + {n.at(1), 2_n}, + {n.at(2), 3_n}, + {n.at(3), 4_n}, + {n.at(4), 2_n}, + {n.at(5), 5_n}, + {n.at(6), 6_n}, + }; + + CHECK(get_longest_path_lengths_from_root(g) == expected_lengths); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc new file mode 100644 index 0000000000..ebc5b6e19f --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_lowest_common_ancestors.cc @@ -0,0 +1,198 @@ +#include "utils/graph/digraph/algorithms/get_lowest_common_ancestors.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_lowest_common_ancestors") { + DiGraph g = DiGraph::create(); + + SUBCASE("returns nullopt for empty input") { + SUBCASE("empty graph") { + std::optional> correct = std::nullopt; + std::optional> result = + get_lowest_common_ancestors(g, {}); + CHECK(correct == result); + } + + SUBCASE("non-empty graph with empty set") { + std::vector n = add_nodes(g, 3); + add_edges( + g, + {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(0), n.at(2)}}); + std::optional> correct = std::nullopt; + std::optional> result = + get_lowest_common_ancestors(g, {}); + CHECK(correct == result); + } + } + + SUBCASE("trees") { + SUBCASE("single node") { + std::vector n = add_nodes(g, 1); + std::optional> correct = + std::unordered_set{n.at(0)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(0)}); + CHECK(correct == result); + } + + SUBCASE("simple tree") { + std::vector n = add_nodes(g, 3); + add_edges( + g, + {DirectedEdge{n.at(0), n.at(1)}, DirectedEdge{n.at(0), n.at(2)}}); + + SUBCASE("LCA of siblings is parent") { + std::optional> correct = + std::unordered_set{n.at(0)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(1), n.at(2)}); + CHECK(correct == result); + } + + SUBCASE("LCA of a single node is itself") { + std::optional> correct = + std::unordered_set{n.at(1)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(1)}); + CHECK(correct == result); + } + + SUBCASE("LCA of another single node is itself") { + std::optional> correct = + std::unordered_set{n.at(2)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(2)}); + CHECK(correct == result); + } + } + + SUBCASE("nodes at different heights") { + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}}); + + SUBCASE("LCA of nodes at different depths (root is LCA)") { + std::optional> correct = + std::unordered_set{n.at(0)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(5), n.at(2)}); + CHECK(correct == result); + } + + SUBCASE("LCA of node and its ancestor is the ancestor") { + std::optional> correct = + std::unordered_set{n.at(3)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(5), n.at(3)}); + CHECK(correct == result); + } + + SUBCASE("LCA of siblings at depth 2") { + std::optional> correct = + std::unordered_set{n.at(1)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(3), n.at(4)}); + CHECK(correct == result); + } + + SUBCASE("LCA of multiple nodes across different branches") { + std::optional> correct = + std::unordered_set{n.at(0)}; + std::optional> result = + get_lowest_common_ancestors( + g, {n.at(1), n.at(2), n.at(3), n.at(4), n.at(5)}); + CHECK(correct == result); + } + } + + SUBCASE("straight path") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}}); + + SUBCASE("LCA of adjacent nodes in a path") { + std::optional> correct = + std::unordered_set{n.at(2)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(2), n.at(3)}); + CHECK(correct == result); + } + + SUBCASE("LCA of non-adjacent nodes in a path") { + std::optional> correct = + std::unordered_set{n.at(1)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(1), n.at(3)}); + CHECK(correct == result); + } + + SUBCASE("LCA of multiple nodes in a path") { + std::optional> correct = + std::unordered_set{n.at(1)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(1), n.at(2), n.at(3)}); + CHECK(correct == result); + } + } + } + + SUBCASE("general dags") { + + SUBCASE("no LCA") { + std::vector n = add_nodes(g, 3); + add_edges( + g, + {DirectedEdge{n.at(0), n.at(2)}, DirectedEdge{n.at(1), n.at(2)}}); + + std::optional> correct = + std::unordered_set{}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(0), n.at(1)}); + CHECK(correct == result); + } + + SUBCASE("multiple LCAs") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}}); + + std::optional> correct = + std::unordered_set{n.at(0), n.at(1)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(2), n.at(3)}); + CHECK(correct == result); + } + + SUBCASE("single LCA") { + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(1), n.at(5)}}); + + std::optional> correct = + std::unordered_set{n.at(3)}; + std::optional> result = + get_lowest_common_ancestors(g, {n.at(4), n.at(5)}); + CHECK(correct == result); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc index e675e6903f..82e0ea61dd 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -9,23 +9,119 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("is_acyclic") { DiGraph g = DiGraph::create(); + SUBCASE("empty graph") { + CHECK(is_acyclic(g)); + } - std::vector n = add_nodes(g, 6); + SUBCASE("single node") { + add_nodes(g, 1); + CHECK(is_acyclic(g)); + } - add_edges(g, - { - DirectedEdge{n.at(0), n.at(1)}, - DirectedEdge{n.at(1), n.at(2)}, - DirectedEdge{n.at(1), n.at(3)}, - DirectedEdge{n.at(1), n.at(5)}, - DirectedEdge{n.at(2), n.at(4)}, - DirectedEdge{n.at(3), n.at(1)}, - DirectedEdge{n.at(3), n.at(4)}, - }); + SUBCASE("simple acyclic graph") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + }); + CHECK(is_acyclic(g)); + } - std::optional correct = false; - std::optional result = is_acyclic(g); + SUBCASE("simple cyclic graph") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + }); + CHECK_FALSE(is_acyclic(g)); + } - CHECK(result == correct); + SUBCASE("diamond graph") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + CHECK(is_acyclic(g)); + } + + SUBCASE("2 parallel chains") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + CHECK(is_acyclic(g)); + } + SUBCASE("graph has a root and a cycle") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}}); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("graph has no root and a cycle") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(0)}}); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("straight line graph with a transitive edge") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}}); + CHECK(is_acyclic(g)); + } + + SUBCASE("complex cyclic graph") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(5), n.at(1)}, + }); + CHECK_FALSE(is_acyclic(g)); + } + + SUBCASE("complex cyclic graph #2") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(5)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(1)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + CHECK_FALSE(is_acyclic(g)); + } } } diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/is_tree.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/is_tree.cc new file mode 100644 index 0000000000..fe42ba1a62 --- /dev/null +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/is_tree.cc @@ -0,0 +1,95 @@ +#include "utils/graph/digraph/algorithms/is_tree.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_tree") { + DiGraph g = DiGraph::create(); + + SUBCASE("single node") { + add_nodes(g, 1); + CHECK(is_tree(g)); + } + + SUBCASE("simple tree") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + }); + CHECK(is_tree(g)); + } + + SUBCASE("simple cycle") { + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(0)}, + }); + CHECK_FALSE(is_tree(g)); + } + + SUBCASE("diamond pattern") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + CHECK_FALSE(is_tree(g)); + } + + SUBCASE("dowstream cycle") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(2)}, + }); + CHECK_FALSE(is_tree(g)); + } + + SUBCASE("multiple roots") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + CHECK_FALSE(is_tree(g)); + } + + SUBCASE("multiple incoming edges") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + CHECK_FALSE(is_tree(g)); + } + + SUBCASE("crossing") { + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(3)}, + }); + CHECK_FALSE(is_tree(g)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc index be912d0011..7e47bc470c 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/transitive_reduction.cc @@ -40,6 +40,50 @@ TEST_SUITE(FF_TEST_SUITE) { } } + SUBCASE("linear graph with additional edge") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(0), n.at(3)}, + }); + + DiGraphView result = transitive_reduction(g); + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + }; + CHECK(result_edges == correct_edges); + } + + SUBCASE("linear graph with 2 additional edges") { + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + }); + + DiGraphView result = transitive_reduction(g); + std::unordered_set result_edges = get_edges(result); + std::unordered_set correct_edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + CHECK(result_edges == correct_edges); + } + SUBCASE("nontrivial graph") { // from // https://en.wikipedia.org/w/index.php?title=Transitive_reduction&oldid=1226082357#In_directed_acyclic_graphs diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc index 8f085f8f64..df4cb475f2 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/binary_sp_decomposition_tree.cc @@ -15,7 +15,9 @@ TEST_SUITE(FF_TEST_SUITE) { return BinarySPDecompositionTree{BinaryParallelSplit{lhs, rhs}}; }; - auto make_leaf = [](int n) { return BinarySPDecompositionTree{Node{n}}; }; + auto make_leaf = [](size_t n) { + return BinarySPDecompositionTree{Node{n}}; + }; SUBCASE("leaf only") { BinarySPDecompositionTree input = make_leaf(5); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc index 0e9415525f..08ffc44ff3 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/left_associative_binary_sp_tree_from_nary.cc @@ -62,7 +62,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree result = left_associative_binary_sp_tree_from_nary(input); - // we use multiple checks here because SerialParallelDecomposition's + // we use multiple checks here because SeriesParallelDecomposition's // ParallelSplit is unordered, so there are multiple possible // left-associative binary SP trees CHECK(is_binary_sp_tree_left_associative(result)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc index 532ff86c90..7b43f52b8f 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/binary_sp_decomposition_tree/right_associative_binary_sp_tree_from_nary.cc @@ -60,7 +60,7 @@ TEST_SUITE(FF_TEST_SUITE) { BinarySPDecompositionTree result = right_associative_binary_sp_tree_from_nary(input); - // we use multiple checks here because SerialParallelDecomposition's + // we use multiple checks here because SeriesParallelDecomposition's // ParallelSplit is unordered, so there are multiple possible // right-associative binary SP trees CHECK(is_binary_sp_tree_right_associative(result)); diff --git a/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc b/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc new file mode 100644 index 0000000000..7c2f9bba5b --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/digraph_generation.cc @@ -0,0 +1,114 @@ +#include "utils/graph/series_parallel/digraph_generation.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("digraph_from_sp_decomposition") { + SUBCASE("Empty") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition(ParallelSplit{{}}); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 0); + CHECK(get_edges(result).size() == 0); + } + + SUBCASE("Complex Empty") { + SeriesParallelDecomposition input = SeriesParallelDecomposition( + ParallelSplit{{SeriesSplit{{}}, SeriesSplit{{ParallelSplit{{}}}}}}); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 0); + CHECK(get_edges(result).size() == 0); + } + + SUBCASE("Single Node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition(Node(1)); + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 1); + CHECK(get_edges(result).size() == 0); + } + + SUBCASE("Simple SeriesSplit") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{Node(1), Node(2), Node(3)}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 3); + CHECK(get_edges(result).size() == 2); + CHECK(get_initial_nodes(result).size() == 1); + CHECK(get_terminal_nodes(result).size() == 1); + } + + SUBCASE("Simple ParallelSplit") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{Node(1), Node(2), Node(3)}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 3); + CHECK(get_edges(result).size() == 0); + CHECK(get_initial_nodes(result).size() == 3); + CHECK(get_terminal_nodes(result).size() == 3); + } + + SUBCASE("Mixed Series-Parallel") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{Node(1), Node(2)}}, + ParallelSplit{{Node(3), Node(4)}}}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(get_edges(result).size() == 4); + CHECK(get_initial_nodes(result).size() == 2); + CHECK(get_terminal_nodes(result).size() == 2); + } + + SUBCASE("Mixed Parallel-Series") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{SeriesSplit{{Node(1), Node(2)}}, + SeriesSplit{{Node(3), Node(4)}}}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(get_edges(result).size() == 2); + CHECK(get_initial_nodes(result).size() == 2); + CHECK(get_terminal_nodes(result).size() == 2); + } + + SUBCASE("Rhombus") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{Node{1}, ParallelSplit{{Node{2}, Node{3}}}, Node{4}}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(get_edges(result).size() == 4); + CHECK(get_initial_nodes(result).size() == 1); + CHECK(get_terminal_nodes(result).size() == 1); + } + + SUBCASE("Duplicate Nodes") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{Node(1), ParallelSplit{{Node(1), Node(2)}}, Node(1)}}}; + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 4); + CHECK(get_edges(result).size() == 4); + CHECK(get_initial_nodes(result).size() == 1); + CHECK(get_terminal_nodes(result).size() == 1); + } + + SUBCASE("Complex Graph") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{ + {ParallelSplit{{SeriesSplit{{ParallelSplit{{Node(1), Node(2)}}, + ParallelSplit{{Node(3), Node(4)}}, + Node(5)}}, + SeriesSplit{{Node(6), Node(7)}}}}, + Node(8)}}}; + + DiGraph result = digraph_from_sp_decomposition(input); + CHECK(num_nodes(result) == 8); + CHECK(get_edges(result).size() == 9); + CHECK(get_initial_nodes(result).size() == 3); + CHECK(get_terminal_nodes(result).size() == 1); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc b/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc new file mode 100644 index 0000000000..22e1fe51cf --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/get_ancestors.cc @@ -0,0 +1,78 @@ +#include "utils/graph/series_parallel/get_ancestors.h" +#include "utils/fmt/unordered_set.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_ancestors") { + std::vector n = { + Node{0}, Node{1}, Node{2}, Node{3}, Node{4}, Node{5}, Node{6}, Node{7}}; + + SUBCASE("Single Node") { + SeriesParallelDecomposition sp = SeriesParallelDecomposition{n.at(0)}; + std::unordered_set correct = {}; + std::unordered_set result = get_ancestors(sp, n.at(0)); + CHECK(correct == result); + } + + SUBCASE("Simple Series") { + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n.at(0), n.at(1), n.at(2)}}}; + std::unordered_set correct = {n.at(0), n.at(1)}; + std::unordered_set result = get_ancestors(sp, n.at(2)); + CHECK(correct == result); + } + + SUBCASE("Simple Parallel") { + SeriesParallelDecomposition sp = SeriesParallelDecomposition{ + ParallelSplit{{n.at(0), n.at(1), n.at(2)}}}; + std::unordered_set correct = {}; + std::unordered_set result = get_ancestors(sp, n.at(1)); + CHECK(correct == result); + } + + SUBCASE("Tree") { + SeriesParallelDecomposition sp = SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + ParallelSplit{{SeriesSplit{{n.at(1), n.at(2)}}, n.at(3)}}}}}; + std::unordered_set correct = {n.at(0), n.at(1)}; + std::unordered_set result = get_ancestors(sp, n.at(2)); + CHECK(correct == result); + } + + SUBCASE("Rhombus") { + SeriesParallelDecomposition sp = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; + std::unordered_set correct = {n.at(0), n.at(1), n.at(2)}; + std::unordered_set result = get_ancestors(sp, n.at(3)); + CHECK(correct == result); + } + + SUBCASE("Complex Structure") { + SeriesParallelDecomposition sp = SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + ParallelSplit{ + {SeriesSplit{ + {n.at(1), ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}, + SeriesSplit{{n.at(5), n.at(6)}}}}, + n.at(7)}}}; + std::unordered_set correct = {n.at(0), n.at(1), n.at(2), n.at(3)}; + std::unordered_set result = get_ancestors(sp, n.at(4)); + CHECK(correct == result); + + correct = {n.at(0), n.at(1)}; + result = get_ancestors(sp, n.at(3)); + CHECK(correct == result); + + correct = {n.at(0), n.at(5)}; + result = get_ancestors(sp, n.at(6)); + CHECK(correct == result); + + correct = {n.at(0), n.at(1), n.at(2), n.at(3), n.at(4), n.at(5), n.at(6)}; + result = get_ancestors(sp, n.at(7)); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc new file mode 100644 index 0000000000..5f41d77305 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc @@ -0,0 +1,139 @@ +#include "utils/graph/series_parallel/non_normal_sp_decomposition.h" +#include "doctest/doctest.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("is_empty_non_normal(NonNormalSPDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + + SUBCASE("Node Decomposition") { + NonNormalSPDecomposition sp = NonNormalSPDecomposition{n1}; + CHECK_FALSE(is_empty_non_normal(sp)); + } + + SUBCASE("Empty Series") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalSeriesSplit{ + std::vector>{}}}; + CHECK(is_empty_non_normal(sp)); + } + + SUBCASE("Empty Parallel") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalParallelSplit{{}}}; + CHECK(is_empty_non_normal(sp)); + } + + SUBCASE("Series with Node") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalSeriesSplit{{n1}}}; + CHECK_FALSE(is_empty_non_normal(sp)); + } + + SUBCASE("Parallel with Node") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalParallelSplit{{n1}}}; + CHECK_FALSE(is_empty_non_normal(sp)); + } + + SUBCASE("Nested Series") { + NonNormalSPDecomposition sp = NonNormalSPDecomposition{ + NonNormalSeriesSplit{{NonNormalParallelSplit{{}}}}}; + CHECK(is_empty_non_normal(sp)); + } + + SUBCASE("Nested Parallel") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalParallelSplit{{NonNormalSeriesSplit{ + std::vector>{}}}}}; + CHECK(is_empty_non_normal(sp)); + } + + SUBCASE("Sparse") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalSeriesSplit{ + {NonNormalParallelSplit{{}}, + NonNormalParallelSplit{{NonNormalSeriesSplit{std::vector< + std::variant>{}}}}}}}; + CHECK(is_empty_non_normal(sp)); + } + + SUBCASE("Sparse with Node") { + NonNormalSPDecomposition sp = + NonNormalSPDecomposition{NonNormalSeriesSplit{ + {NonNormalParallelSplit{{}}, + NonNormalParallelSplit{ + {NonNormalSeriesSplit{std::vector< + std::variant>{}}, + n2}}}}}; + CHECK_FALSE(is_empty_non_normal(sp)); + } + } + + TEST_CASE("as_non_normal(SeriesParallelDecomposition)") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + Node n4 = Node{4}; + + SUBCASE("Node") { + SeriesParallelDecomposition input = SeriesParallelDecomposition{n1}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = NonNormalSPDecomposition{n1}; + CHECK(result == correct); + } + + SUBCASE("SeriesSplit with Nodes") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{SeriesSplit{{n1, n2, n3}}}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = + NonNormalSPDecomposition{NonNormalSeriesSplit{{n1, n2, n3}}}; + CHECK(result == correct); + } + + SUBCASE("ParallelSplit with Nodes") { + SeriesParallelDecomposition input = + SeriesParallelDecomposition{ParallelSplit{{n1, n2}}}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = + NonNormalSPDecomposition{NonNormalParallelSplit{{n1, n2}}}; + CHECK(result == correct); + } + + SUBCASE("SeriesSplit containing ParallelSplit") { + // S(P(n1, n2), n3) + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{n1, n2}}, n3}}}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = NonNormalSPDecomposition{ + NonNormalSeriesSplit{{NonNormalParallelSplit{{n1, n2}}, n3}}}; + CHECK(result == correct); + } + + SUBCASE("ParallelSplit containing SeriesSplit") { + // P(S(n1, n2), n3) + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + ParallelSplit{{SeriesSplit{{n1, n2}}, n3}}}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = NonNormalSPDecomposition{ + NonNormalParallelSplit{{NonNormalSeriesSplit{{n1, n2}}, n3}}}; + CHECK(result == correct); + } + + SUBCASE("deeply nested") { + // S(P(S(n1, n2), n3), n4) + SeriesParallelDecomposition input = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{SeriesSplit{{n1, n2}}, n3}}, n4}}}; + NonNormalSPDecomposition result = as_non_normal(input); + NonNormalSPDecomposition correct = + NonNormalSPDecomposition{NonNormalSeriesSplit{ + {NonNormalParallelSplit{{NonNormalSeriesSplit{{n1, n2}}, n3}}, + n4}}}; + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc new file mode 100644 index 0000000000..b1777cf9b4 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -0,0 +1,76 @@ +#include "utils/graph/series_parallel/normalize_sp_decomposition.h" +#include "utils/graph/series_parallel/non_normal_parallel_split.dtg.h" +#include "utils/graph/series_parallel/non_normal_series_split.dtg.h" +#include "utils/graph/series_parallel/non_normal_sp_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("normalize_sp_decomposition") { + Node n1 = Node{1}; + Node n2 = Node{2}; + Node n3 = Node{3}; + + SUBCASE("Empty") { + NonNormalSPDecomposition input = NonNormalSPDecomposition{ + NonNormalSeriesSplit{ + {NonNormalParallelSplit{{}}, NonNormalParallelSplit{{}}}}}; + CHECK_THROWS_AS(normalize_sp_decomposition(input), std::runtime_error); + } + + SUBCASE("Node Decomposition") { + NonNormalSPDecomposition input = NonNormalSPDecomposition{n1}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Series with Single Node") { + NonNormalSPDecomposition input = + NonNormalSPDecomposition{NonNormalSeriesSplit{{n1}}}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Parallel with Single Node") { + NonNormalSPDecomposition input = + NonNormalSPDecomposition{NonNormalParallelSplit{{n1}}}; + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n1}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Mixed Series") { + NonNormalSPDecomposition input = NonNormalSPDecomposition{ + NonNormalSeriesSplit{{NonNormalParallelSplit{{n1}}, n2}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n1, n2}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Mixed Parallel") { + NonNormalSPDecomposition input = NonNormalSPDecomposition{ + NonNormalParallelSplit{{NonNormalSeriesSplit{{n1}}, n2}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n1, n2}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + + SUBCASE("Nested") { + NonNormalSPDecomposition input = NonNormalSPDecomposition{ + NonNormalParallelSplit{{NonNormalSeriesSplit{ + {NonNormalParallelSplit{{n1, n2}}}}, + n3, + NonNormalSeriesSplit{{}}}}}; + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{ParallelSplit{{n1, n2, n3}}}; + SeriesParallelDecomposition result = normalize_sp_decomposition(input); + CHECK(correct == result); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc index f5766c9fdd..544b290208 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -157,4 +157,5 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_multiset correct = {input}; CHECK(result == correct); } + } diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc new file mode 100644 index 0000000000..e3f09253ea --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/dependencies_are_maintained.cc @@ -0,0 +1,125 @@ +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/containers/get_only.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("dependencies_are_maintained") { + DiGraph g = DiGraph::create(); + SUBCASE("Single Node") { + std::vector n = add_nodes(g, 1); + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[0]}}}; + CHECK(dependencies_are_maintained(g, sp)); + } + + SUBCASE("SeriesSplit") { + SUBCASE("Valid SP-ization") { + std::vector n = add_nodes(g, 3); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[1], n[2]}}}; + CHECK(dependencies_are_maintained(g, sp)); + } + + SUBCASE("Incorrect SP-ization") { + std::vector n = add_nodes(g, 3); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{SeriesSplit{{n[1], n[0], n[2]}}}; + CHECK_FALSE(dependencies_are_maintained(g, sp)); + } + } + + SUBCASE("ParallelSplit") { + SUBCASE("Valid SP-ization") { + std::vector n = add_nodes(g, 3); + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{ParallelSplit{{n[0], n[1], n[2]}}}; + CHECK(dependencies_are_maintained(g, sp)); + } + + SUBCASE("Incorrect SP-ization") { + std::vector n = add_nodes(g, 3); + + SeriesParallelDecomposition sp = + SeriesParallelDecomposition{ParallelSplit{{n[0], n[2]}}}; + CHECK_FALSE(dependencies_are_maintained(g, sp)); + } + } + + SUBCASE("Rhombus") { + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}}); + SUBCASE("Valid SP-izations") { + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; + CHECK(dependencies_are_maintained(g, sp_correct)); + + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), n.at(1), n.at(2), n.at(3)}}}; + CHECK(dependencies_are_maintained(g, sp_correct)); + } + SUBCASE("Invalid SP-ization") { + SeriesParallelDecomposition sp_incorrect = SeriesParallelDecomposition{ + ParallelSplit{{n.at(0), SeriesSplit{{n.at(1), n.at(3)}}, n.at(2)}}}; + CHECK_FALSE(dependencies_are_maintained(g, sp_incorrect)); + } + } + + SUBCASE("Diamond") { + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}}); + + SUBCASE("Valid SP-izations") { + + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), + ParallelSplit{{n.at(1), n.at(2)}}, + ParallelSplit{{n.at(3), n.at(4)}}, + n.at(5)}}}; + CHECK(dependencies_are_maintained(g, sp_correct)); + + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), + n.at(1), + n.at(2), + ParallelSplit{{n.at(3), n.at(4)}}, + n.at(5)}}}; + CHECK(dependencies_are_maintained(g, sp_correct)); + + sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), + ParallelSplit{{n.at(1), n.at(2)}}, + n.at(3), + n.at(4), + n.at(5)}}}; + CHECK(dependencies_are_maintained(g, sp_correct)); + } + + SUBCASE("Invalid SP-izations") { + SeriesParallelDecomposition sp_correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), + ParallelSplit{{n.at(1), n.at(2), n.at(4)}}, + n.at(3), + n.at(5)}}}; + CHECK_FALSE(dependencies_are_maintained(g, sp_correct)); + } + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc new file mode 100644 index 0000000000..dfad2f69a5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -0,0 +1,359 @@ +#include "utils/graph/series_parallel/sp_ization/escribano_algo.h" +#include "test/utils/doctest/fmt/unordered_multiset.h" +#include "utils/containers/values.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("escribano_algo - subcomponents") { + SUBCASE("add_dummy_nodes") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + std::unordered_map node_types = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::PURE}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + }; + + DiGraph result = add_dummy_nodes(g, node_types); + CHECK(get_edges(result).size() == 6); + CHECK(get_nodes(result).size() == 6); + CHECK(get_incoming_edges(g, n.at(3)).size() == 2); + CHECK(get_outgoing_edges(g, n.at(0)).size() == 2); + + CHECK(node_types.size() == 6); + CHECK(values(node_types) == + std::unordered_multiset{NodeRole::PURE, + NodeRole::PURE, + NodeRole::PURE, + NodeRole::PURE, + NodeRole::DUMMY, + NodeRole::DUMMY}); + } + + SUBCASE("get_component") { + SUBCASE("2 layer graph, single simple component") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + }; + std::unordered_map depth_map = { + {n.at(0), 0_n}, + {n.at(2), 1_n}, + {n.at(3), 1_n}, + }; + std::unordered_set correct = {n.at(0), n.at(2), n.at(3)}; + std::unordered_set result = + get_component(g, n.at(2), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("2 layer graph, single complex component") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::PURE}, + {n.at(2), NodeRole::SYNC}, + {n.at(3), NodeRole::SYNC}, + {n.at(4), NodeRole::PURE}, + {n.at(5), NodeRole::PURE}, + }; + std::unordered_map depth_map = { + {n.at(0), 0_n}, + {n.at(1), 0_n}, + {n.at(4), 1_n}, + {n.at(5), 1_n}, + }; + SUBCASE("n.at(4)'s component") { + std::unordered_set correct = { + n.at(0), n.at(1), n.at(4), n.at(5)}; + std::unordered_set result = + get_component(g, n.at(4), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(5)'s component") { + std::unordered_set correct = { + n.at(0), n.at(1), n.at(4), n.at(5)}; + std::unordered_set result = + get_component(g, n.at(5), depth_map, node_roles); + CHECK(correct == result); + } + } + SUBCASE("3 layer graph, single connected component") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + DirectedEdge{n.at(4), n.at(6)}}); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::SYNC}, + {n.at(5), NodeRole::PURE}, + {n.at(6), NodeRole::PURE}}; + + std::unordered_map depth_map = {{n.at(0), 0_n}, + {n.at(2), 1_n}, + {n.at(3), 1_n}, + {n.at(5), 2_n}, + {n.at(6), 2_n}}; + SUBCASE("n.at(5)'s component") { + std::unordered_set correct = { + n.at(2), n.at(3), n.at(5), n.at(6)}; + std::unordered_set result = + get_component(g, n.at(5), depth_map, node_roles); + CHECK(correct == result); + } + + SUBCASE("n.at(6)'s component") { + std::unordered_set correct = { + n.at(2), n.at(3), n.at(5), n.at(6)}; + std::unordered_set result = + get_component(g, n.at(6), depth_map, node_roles); + CHECK(correct == result); + } + } + SUBCASE("3 layer graph, multiple weakly connected components") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 10); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(6)}, + DirectedEdge{n.at(4), n.at(6)}, + DirectedEdge{n.at(5), n.at(7)}, + DirectedEdge{n.at(5), n.at(8)}, + DirectedEdge{n.at(6), n.at(9)}, + }); + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::PURE}, + {n.at(5), NodeRole::SYNC}, + {n.at(6), NodeRole::SYNC}, + {n.at(7), NodeRole::PURE}, + {n.at(8), NodeRole::PURE}, + {n.at(9), NodeRole::PURE}, + }; + + std::unordered_map depth_map = {{n.at(0), 0_n}, + {n.at(2), 1_n}, + {n.at(3), 1_n}, + {n.at(4), 1_n}, + {n.at(7), 2_n}, + {n.at(8), 2_n}, + {n.at(9), 2_n}}; + SUBCASE("n.at(7)'s component") { + std::unordered_set correct = {n.at(2), n.at(7), n.at(8)}; + std::unordered_set result = + get_component(g, n.at(7), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(8)'s component") { + std::unordered_set correct = {n.at(2), n.at(7), n.at(8)}; + std::unordered_set result = + get_component(g, n.at(8), depth_map, node_roles); + CHECK(correct == result); + } + SUBCASE("n.at(9)'s component") { + std::unordered_set correct = {n.at(3), n.at(4), n.at(9)}; + std::unordered_set result = + get_component(g, n.at(9), depth_map, node_roles); + CHECK(correct == result); + } + } + } + } + + TEST_CASE("escribano_algorithm") { + + SUBCASE("Single Node") { + DiGraph g = DiGraph::create(); + Node n = g.add_node(); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{Node{n}}; + CHECK(sp == correct); + CHECK(dependencies_are_maintained(g, sp)); + } + SUBCASE("Linear Graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[1], n[2]}}); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[1], n[2]}}}; + CHECK(sp == correct); + } + + SUBCASE("Rhombus") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}}); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], ParallelSplit{{n[1], n[2]}}, n[3]}}}; + + CHECK(dependencies_are_maintained(g, sp)); + CHECK(correct == sp); + } + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], n[1], ParallelSplit{{n[2], n[3]}}, n[4], n[5]}}}; + CHECK(sp == correct); + } + + SUBCASE("Diamond without crossing") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[5]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}, + }); + + SeriesParallelDecomposition sp = escribano_sp_ization(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{SeriesSplit{{n[1], n[3], n[4]}}, n[2]}}, + n[5]}}}; + SeriesParallelDecomposition result = sp; + CHECK(correct == result); + } + + SUBCASE("Diamond Graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + CHECK(dependencies_are_maintained(g, sp)); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], + ParallelSplit{{n[1], n[2]}}, + ParallelSplit{{n[3], n[4]}}, + n[5]}}}; + CHECK(sp == correct); + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 10); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[6]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[8]}, + DirectedEdge{n[4], n[8]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[7], n[8]}, + DirectedEdge{n[6], n[9]}, + DirectedEdge{n[8], n[9]}}); + SeriesParallelDecomposition sp = escribano_sp_ization(g); + CHECK(dependencies_are_maintained(g, sp)); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{n[1], n[3]}}, + ParallelSplit{ + {SeriesSplit{{n[2], n[6]}}, + SeriesSplit{{ParallelSplit{ + {SeriesSplit{{n[5], n[7]}}, n[4]}}, + n[8]}}}}, + n[9]}}}; + CHECK(sp == correct); + } + } +} + diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc new file mode 100644 index 0000000000..53b1adb37a --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc @@ -0,0 +1,483 @@ +#include "utils/graph/series_parallel/sp_ization/flexible_algo.h" +#include "utils/containers/values.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/algorithms/get_incoming_edges.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_outgoing_edges.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("flexible_algo") { + SUBCASE("Single Node") { + DiGraph g = DiGraph::create(); + Node n0 = g.add_node(); + + std::unordered_map cost_map = {{n0, 1.0f}}; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{n0}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("Tri Node Graph") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[3]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], ParallelSplit{{n[1], n[2]}}, n[3]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("Series") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[2], n[3]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], n[1], n[2], n[3]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("6 Node Diamond Graph - constant cost map") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{{n[0], + ParallelSplit{{n[1], n[2]}}, + ParallelSplit{{n[3], n[4]}}, + n[5]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("6 Node Diamond Graph - cost map v2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 10.0f}, + {n[3], 1000.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + n[1], + ParallelSplit{{SeriesSplit{{n[2], n[4]}}, n[3]}}, + n[5]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("6 Node Diamond Graph - cost map v3") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1000.0f}, + {n[3], 1000.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{SeriesSplit{{n[1], n[3]}}, n[2]}}, + n[4], + n[5]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("With Parallel Strand") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 8); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[6]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[6], n[7]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + {n[6], 1.0f}, + {n[7], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{SeriesSplit{{n[1], n[3], n[5]}}, + SeriesSplit{{n[2], n[4], n[6]}}}}, + n[7]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("Simple With Parallel Strand") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 7); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[6]}, + DirectedEdge{n[4], n[6]}, + DirectedEdge{n[0], n[5]}, + DirectedEdge{n[5], n[6]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 100.0f}, + {n[3], 100.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + {n[6], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{ + {SeriesSplit{{ParallelSplit{ + {SeriesSplit{{n[1], n[3]}}, n[2]}}, + n[4]}}, + n[5]}}, + n[6]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("With Appendage - constant cost") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 8); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + DirectedEdge{n[0], n[6]}, + DirectedEdge{n[6], n[5]}, + DirectedEdge{n[2], n[7]}, + DirectedEdge{n[7], n[5]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 1.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + {n[6], 1.0f}, + {n[7], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{ + {n[6], + SeriesSplit{{ParallelSplit{{n[1], n[2]}}, + ParallelSplit{{n[3], n[4], n[7]}}}}}}, + n[5]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("With Appendage - weighted") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 8); + add_edges(g, + {DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + DirectedEdge{n[0], n[6]}, + DirectedEdge{n[6], n[5]}, + DirectedEdge{n[2], n[7]}, + DirectedEdge{n[7], n[5]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 100.0f}, + {n[2], 1.0f}, + {n[3], 1.0f}, + {n[4], 1.0f}, + {n[5], 1.0f}, + {n[6], 1.0f}, + {n[7], 100.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n[0], + ParallelSplit{ + {n[6], + SeriesSplit{ + {n[2], + ParallelSplit{ + {n[7], + SeriesSplit{ + {n[1], ParallelSplit{{n[3], n[4]}}}}}}}}}}, + n[5]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("Transitive Edge") { + DiGraph g2 = DiGraph::create(); + std::vector m = add_nodes(g2, 7); + add_edges(g2, + {DirectedEdge{m[0], m[1]}, + DirectedEdge{m[0], m[2]}, + DirectedEdge{m[1], m[3]}, + DirectedEdge{m[1], m[4]}, + DirectedEdge{m[2], m[5]}, + DirectedEdge{m[3], m[6]}, + DirectedEdge{m[4], m[5]}, + DirectedEdge{m[5], m[6]}}); + + std::unordered_map cost_map2 = { + {m[0], 1.0f}, + {m[1], 1.0f}, + {m[2], 1.0f}, + {m[3], 10.0f}, + {m[4], 1.0f}, + {m[5], 10.0f}, + {m[6], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g2, cost_map2); + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{m[0], + ParallelSplit{{SeriesSplit{{m[1], m[4]}}, m[2]}}, + ParallelSplit{{m[3], m[5]}}, + m[6]}}}; + + CHECK(result == correct); + CHECK(dependencies_are_maintained(g2, result)); + } + + SUBCASE("Graph From Paper - constant cost map") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 18); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[10]}, DirectedEdge{n[2], n[11]}, + DirectedEdge{n[2], n[12]}, DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[6]}, DirectedEdge{n[4], n[6]}, + DirectedEdge{n[4], n[7]}, DirectedEdge{n[4], n[10]}, + DirectedEdge{n[5], n[8]}, DirectedEdge{n[6], n[8]}, + DirectedEdge{n[6], n[9]}, DirectedEdge{n[7], n[8]}, + DirectedEdge{n[8], n[17]}, DirectedEdge{n[9], n[17]}, + DirectedEdge{n[10], n[9]}, DirectedEdge{n[10], n[16]}, + DirectedEdge{n[11], n[16]}, DirectedEdge{n[12], n[13]}, + DirectedEdge{n[12], n[14]}, DirectedEdge{n[13], n[15]}, + DirectedEdge{n[14], n[15]}, DirectedEdge{n[15], n[16]}, + DirectedEdge{n[16], n[17]}}); + + std::unordered_map cost_map; + for (int i = 0; i < 18; i++) { + cost_map[n[i]] = 1.0f; + } + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n[0], + ParallelSplit{ + {SeriesSplit{{n[1], ParallelSplit{{n[3], n[4]}}}}, + SeriesSplit{{n[2], ParallelSplit{{n[11], n[12]}}}}}}, + ParallelSplit{ + {SeriesSplit{{ParallelSplit{{n[10], n[5], n[6], n[7]}}, + ParallelSplit{{n[8], n[9]}}}}, + SeriesSplit{{ParallelSplit{{n[13], n[14]}}, n[15]}}}}, + n[16], + n[17]}}}; + + CHECK(dependencies_are_maintained(g, result)); + CHECK(result == correct); + } + + SUBCASE("Graph From Paper - non constant cost map") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 18); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[10]}, DirectedEdge{n[2], n[11]}, + DirectedEdge{n[2], n[12]}, DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[6]}, DirectedEdge{n[4], n[6]}, + DirectedEdge{n[4], n[7]}, DirectedEdge{n[4], n[10]}, + DirectedEdge{n[5], n[8]}, DirectedEdge{n[6], n[8]}, + DirectedEdge{n[6], n[9]}, DirectedEdge{n[7], n[8]}, + DirectedEdge{n[8], n[17]}, DirectedEdge{n[9], n[17]}, + DirectedEdge{n[10], n[16]}, DirectedEdge{n[11], n[16]}, + DirectedEdge{n[12], n[13]}, DirectedEdge{n[12], n[14]}, + DirectedEdge{n[13], n[15]}, DirectedEdge{n[14], n[15]}, + DirectedEdge{n[15], n[16]}, DirectedEdge{n[16], n[17]}}); + + std::unordered_map cost_map; + for (int i = 0; i < 18; i++) { + cost_map[n[i]] = 1.0f; + } + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + CHECK(dependencies_are_maintained(g, result)); + } + + SUBCASE("Additional Test Case") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 18); + add_edges(g, {DirectedEdge{n[0], n[1]}, DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[10]}, DirectedEdge{n[2], n[11]}, + DirectedEdge{n[2], n[12]}, DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[6]}, DirectedEdge{n[4], n[6]}, + DirectedEdge{n[4], n[7]}, DirectedEdge{n[4], n[10]}, + DirectedEdge{n[5], n[8]}, DirectedEdge{n[6], n[8]}, + DirectedEdge{n[6], n[9]}, DirectedEdge{n[7], n[8]}, + DirectedEdge{n[8], n[17]}, DirectedEdge{n[9], n[17]}, + DirectedEdge{n[10], n[16]}, DirectedEdge{n[11], n[16]}, + DirectedEdge{n[12], n[13]}, DirectedEdge{n[12], n[14]}, + DirectedEdge{n[13], n[15]}, DirectedEdge{n[14], n[15]}, + DirectedEdge{n[15], n[16]}, DirectedEdge{n[16], n[17]}}); + + std::unordered_map cost_map = { + {n[0], 1.0f}, + {n[1], 3.0f}, + {n[2], 5.0f}, + {n[3], 3.0f}, + {n[4], 5.0f}, + {n[5], 3.0f}, + {n[6], 5.0f}, + {n[7], 1.0f}, + {n[8], 5.0f}, + {n[9], 5.0f}, + {n[10], 3.0f}, + {n[11], 3.0f}, + {n[12], 1.0f}, + {n[13], 3.0f}, + {n[14], 3.0f}, + {n[15], 1.0f}, + {n[16], 1.0f}, + {n[17], 1.0f}, + }; + + SeriesParallelDecomposition result = flexible_sp_ization(g, cost_map); + CHECK(dependencies_are_maintained(g, result)); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc new file mode 100644 index 0000000000..042ca071ab --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc @@ -0,0 +1,156 @@ +#include "utils/graph/series_parallel/sp_ization/naive_stratum_sync.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("naive_stratum_sync") { + + SUBCASE("Sample Graph #1") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[4]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 2}, {n[3], 3}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 9); + CHECK(critical_path_cost(g, cost_map) == 7); + + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], n[1], ParallelSplit{{n[2], n[3]}}, n[4], n[5]}}}; + SeriesParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 7; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #2") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[2]}, + DirectedEdge{n[1], n[3]}, + DirectedEdge{n[2], n[5]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[4], n[5]}, + }); + + std::unordered_map cost_map = { + {n[0], 1}, {n[1], 1}, {n[2], 10}, {n[3], 1}, {n[4], 1}, {n[5], 1}}; + + CHECK(work_cost(g, cost_map) == 15); + CHECK(critical_path_cost(g, cost_map) == 12); + + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], ParallelSplit{{n[1], n[2]}}, n[3], n[4], n[5]}}}; + SeriesParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 14; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + + SUBCASE("Sample Graph #3") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 9); + add_edges(g, + { + DirectedEdge{n[0], n[1]}, + DirectedEdge{n[0], n[3]}, + DirectedEdge{n[1], n[2]}, + DirectedEdge{n[1], n[5]}, + DirectedEdge{n[1], n[4]}, + DirectedEdge{n[2], n[6]}, + DirectedEdge{n[3], n[4]}, + DirectedEdge{n[3], n[5]}, + DirectedEdge{n[3], n[8]}, + DirectedEdge{n[4], n[8]}, + DirectedEdge{n[5], n[7]}, + DirectedEdge{n[7], n[8]}, + }); + + std::unordered_map cost_map = {{n[0], 1}, + {n[1], 1}, + {n[2], 10}, + {n[3], 10}, + {n[4], 1}, + {n[5], 1}, + {n[6], 10}, + {n[7], 10}, + {n[8], 1}}; + + CHECK(work_cost(g, cost_map) == 45); + CHECK(critical_path_cost(g, cost_map) == 23); + + SeriesParallelDecomposition sp = stratum_sync_sp_ization(g); + + SUBCASE("structure") { + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n[0], + ParallelSplit{{n[1], n[3]}}, + ParallelSplit{{n[2], n[4], n[5]}}, + ParallelSplit{{n[6], n[7]}}, + n[8]}}}; + SeriesParallelDecomposition result = sp; + CHECK(correct == result); + } + SUBCASE("work cost") { + float correct = work_cost(g, cost_map); + float result = work_cost(sp, cost_map); + CHECK(correct == result); + } + + SUBCASE("critical path cost") { + float correct = 32; + float result = critical_path_cost(sp, cost_map); + CHECK(correct == result); + } + } + } +} + diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/node_role.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/node_role.cc new file mode 100644 index 0000000000..e56cd643e5 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/node_role.cc @@ -0,0 +1,72 @@ +#include "utils/graph/series_parallel/sp_ization/node_role.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_edges.h" +#include "utils/graph/digraph/directed_edge.dtg.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/node/node.dtg.h" +#include "utils/graph/series_parallel/sp_ization/node_role.dtg.h" +#include +#include + +using namespace FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("contract_out_nodes_of_given_role") { + SUBCASE("contract out dummy nodes") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + std::vector edges = { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }; + add_edges(g, edges); + + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::DUMMY}, + {n.at(2), NodeRole::DUMMY}, + {n.at(3), NodeRole::PURE}, + {n.at(4), NodeRole::PURE}, + }; + + DiGraph result = + contract_out_nodes_of_given_role(g, NodeRole::DUMMY, node_roles); + + CHECK(get_nodes(result) == + std::unordered_set{n.at(0), n.at(3), n.at(4)}); + CHECK(get_edges(result) == + std::unordered_set{DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(0), n.at(3)}, + DirectedEdge{n.at(3), n.at(4)}}); + } + + SUBCASE("contract out sync nodes") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + {DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}}); + + std::unordered_map node_roles = { + {n.at(0), NodeRole::PURE}, + {n.at(1), NodeRole::SYNC}, + {n.at(2), NodeRole::PURE}, + {n.at(3), NodeRole::PURE}, + }; + + DiGraph result = + contract_out_nodes_of_given_role(g, NodeRole::SYNC, node_roles); + + CHECK(get_nodes(result) == + std::unordered_set{n.at(0), n.at(2), n.at(3)}); + CHECK(get_edges(result) == + std::unordered_set{DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(3)}}); + } + } +} diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc new file mode 100644 index 0000000000..687511b439 --- /dev/null +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_spization.cc @@ -0,0 +1,245 @@ +#include "utils/graph/series_parallel/sp_ization/work_duplicating_spization.h" +#include "test/utils/rapidcheck.h" +#include "utils/containers/generate_map.h" +#include "utils/graph/algorithms.h" +#include "utils/graph/digraph/algorithms/get_initial_nodes.h" +#include "utils/graph/digraph/algorithms/get_terminal_nodes.h" +#include "utils/graph/digraph/digraph.h" +#include "utils/graph/instances/adjacency_digraph.h" +#include "utils/graph/node/algorithms.h" +#include "utils/graph/series_parallel/parallel_split.dtg.h" +#include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include "utils/graph/series_parallel/series_parallel_metrics.h" +#include "utils/graph/series_parallel/series_split.dtg.h" +#include + +using namespace FlexFlow; + +static std::pair> + generate_random_2_terminal_weighted_dag(int max_num_nodes = 10, + int max_num_edges = 20) { + assert(max_num_nodes >= 2); + + int num_nodes = *rc::gen::inRange(2, max_num_nodes); + + DiGraph g = DiGraph::create(); + std::vector nodes = add_nodes(g, num_nodes); + Node source = nodes.front(); + Node sink = nodes.back(); + + int num_edges = *rc::gen::inRange(0, max_num_edges + 1); + for (int i = 0; i < num_edges; i++) { + int src_idx = *rc::gen::inRange(0, num_nodes - 1); + int dst_idx = *rc::gen::inRange(src_idx + 1, num_nodes); + g.add_edge(DirectedEdge{nodes.at(src_idx), nodes.at(dst_idx)}); + } + + for (Node const &n : get_initial_nodes(g)) { + if (n != source) { + g.add_edge(DirectedEdge{source, n}); + } + } + for (Node const &n : get_terminal_nodes(g)) { + if (n != sink) { + g.add_edge(DirectedEdge{n, sink}); + } + } + + std::unordered_map cost_map = + generate_map(get_nodes(g), [](Node const &) { + return static_cast(*rc::gen::inRange(1, 101)); + }); + + return {g, cost_map}; +} + +TEST_SUITE(FF_TEST_SUITE) { + + TEST_CASE("naive_work_duplicating_spization") { + + SUBCASE("linear chain") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 3); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(1), n.at(2)}, + }); + + SeriesParallelDecomposition result = + naive_work_duplicating_spization(g); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{n.at(0), n.at(1), n.at(2)}}}; + CHECK(correct == result); + } + + SUBCASE("diamond") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + SeriesParallelDecomposition result = + naive_work_duplicating_spization(g); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{{SeriesSplit{{n.at(0), n.at(1)}}, + SeriesSplit{{n.at(0), n.at(2)}}}}, + n.at(3)}}}; + CHECK(correct == result); + } + + SUBCASE("parallel paths of different lengths") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + SeriesParallelDecomposition result = + naive_work_duplicating_spization(g); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{{ParallelSplit{ + {SeriesSplit{{n.at(0), n.at(1), n.at(3), n.at(4)}}, + SeriesSplit{{n.at(0), n.at(2)}}}}, + n.at(5)}}}; + CHECK(correct == result); + } + + RC_SUBCASE("critical path cost is preserved", + []() { + auto [g, cost_map] = + generate_random_2_terminal_weighted_dag(); + SeriesParallelDecomposition sp = + naive_work_duplicating_spization(g); + float original_cost = critical_path_cost(g, cost_map); + float sp_cost = critical_path_cost(sp, cost_map); + RC_ASSERT(original_cost == sp_cost); + }); + } + + TEST_CASE("work_duplicating_spization_with_coalescing") { + + SUBCASE("diamond") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 4); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(3)}, + }); + + SeriesParallelDecomposition result = + work_duplicating_spization_with_coalescing(g); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{ + {n.at(0), ParallelSplit{{n.at(1), n.at(2)}}, n.at(3)}}}; + CHECK(correct == result); + } + + SUBCASE("parallel paths of different lengths") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(5)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + SeriesParallelDecomposition result = + work_duplicating_spization_with_coalescing(g); + + SeriesParallelDecomposition correct = SeriesParallelDecomposition{ + SeriesSplit{ + {n.at(0), + ParallelSplit{ + {SeriesSplit{{n.at(1), n.at(3), n.at(4)}}, n.at(2)}}, + n.at(5)}}}; + CHECK(correct == result); + } + + SUBCASE("parallel strands with cross edges") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 6); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + DirectedEdge{n.at(3), n.at(5)}, + DirectedEdge{n.at(4), n.at(5)}, + }); + + SeriesParallelDecomposition result = + work_duplicating_spization_with_coalescing(g); + + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), + n.at(1), + ParallelSplit{ + {SeriesSplit{{ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}, + n.at(3)}}, + n.at(5)}}}; + CHECK(correct == result); + } + + SUBCASE("graph with transitive edges") { + DiGraph g = DiGraph::create(); + std::vector n = add_nodes(g, 5); + add_edges(g, + { + DirectedEdge{n.at(0), n.at(1)}, + DirectedEdge{n.at(0), n.at(2)}, + DirectedEdge{n.at(0), n.at(4)}, + DirectedEdge{n.at(1), n.at(2)}, + DirectedEdge{n.at(1), n.at(3)}, + DirectedEdge{n.at(1), n.at(4)}, + DirectedEdge{n.at(2), n.at(4)}, + DirectedEdge{n.at(3), n.at(4)}, + }); + + SeriesParallelDecomposition result = + work_duplicating_spization_with_coalescing(g); + + SeriesParallelDecomposition correct = + SeriesParallelDecomposition{SeriesSplit{ + {n.at(0), n.at(1), ParallelSplit{{n.at(2), n.at(3)}}, n.at(4)}}}; + CHECK(correct == result); + } + + RC_SUBCASE("critical path cost is preserved", + []() { + auto [g, cost_map] = + generate_random_2_terminal_weighted_dag(); + SeriesParallelDecomposition sp = + work_duplicating_spization_with_coalescing(g); + float original_cost = critical_path_cost(g, cost_map); + float sp_cost = critical_path_cost(sp, cost_map); + RC_ASSERT(original_cost == sp_cost); + }); + } +}