From 6a74af08019005583a7a3206f72175cd52e50ca2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 11:31:04 -0800 Subject: [PATCH 01/63] Add make_dynamic_open_dataflow_graph_from_pcg. --- .../parallel_computation_graph.h | 6 ++ .../parallel_computation_graph.cc | 21 +++++ ...ake_dynamic_open_dataflow_graph_from_pcg.h | 14 ++++ ...ke_dynamic_open_dataflow_graph_from_pcg.cc | 77 +++++++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 25dc0721cd..3d948ac107 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -54,6 +54,9 @@ std::unordered_map std::unordered_set get_initial_layers(ParallelComputationGraph const &); +std::unordered_map + get_outgoing_tensors(ParallelComputationGraph const &, + parallel_layer_guid_t const &); std::unordered_map get_incoming_tensors(ParallelComputationGraph const &, parallel_layer_guid_t const &); @@ -107,6 +110,9 @@ ParallelTensorShape get_parallel_tensor_shape(ParallelComputationGraph const &, std::vector topological_ordering(ParallelComputationGraph const &); +std::unordered_map + get_parallel_layer_attrs_mapping(ParallelComputationGraph const &pcg); + parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index f83628b8e1..907dc05620 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -212,6 +212,16 @@ std::unordered_set [](Node const &n) { return parallel_layer_guid_t{n}; }); } +std::unordered_map + get_outgoing_tensors(ParallelComputationGraph const &pcg, + parallel_layer_guid_t const &l) { + return map_values(get_outgoing_kwarg_dataflow_outputs_for_node( + pcg.raw_graph, l.raw_graph_node), + [](KwargDataflowOutput const &o) { + return parallel_tensor_guid_t{o}; + }); +} + std::unordered_map get_incoming_tensors(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { @@ -378,6 +388,17 @@ std::vector [](Node const &n) { return parallel_layer_guid_t{n}; }); } +std::unordered_map + get_parallel_layer_attrs_mapping(ParallelComputationGraph const &pcg) { + std::unordered_map + layer_attrs_mapping; + for (parallel_layer_guid_t const &layer_guid : get_parallel_layers(pcg)) { + layer_attrs_mapping.insert( + {layer_guid, get_parallel_layer_attrs(pcg, layer_guid)}); + } + return layer_attrs_mapping; +} + parallel_layer_guid_t get_parallel_layer_by_name(ParallelComputationGraph const &pcg, std::string const &name) { diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h new file mode 100644 index 0000000000..a71eb558c1 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H + +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph + make_dynamic_open_dataflow_graph_from_pcg(ParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc new file mode 100644 index 0000000000..841be27dfd --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc @@ -0,0 +1,77 @@ +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/containers/generate_map.h" +#include +#include +#include + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( + ParallelComputationGraph const &pcg) { + DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); + + for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/attrs.op_attrs, + /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, + /*per_device_op_state=*/std::nullopt, + }; + + std::unordered_map result_inputs = + transform(get_incoming_tensors(pcg, layer), + [&](TensorSlotName const &slot_name, + parallel_tensor_guid_t const &tensor) { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + return std::pair{ + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }; + }); + std::unordered_map result_outputs = + transform(get_outgoing_tensors(pcg, layer), + [&](TensorSlotName const &slot_name, + parallel_tensor_guid_t const &tensor) { + ParallelTensorAttrs attrs = + get_parallel_tensor_attrs(pcg, tensor); + return std::pair{ + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, + /*parallel_tensor_shape=*/attrs.shape, + /*shard_coord=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }; + }); + + result.invocations.emplace(result_inputs, result_attrs, result_outputs); + } + + return result; +} + +} // namespace FlexFlow From 587e08e80006a6533591c404075e12ee86c1ec82 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:10:13 -0800 Subject: [PATCH 02/63] Empty skeleton of the realm-execution backend. --- .proj.toml | 7 +++++++ lib/CMakeLists.txt | 1 + lib/realm-execution/CMakeLists.txt | 21 +++++++++++++++++++ .../parallel_computation_graph_instance.h | 12 +++++++++++ .../parallel_computation_graph_instance.cc | 1 + lib/realm-execution/test/CMakeLists.txt | 15 +++++++++++++ .../test/src/realm-execution/test_e2e.cc | 9 ++++++++ 7 files changed, 66 insertions(+) create mode 100644 lib/realm-execution/CMakeLists.txt create mode 100644 lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h create mode 100644 lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc create mode 100644 lib/realm-execution/test/CMakeLists.txt create mode 100644 lib/realm-execution/test/src/realm-execution/test_e2e.cc diff --git a/.proj.toml b/.proj.toml index 38690f710b..5dbbfbcdd7 100644 --- a/.proj.toml +++ b/.proj.toml @@ -85,6 +85,13 @@ has-cpu-only-benchmarks = false has-cuda-tests = true has-cuda-benchmarks = false +[targets.realm-execution] +type = "lib" +has-cpu-only-tests = true +has-cpu-only-benchmarks = false +has-cuda-tests = true +has-cuda-benchmarks = false + # [targets.local-pcg-execution] # type = "lib" # has-cpu-only-tests = true diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 2e71e577c0..cb3bd6d6ae 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(op-attrs) add_subdirectory(kernels) add_subdirectory(local-execution) add_subdirectory(local-pcg-execution) +add_subdirectory(realm-execution) add_subdirectory(task-spec) add_subdirectory(utils) add_subdirectory(ffi) diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt new file mode 100644 index 0000000000..7a38f70607 --- /dev/null +++ b/lib/realm-execution/CMakeLists.txt @@ -0,0 +1,21 @@ +ff_add_library( + NAME + realm-execution + SRC_PATTERNS + src/*.cc + PUBLIC_INCLUDE + include/ + PRIVATE_INCLUDE + src/ + DEPS + op-attrs + utils + kernels + task-spec + pcg + spdlog + compiler + local-execution +) + +add_subdirectory(test) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h new file mode 100644 index 0000000000..58cc5234d9 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H + +namespace FlexFlow { + +struct ParallelComputationGraphInstance { + public: +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc new file mode 100644 index 0000000000..a22f4730b7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -0,0 +1 @@ +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" diff --git a/lib/realm-execution/test/CMakeLists.txt b/lib/realm-execution/test/CMakeLists.txt new file mode 100644 index 0000000000..b3beff42c0 --- /dev/null +++ b/lib/realm-execution/test/CMakeLists.txt @@ -0,0 +1,15 @@ +ff_add_test_executable( + NAME + realm-execution-tests + SRC_PATTERNS + src/*.cc + PRIVATE_INCLUDE + src/ + DEPS + doctest + utils-test-common + realm-execution + kernels + op-attrs + task-spec +) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc new file mode 100644 index 0000000000..55dfe427d5 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -0,0 +1,9 @@ +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training") { + } +} From 50f6ec6b8dc3df127074df339d54631d3dc33fc8 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:25:52 -0800 Subject: [PATCH 03/63] More Realm execution skeleton. --- .../parallel_computation_graph_instance.h | 52 ++++++++++++++++++- .../parallel_computation_graph_instance.cc | 45 ++++++++++++++++ .../test/src/realm-execution/test_e2e.cc | 3 +- 3 files changed, 97 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 58cc5234d9..b0529761c1 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -1,12 +1,62 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#include "kernels/accessor.h" +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" +#include "utils/units/milliseconds_t.h" +#include + namespace FlexFlow { struct ParallelComputationGraphInstance { - public: +public: + ParallelComputationGraphInstance(DynamicOpenDataflowGraph, + Allocator &, + std::vector const &, + OptimizerAttrs const &, + std::optional const &, + std::optional); + DynamicOpenDataflowGraph const &get_dynamic_dataflow_graph() const; + Allocator &get_allocator() const; + std::vector const &get_topological_ordering() const; + OptimizerAttrs const &get_optimizer_attrs() const; + void update_optimizer_attrs_for_next_iter(); + std::optional const &get_loss_attrs() const; + std::optional get_loss_tensor_accessor() const; + +private: + DynamicOpenDataflowGraph dataflow_graph; + Allocator &allocator; + std::vector topological_ordering; + OptimizerAttrs optimizer_attrs; + std::optional loss_attrs; + std::optional logit_grad_tensor; }; +ParallelComputationGraphInstance create_parallel_computation_graph_instance( + ParallelComputationGraph const &pcg, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional label_tensor, + std::optional logit_tensor, + std::unordered_map const + &input_tensors, + Allocator &allocator, + ProfilingSettings const &profiling_settings, + device_handle_t const &device_handle, + FFIterationConfig const &iteration_config, + device_id_t device_idx); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index a22f4730b7..2f001a2975 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1 +1,46 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "pcg/optimizer_attrs.h" + +namespace FlexFlow { + +ParallelComputationGraphInstance::ParallelComputationGraphInstance( + DynamicOpenDataflowGraph dataflow_graph, + Allocator &allocator, + std::vector const &topological_ordering, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional logit_grad_tensor) + : dataflow_graph(dataflow_graph), allocator(allocator), + topological_ordering(topological_ordering), + optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), + logit_grad_tensor(logit_grad_tensor) {} + +DynamicOpenDataflowGraph const & + ParallelComputationGraphInstance::get_dynamic_dataflow_graph() const { + return this->dataflow_graph; +} +Allocator &ParallelComputationGraphInstance::get_allocator() const { + return this->allocator; +} +std::vector const & + ParallelComputationGraphInstance::get_topological_ordering() const { + return this->topological_ordering; +} +OptimizerAttrs const & + ParallelComputationGraphInstance::get_optimizer_attrs() const { + return this->optimizer_attrs; +} +void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { + this->optimizer_attrs = + get_optimizer_attrs_for_next_iter(this->optimizer_attrs); +} +std::optional const & + ParallelComputationGraphInstance::get_loss_attrs() const { + return this->loss_attrs; +} +std::optional + ParallelComputationGraphInstance::get_loss_tensor_accessor() const { + return this->logit_grad_tensor; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 55dfe427d5..78a57fb99f 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -4,6 +4,5 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("RealmBackend e2e Training") { - } + TEST_CASE("RealmBackend e2e Training") {} } From 984aae5640ed6535bba64c9a69c9363ec7358d46 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 14:59:55 -0800 Subject: [PATCH 04/63] Stub creation. --- .../parallel_computation_graph_instance.cc | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 2f001a2975..29683c4dba 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,5 +1,6 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include "pcg/optimizer_attrs.h" +#include "utils/exception.h" namespace FlexFlow { @@ -43,4 +44,20 @@ std::optional return this->logit_grad_tensor; } +ParallelComputationGraphInstance create_parallel_computation_graph_instance( + ParallelComputationGraph const &pcg, + OptimizerAttrs const &optimizer_attrs, + std::optional const &loss_attrs, + std::optional label_tensor, + std::optional logit_tensor, + std::unordered_map const + &input_tensors, + Allocator &allocator, + ProfilingSettings const &profiling_settings, + device_handle_t const &device_handle, + FFIterationConfig const &iteration_config, + device_id_t device_idx) { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow From e9e11059136e9b3e14df1e48bbf24517aeffe7e8 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 15:17:00 -0800 Subject: [PATCH 05/63] More passes. --- .../parallel_computation_graph_instance.cc | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 29683c4dba..8f878c90d8 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,5 +1,12 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "local-execution/device_state_initialization.h" +#include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/loss_insertion.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" namespace FlexFlow { @@ -44,6 +51,15 @@ std::optional return this->logit_grad_tensor; } +static GenericTensorAccessorW + get_loss_tensor_accessor(DynamicOpenDataflowGraph const &dg, + DynamicValueAttrs const &value) { + return find_output_tensor(dg, value.tensor_guid, value.role) + .value() + .second.accessor.value() + .get(); +} + ParallelComputationGraphInstance create_parallel_computation_graph_instance( ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, @@ -57,6 +73,36 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( device_handle_t const &device_handle, FFIterationConfig const &iteration_config, device_id_t device_idx) { + + DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); + dg = perform_pass_expansion(dg); + + std::unordered_map inputs = + input_tensors; + std::optional logit_grad_value; + if (loss_attrs) { + auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( + dg, assert_unwrap(loss_attrs), assert_unwrap(logit_tensor)); + dg = dg2; + logit_grad_value = logit_grad_v; + inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); + } + + dg = perform_update_insertion(dg, optimizer_attrs); + dg = perform_tensor_allocation(dg, inputs, allocator); + + std::optional logit_grad_tensor = + transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { + return get_loss_tensor_accessor(dg, lgv); + }); + + dg = perform_device_state_initialization(dg, + allocator, + profiling_settings, + device_handle, + iteration_config, + optimizer_attrs, + device_idx); NOT_IMPLEMENTED(); } From 90b92c789aa5185e16d3c7b70d196fe9091bb006 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 16:39:30 -0800 Subject: [PATCH 06/63] Add Realm manager and test it. --- lib/realm-execution/CMakeLists.txt | 11 ++++---- .../include/realm-execution/realm_manager.h | 27 +++++++++++++++++++ .../src/realm-execution/realm_manager.cc | 22 +++++++++++++++ .../test/src/realm-execution/test_e2e.cc | 10 ++++++- 4 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_manager.h create mode 100644 lib/realm-execution/src/realm-execution/realm_manager.cc diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt index 7a38f70607..0a1b681b8d 100644 --- a/lib/realm-execution/CMakeLists.txt +++ b/lib/realm-execution/CMakeLists.txt @@ -8,14 +8,15 @@ ff_add_library( PRIVATE_INCLUDE src/ DEPS - op-attrs - utils + compiler kernels - task-spec + local-execution + op-attrs pcg + realm spdlog - compiler - local-execution + task-spec + utils ) add_subdirectory(test) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h new file mode 100644 index 0000000000..a08668e6cc --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -0,0 +1,27 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H + +#include "realm.h" + +namespace FlexFlow { + +struct RealmManager { +public: + RealmManager(int *argc, char ***argv); + + RealmManager() = delete; + RealmManager(RealmManager const &) = delete; + RealmManager(RealmManager &&) = delete; + + Realm::Runtime get_runtime(); + void shutdown(); + int wait_for_shutdown(); + +private: + Realm::Runtime runtime; + Realm::Event last_event = Realm::Event::NO_EVENT; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc new file mode 100644 index 0000000000..5a085bc04b --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -0,0 +1,22 @@ +#include "realm-execution/realm_manager.h" +#include "utils/exception.h" + +namespace FlexFlow { + +RealmManager::RealmManager(int *argc, char ***argv) { + bool ok = this->runtime.init(argc, argv); + ASSERT(ok); +} + +Realm::Runtime RealmManager::get_runtime() { + return this->runtime; +} + +void RealmManager::shutdown() { + this->runtime.shutdown(this->last_event); +} + +int RealmManager::wait_for_shutdown() { + return this->runtime.wait_for_shutdown(); +} +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 78a57fb99f..947a02e6be 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,8 +1,16 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/realm_manager.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("RealmBackend e2e Training") {} + TEST_CASE("RealmBackend e2e Training") { + char fake_executable_name[] = "fake_executable_name"; + std::vector fake_args{fake_executable_name}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + RealmManager manager(&fake_argc, &fake_argv); + manager.shutdown(); + } } From ef92d6fb622f06246315e11810a8fb45f1dc8f72 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 16:45:57 -0800 Subject: [PATCH 07/63] Do not expose raw runtime and properly wait in test. --- lib/realm-execution/include/realm-execution/realm_manager.h | 1 - lib/realm-execution/src/realm-execution/realm_manager.cc | 5 +---- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 2 ++ 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index a08668e6cc..f9fa9f7de7 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -13,7 +13,6 @@ struct RealmManager { RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - Realm::Runtime get_runtime(); void shutdown(); int wait_for_shutdown(); diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 5a085bc04b..014a16718a 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -8,10 +8,6 @@ RealmManager::RealmManager(int *argc, char ***argv) { ASSERT(ok); } -Realm::Runtime RealmManager::get_runtime() { - return this->runtime; -} - void RealmManager::shutdown() { this->runtime.shutdown(this->last_event); } @@ -19,4 +15,5 @@ void RealmManager::shutdown() { int RealmManager::wait_for_shutdown() { return this->runtime.wait_for_shutdown(); } + } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 947a02e6be..b88807e079 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -12,5 +12,7 @@ TEST_SUITE(FF_TEST_SUITE) { char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); manager.shutdown(); + int result = manager.wait_for_shutdown(); + ASSERT(result == 0); } } From 01e23cd2ba11c8ab3e18de54a6246501f669258f Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Wed, 4 Feb 2026 17:11:02 -0800 Subject: [PATCH 08/63] Sketch more Realm manager APIs. --- .../parallel_computation_graph_instance.h | 13 ++++++------ .../include/realm-execution/realm_manager.h | 8 +++++++ .../parallel_computation_graph_instance.cc | 21 +++++++++---------- .../src/realm-execution/realm_manager.cc | 11 ++++++++++ 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index b0529761c1..4ba77a7925 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" +#include "realm-execution/realm_manager.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -20,8 +21,8 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(DynamicOpenDataflowGraph, - Allocator &, + ParallelComputationGraphInstance(RealmManager &, + DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, std::optional const &, @@ -35,8 +36,8 @@ struct ParallelComputationGraphInstance { std::optional get_loss_tensor_accessor() const; private: + RealmManager &realm; DynamicOpenDataflowGraph dataflow_graph; - Allocator &allocator; std::vector topological_ordering; OptimizerAttrs optimizer_attrs; std::optional loss_attrs; @@ -44,6 +45,7 @@ struct ParallelComputationGraphInstance { }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( + RealmManager &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -51,11 +53,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_tensor, std::unordered_map const &input_tensors, - Allocator &allocator, ProfilingSettings const &profiling_settings, - device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, - device_id_t device_idx); + FFIterationConfig const &iteration_config); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index f9fa9f7de7..9261bc91f4 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -1,6 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_MANAGER_H +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "pcg/device_id_t.dtg.h" #include "realm.h" namespace FlexFlow { @@ -16,6 +19,11 @@ struct RealmManager { void shutdown(); int wait_for_shutdown(); + Allocator &get_current_device_allocator() const; + + device_handle_t const &get_current_device_handle() const; + device_id_t const &get_current_device_idx() const; + private: Realm::Runtime runtime; Realm::Event last_event = Realm::Event::NO_EVENT; diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 8f878c90d8..64c9da2f4c 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -12,13 +12,13 @@ namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( + RealmManager &realm, DynamicOpenDataflowGraph dataflow_graph, - Allocator &allocator, std::vector const &topological_ordering, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional logit_grad_tensor) - : dataflow_graph(dataflow_graph), allocator(allocator), + : realm(realm), dataflow_graph(dataflow_graph), topological_ordering(topological_ordering), optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), logit_grad_tensor(logit_grad_tensor) {} @@ -28,7 +28,7 @@ DynamicOpenDataflowGraph const & return this->dataflow_graph; } Allocator &ParallelComputationGraphInstance::get_allocator() const { - return this->allocator; + return this->realm.get_current_device_allocator(); } std::vector const & ParallelComputationGraphInstance::get_topological_ordering() const { @@ -61,6 +61,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( + RealmManager &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -68,11 +69,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_tensor, std::unordered_map const &input_tensors, - Allocator &allocator, ProfilingSettings const &profiling_settings, - device_handle_t const &device_handle, - FFIterationConfig const &iteration_config, - device_id_t device_idx) { + FFIterationConfig const &iteration_config) { DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); dg = perform_pass_expansion(dg); @@ -89,7 +87,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( } dg = perform_update_insertion(dg, optimizer_attrs); - dg = perform_tensor_allocation(dg, inputs, allocator); + dg = perform_tensor_allocation( + dg, inputs, realm.get_current_device_allocator()); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -97,12 +96,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( }); dg = perform_device_state_initialization(dg, - allocator, + realm.get_current_device_allocator(), profiling_settings, - device_handle, + realm.get_current_device_handle(), iteration_config, optimizer_attrs, - device_idx); + realm.get_current_device_idx()); NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 014a16718a..b136b4c379 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -16,4 +16,15 @@ int RealmManager::wait_for_shutdown() { return this->runtime.wait_for_shutdown(); } +Allocator &RealmManager::get_current_device_allocator() const { + NOT_IMPLEMENTED(); +} + +device_handle_t const &RealmManager::get_current_device_handle() const { + NOT_IMPLEMENTED(); +} +device_id_t const &RealmManager::get_current_device_idx() const { + NOT_IMPLEMENTED(); +} + } // namespace FlexFlow From 3e7d841ed71607bb3209a7ae4b1bafddce8d94c6 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 09:57:00 -0800 Subject: [PATCH 09/63] Add controller functionality. --- .../include/realm-execution/realm_manager.h | 17 ++++-- .../src/realm-execution/realm_manager.cc | 60 +++++++++++++++++-- .../test/src/realm-execution/test_e2e.cc | 4 +- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 9261bc91f4..497a1f3958 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -11,22 +11,31 @@ namespace FlexFlow { struct RealmManager { public: RealmManager(int *argc, char ***argv); + ~RealmManager(); RealmManager() = delete; RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - void shutdown(); - int wait_for_shutdown(); + Realm::Event start_controller(void (*thunk)(RealmManager &)); + // Current device context Allocator &get_current_device_allocator() const; - device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; +private: + RealmManager(void const *, size_t, void const *, size_t, Realm::Processor); + + [[nodiscard]] Realm::Event merge_outstanding_events(); + + static void controller_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor); + private: Realm::Runtime runtime; - Realm::Event last_event = Realm::Event::NO_EVENT; + std::vector outstanding_events; + bool is_root_runtime; }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index b136b4c379..acc11936c7 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -3,17 +3,48 @@ namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) { +RealmManager::RealmManager(int *argc, char ***argv) : is_root_runtime(true) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); } -void RealmManager::shutdown() { - this->runtime.shutdown(this->last_event); +RealmManager::RealmManager(void const *args, + size_t arglen, + void const *userdata, + size_t userdatalen, + Realm::Processor proc) + : runtime(Realm::Runtime::get_runtime()), is_root_runtime(false) {} + +RealmManager::~RealmManager() { + Realm::Event outstanding = this->merge_outstanding_events(); + if (is_root_runtime) { + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); + } else { + outstanding.wait(); + } } -int RealmManager::wait_for_shutdown() { - return this->runtime.wait_for_shutdown(); +Realm::Event RealmManager::start_controller(void (*thunk)(RealmManager &)) { + constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; + Realm::Event task_ready = Realm::Processor::register_task_by_kind( + Realm::Processor::LOC_PROC, + /*global=*/false, + CONTROLLER_TASK_ID, + Realm::CodeDescriptor(RealmManager::controller_task_wrapper), + Realm::ProfilingRequestSet(), + &thunk, + sizeof(thunk)); + + Realm::Processor target_proc = + Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) + .only_kind(Realm::Processor::LOC_PROC) + .first(); + + Realm::Event task_complete = this->runtime.collective_spawn( + target_proc, CONTROLLER_TASK_ID, &thunk, sizeof(thunk), task_ready); + this->outstanding_events.push_back(task_complete); + return task_complete; } Allocator &RealmManager::get_current_device_allocator() const { @@ -27,4 +58,23 @@ device_id_t const &RealmManager::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event RealmManager::merge_outstanding_events() { + Realm::Event result = Realm::Event::merge_events(this->outstanding_events); + this->outstanding_events.clear(); + return result; +} + +void RealmManager::controller_task_wrapper(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + assert(arglen == sizeof(void (*)(RealmManager &))); + void (*thunk)(RealmManager &) = + *reinterpret_cast(const_cast(args)); + + RealmManager manager(args, arglen, userdata, userlen, proc); + thunk(manager); +} + } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index b88807e079..f09951e73c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,8 +11,6 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - manager.shutdown(); - int result = manager.wait_for_shutdown(); - ASSERT(result == 0); + manager.start_controller([](RealmManager &manager) {}); } } From b9a30a623b3870daa1a4292d1e138c6e949013a0 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 12:29:19 -0800 Subject: [PATCH 10/63] Fix Realm tests. --- .flake/pkgs/legion.nix | 48 ------------------- .flake/pkgs/realm.nix | 44 +++++++++++++++++ flake.nix | 21 ++++---- lib/realm-execution/CMakeLists.txt | 2 +- .../src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/realm_manager.cc | 22 +++++++++ 6 files changed, 78 insertions(+), 61 deletions(-) delete mode 100644 .flake/pkgs/legion.nix create mode 100644 .flake/pkgs/realm.nix create mode 100644 lib/realm-execution/test/src/realm-execution/realm_manager.cc diff --git a/.flake/pkgs/legion.nix b/.flake/pkgs/legion.nix deleted file mode 100644 index 361a66c4ff..0000000000 --- a/.flake/pkgs/legion.nix +++ /dev/null @@ -1,48 +0,0 @@ -{ lib -, stdenv -, fetchFromGitLab -, cmake -, cudaPackages ? { } -, cudaCapabilities ? [ "60" "70" "80" "86" ] -, maxDim ? 5 -}: - -# from https://codeberg.org/Uli/nix-things/src/commit/776519e382c81b136c1d0b10d8c7b52b4acb9192/overlays/cq/python/libclang-python.nix - -let - cmakeFlag = x: if x then "1" else "0"; - - inherit (cudaPackages) cudatoolkit; -in - -stdenv.mkDerivation rec { - pname = "legion"; - version = "2025-01-06"; - - src = fetchFromGitLab { - owner = "StanfordLegion"; - repo = "legion"; - rev = "7be1abd0207eb1126c7629b16d1123fa6f58ce9d"; - sha256 = "sha256-gTjnGYYTQwTsrV1WcY0qqpTrlwbzAPcndurRy6XnG8A="; - }; - - nativeBuildInputs = [ - cmake - ]; - - cmakeFlags = [ - "-DLegion_USE_CUDA=1" - "-DLegion_CUDA_ARCH=${lib.concatStringsSep "," cudaCapabilities}" - "-DLegion_MAX_DIM=${toString maxDim}" - ]; - - buildInputs = [ - cudatoolkit - ]; - - meta = with lib; { - description = "Legion is a parallel programming model for distributed, heterogeneous machines"; - homepage = "https://legion.stanford.edu/"; - license = licenses.asl20; - }; -} diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix new file mode 100644 index 0000000000..1249c0ae28 --- /dev/null +++ b/.flake/pkgs/realm.nix @@ -0,0 +1,44 @@ +{ lib +, stdenv +, fetchFromGitHub +, cmake +, cudaPackages ? { } +, maxDim ? 5 +}: + +let + inherit (cudaPackages) cudatoolkit; +in + +stdenv.mkDerivation rec { + pname = "realm"; + version = "2025-01-06"; + + # This version is compatible with Legion 7be1abd0207eb1126c7629b16d1123fa6f58ce9d + src = fetchFromGitHub { + owner = "StanfordLegion"; + repo = "realm"; + rev = "0ef7edc8c012d4ab6a50805c044cec8a8edeae33"; + sha256 = "sha256-57/a1lAgs+ajpRn0y0Lk1gP5nKt+N08WW0DIJP4vdho="; + }; + + nativeBuildInputs = [ + cmake + ]; + + cmakeFlags = [ + "-DBUILD_SHARED_LIBS=ON" + "-DREALM_ENABLE_CUDA=ON" + "-DREALM_MAX_DIM=${toString maxDim}" + ]; + + buildInputs = [ + cudatoolkit + ]; + + meta = with lib; { + description = "Realm is a distributed, event–based tasking runtime for building high-performance applications that span clusters of CPUs, GPUs, and other accelerators"; + homepage = "https://legion.stanford.edu/realm"; + license = licenses.asl20; + }; +} diff --git a/flake.nix b/flake.nix index 6ccd5616cd..dad0e2fc32 100644 --- a/flake.nix +++ b/flake.nix @@ -30,8 +30,8 @@ }; }; - outputs = { self, nixpkgs, flake-utils, proj-repo, nixGL, ... }: flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: - let + outputs = { self, nixpkgs, flake-utils, proj-repo, nixGL, ... }: flake-utils.lib.eachSystem [ "x86_64-linux" ] (system: + let pkgs = import nixpkgs { inherit system; config.allowUnfree = true; @@ -41,21 +41,21 @@ mkShell = attrs: pkgs.mkShell.override { stdenv = pkgs.cudaPackages.backendStdenv; } (attrs // { - hardeningDisable = ["all"]; # disable nixpkgs default compiler arguments, otherwise ubsan doesn't catch - # signed overflows due to the signedoverflow hardening setting. - # for more details, see the following (long-running) nixpkgs github issues: + hardeningDisable = ["all"]; # disable nixpkgs default compiler arguments, otherwise ubsan doesn't catch + # signed overflows due to the signedoverflow hardening setting. + # for more details, see the following (long-running) nixpkgs github issues: # - https://github.com/NixOS/nixpkgs/issues/18995 # - https://github.com/NixOS/nixpkgs/issues/60919 }); proj = proj-repo.packages.${system}.proj; - in + in { packages = rec { libdwarf-lite = pkgs.callPackage ./.flake/pkgs/libdwarf-lite.nix { }; cpptrace = pkgs.callPackage ./.flake/pkgs/cpptrace.nix { inherit libdwarf-lite; }; libassert = pkgs.callPackage ./.flake/pkgs/libassert.nix { inherit cpptrace; }; - legion = pkgs.callPackage ./.flake/pkgs/legion.nix { }; + realm = pkgs.callPackage ./.flake/pkgs/realm.nix { }; bencher-cli = pkgs.callPackage ./.flake/pkgs/bencher-cli.nix { }; ffdb = pkgs.callPackage ./.flake/pkgs/ffdb { inherit proj; }; hpp2plantuml = pkgs.python3Packages.callPackage ./.flake/pkgs/hpp2plantuml.nix { }; @@ -83,8 +83,7 @@ shellHook = '' export PATH="$HOME/ff/.scripts/:$PATH" export RC_PARAMS="max_discard_ratio=100" - export CMAKE_FLAGS="-DFF_USE_EXTERNAL_LEGION=ON \ - -DFF_USE_EXTERNAL_NCCL=ON \ + export CMAKE_FLAGS="-DFF_USE_EXTERNAL_NCCL=ON \ -DFF_USE_EXTERNAL_JSON=ON \ -DFF_USE_EXTERNAL_FMT=ON \ -DFF_USE_EXTERNAL_SPDLOG=ON \ @@ -94,7 +93,7 @@ -DFF_USE_EXTERNAL_GBENCHMARK=ON \ -DFF_USE_EXTERNAL_LIBASSERT=ON" ''; - + buildInputs = builtins.concatLists [ (with pkgs; [ zlib @@ -125,7 +124,7 @@ ]) (with self.packages.${system}; [ libassert - legion + realm rapidcheckFull doctest ]) diff --git a/lib/realm-execution/CMakeLists.txt b/lib/realm-execution/CMakeLists.txt index 0a1b681b8d..08676525e1 100644 --- a/lib/realm-execution/CMakeLists.txt +++ b/lib/realm-execution/CMakeLists.txt @@ -13,10 +13,10 @@ ff_add_library( local-execution op-attrs pcg - realm spdlog task-spec utils + Realm::Realm ) add_subdirectory(test) diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index acc11936c7..33e7ca252e 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -71,7 +71,7 @@ void RealmManager::controller_task_wrapper(void const *args, Realm::Processor proc) { assert(arglen == sizeof(void (*)(RealmManager &))); void (*thunk)(RealmManager &) = - *reinterpret_cast(const_cast(args)); + *reinterpret_cast(args); RealmManager manager(args, arglen, userdata, userlen, proc); thunk(manager); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc new file mode 100644 index 0000000000..880268c018 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -0,0 +1,22 @@ +#include "realm-execution/realm_manager.h" +#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmManager") { + // Construct some fake command line for our test + char fake_executable_name[] = "fake_executable_name"; + std::vector fake_args{fake_executable_name}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + // Initialize Realm + RealmManager manager(&fake_argc, &fake_argv); + + // Launch a controller and wait on it + Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + event.wait(); + } +} From 814e13f7efa10db7eae5fc37ffa0799d024e9e73 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 12:33:22 -0800 Subject: [PATCH 11/63] Support passing closure arguments to controllers. --- .../include/realm-execution/realm_manager.h | 3 ++- lib/realm-execution/src/realm-execution/realm_manager.cc | 9 +++++---- .../test/src/realm-execution/realm_manager.cc | 7 +++++-- lib/realm-execution/test/src/realm-execution/test_e2e.cc | 3 ++- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 497a1f3958..88cc11f744 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -17,7 +17,8 @@ struct RealmManager { RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; - Realm::Event start_controller(void (*thunk)(RealmManager &)); + [[nodiscard]] Realm::Event + start_controller(std::function); // Current device context Allocator &get_current_device_allocator() const; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 33e7ca252e..0ccf3f4116 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -25,7 +25,8 @@ RealmManager::~RealmManager() { } } -Realm::Event RealmManager::start_controller(void (*thunk)(RealmManager &)) { +Realm::Event + RealmManager::start_controller(std::function thunk) { constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, @@ -69,9 +70,9 @@ void RealmManager::controller_task_wrapper(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - assert(arglen == sizeof(void (*)(RealmManager &))); - void (*thunk)(RealmManager &) = - *reinterpret_cast(args); + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); RealmManager manager(args, arglen, userdata, userlen, proc); thunk(manager); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 880268c018..16b5338881 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -15,8 +15,11 @@ TEST_SUITE(FF_TEST_SUITE) { // Initialize Realm RealmManager manager(&fake_argc, &fake_argv); - // Launch a controller and wait on it - Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + // Launch a controller + int some_data = 123; + Realm::Event event = manager.start_controller( + [&](RealmManager &manager) { ASSERT(some_data == 123); }); + // Need to block on the completion of the event to ensure we don't race event.wait(); } } diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index f09951e73c..623b8318e6 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,6 +11,7 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - manager.start_controller([](RealmManager &manager) {}); + Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + event.wait(); } } From 3d0298cbd35a4b4683051cdcd8fbc05821ed32a5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:22:44 -0800 Subject: [PATCH 12/63] Move task IDs into Realm and assign IDs to remaining tasks. --- .../realm-execution}/task_id_t.dtg.toml | 5 +- .../include/realm-execution/task_id_t.h | 28 ++ .../src/realm-execution/task_id_t.cc | 192 ++++++++++++++ .../include/task-spec/ops/impl/dropout.h | 1 - .../task-spec/ops/op_task_id_t.dtg.toml | 18 -- .../task_id_with_noop_default_t.dtg.toml | 28 -- .../task-spec/task_id_with_noop_default_t.h | 28 -- .../task-spec/task_id_with_noop_default_t.cc | 243 ------------------ 8 files changed, 221 insertions(+), 322 deletions(-) rename lib/{task-spec/include/task-spec => realm-execution/include/realm-execution}/task_id_t.dtg.toml (98%) create mode 100644 lib/realm-execution/include/realm-execution/task_id_t.h create mode 100644 lib/realm-execution/src/realm-execution/task_id_t.cc delete mode 100644 lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml delete mode 100644 lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml delete mode 100644 lib/task-spec/include/task-spec/task_id_with_noop_default_t.h delete mode 100644 lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc diff --git a/lib/task-spec/include/task-spec/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml similarity index 98% rename from lib/task-spec/include/task-spec/task_id_t.dtg.toml rename to lib/realm-execution/include/realm-execution/task_id_t.dtg.toml index ce2de52d40..0336bc81a4 100644 --- a/lib/task-spec/include/task-spec/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml @@ -9,10 +9,7 @@ features = [ ] [[values]] -name = "TOP_LEVEL_TASK_ID" - -[[values]] -name = "FF_INIT_TASK_ID" +name = "CONTROLLER_TASK_ID" [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/task_id_t.h new file mode 100644 index 0000000000..af20dc27f6 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/task_id_t.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H + +#include "op-attrs/pcg_operator_attrs.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/task_id_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include + +namespace FlexFlow { + +std::optional + get_task_id_for_op(DynamicNodeInvocation const &, + std::optional const &); + +std::optional + get_init_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional get_fwd_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional get_bwd_task_id_for_op_attrs(PCGOperatorAttrs const &); + +std::optional + get_update_task_id_for_optimizer_attrs(OptimizerAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc new file mode 100644 index 0000000000..94b5fb5b24 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -0,0 +1,192 @@ +#include "realm-execution/task_id_t.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "pcg/optimizers/adam_optimizer_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" +#include "utils/overload.h" + +namespace FlexFlow { + +std::optional + get_task_id_for_op(DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs) { + DynamicTaskType task_type = invocation.node_attrs.task_type.value(); + switch (task_type) { + case DynamicTaskType::FWD: + return get_fwd_task_id_for_op_attrs( + invocation.node_attrs.op_attrs.value()); + case DynamicTaskType::BWD: + return get_bwd_task_id_for_op_attrs( + invocation.node_attrs.op_attrs.value()); + case DynamicTaskType::UPD: + return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); + case DynamicTaskType::LOSS: + return task_id_t::LOSS_BWD_TASK_ID; + default: + PANIC("Unhandled DynamicTaskType", task_type); + } +} + +std::optional + get_init_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { return std::nullopt; }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_INIT_TASK_ID; }, + [](BroadcastAttrs const &) { return std::nullopt; }, + [](CastAttrs const &) { return std::nullopt; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_INIT_TASK_ID; }, + [](ConcatAttrs const &) { return std::nullopt; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_INIT_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_INIT_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_INIT_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_INIT_TASK_ID; + }, + [](EmbeddingAttrs const &) { return std::nullopt; }, + [](FlatAttrs const &) { return std::nullopt; }, + [](GatherAttrs const &) { return task_id_t::GATHER_INIT_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_INIT_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_INIT_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_INIT_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_INIT_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_INIT_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_INIT_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_INIT_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_INIT_TASK_ID; + }, + [](ReshapeAttrs const &) { return std::nullopt; }, + [](ReverseAttrs const &) { return std::nullopt; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, + [](SplitAttrs const &) { return std::nullopt; }, + [](TopKAttrs const &) { return std::nullopt; }, + [](TransposeAttrs const &) { return std::nullopt; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional + get_fwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return task_id_t::BATCHMATMUL_FWD_TASK_ID; + }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_FWD_TASK_ID; }, + [](BroadcastAttrs const &) { return task_id_t::BROADCAST_FWD_TASK_ID; }, + [](CastAttrs const &) { return task_id_t::CAST_FWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_FWD_TASK_ID; }, + [](ConcatAttrs const &) { return task_id_t::CONCAT_FWD_TASK_ID; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_FWD_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_FWD_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_FWD_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_FWD_TASK_ID; + }, + [](EmbeddingAttrs const &) { return task_id_t::EMBED_FWD_TASK_ID; }, + [](FlatAttrs const &) { return task_id_t::FLAT_FWD_TASK_ID; }, + [](GatherAttrs const &) { return task_id_t::GATHER_FWD_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_FWD_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_FWD_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_FWD_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_FWD_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_FWD_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_FWD_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_FWD_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_FWD_TASK_ID; + }, + [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, + [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, + [](SplitAttrs const &) { return task_id_t::SPLIT_FWD_TASK_ID; }, + [](TopKAttrs const &) { return task_id_t::TOPK_FWD_TASK_ID; }, + [](TransposeAttrs const &) { return task_id_t::TRANSPOSE_FWD_TASK_ID; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional + get_bwd_task_id_for_op_attrs(PCGOperatorAttrs const &op_attrs) { + + return op_attrs.visit>(overload{ + [](BatchMatmulAttrs const &) { + return task_id_t::BATCHMATMUL_BWD_TASK_ID; + }, + [](BatchNormAttrs const &) { return task_id_t::BATCHNORM_BWD_TASK_ID; }, + [](BroadcastAttrs const &) { return task_id_t::BROADCAST_BWD_TASK_ID; }, + [](CastAttrs const &) { return task_id_t::CAST_BWD_TASK_ID; }, + [](CombineAttrs const &attrs) { return task_id_t::COMBINE_BWD_TASK_ID; }, + [](ConcatAttrs const &) { return task_id_t::CONCAT_BWD_TASK_ID; }, + [](Conv2DAttrs const &) { return task_id_t::CONV2D_BWD_TASK_ID; }, + [](DropoutAttrs const &) { return task_id_t::DROPOUT_BWD_TASK_ID; }, + [](ElementBinaryAttrs const &) { + return task_id_t::ELEMENTBINARY_BWD_TASK_ID; + }, + [](ElementUnaryAttrs const &) { + return task_id_t::ELEMENTUNARY_BWD_TASK_ID; + }, + [](EmbeddingAttrs const &) { return task_id_t::EMBED_BWD_TASK_ID; }, + [](FlatAttrs const &) { return task_id_t::FLAT_BWD_TASK_ID; }, + [](GatherAttrs const &) { return task_id_t::GATHER_BWD_TASK_ID; }, + [](InputAttrs const &) { return std::nullopt; }, + [](LayerNormAttrs const &) { return task_id_t::LAYERNORM_BWD_TASK_ID; }, + [](LinearAttrs const &) { return task_id_t::LINEAR_BWD_TASK_ID; }, + [](MultiHeadAttentionAttrs const &) { + return task_id_t::ATTENTION_BWD_TASK_ID; + }, + [](NoopAttrs const &) { return std::nullopt; }, + [](Pool2DAttrs const &) { return task_id_t::POOL2D_BWD_TASK_ID; }, + [](ReduceAttrs const &) { return task_id_t::REDUCE_BWD_TASK_ID; }, + [](ReductionAttrs const &attrs) { + return task_id_t::REDUCTION_BWD_TASK_ID; + }, + [](RepartitionAttrs const &attrs) { + return task_id_t::REPARTITION_BWD_TASK_ID; + }, + [](ReplicateAttrs const &attrs) { + return task_id_t::REPLICATE_BWD_TASK_ID; + }, + [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, + [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, + [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, + [](SplitAttrs const &) { return task_id_t::SPLIT_BWD_TASK_ID; }, + [](TopKAttrs const &) { return task_id_t::TOPK_BWD_TASK_ID; }, + [](TransposeAttrs const &) { return task_id_t::TRANSPOSE_BWD_TASK_ID; }, + [](WeightAttrs const &) { return std::nullopt; }, + }); +} + +std::optional get_update_task_id_for_optimizer_attrs( + OptimizerAttrs const &optimizer_attrs) { + + return optimizer_attrs.visit>(overload{ + [](SGDOptimizerAttrs const &) { return task_id_t::SGD_UPD_NCCL_TASK_ID; }, + [](AdamOptimizerAttrs const &) { + return task_id_t::ADAM_UPD_NCCL_TASK_ID; + }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/ops/impl/dropout.h b/lib/task-spec/include/task-spec/ops/impl/dropout.h index a7b382ce62..192f2f8244 100644 --- a/lib/task-spec/include/task-spec/ops/impl/dropout.h +++ b/lib/task-spec/include/task-spec/ops/impl/dropout.h @@ -2,7 +2,6 @@ #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_OPS_IMPL_DROPOUT_H #include "op-attrs/ops/dropout_attrs.dtg.h" -#include "task-spec/task_id_t.dtg.h" #include "task-spec/task_impl_function.dtg.h" namespace FlexFlow { diff --git a/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml b/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml deleted file mode 100644 index 557da6cf4c..0000000000 --- a/lib/task-spec/include/task-spec/ops/op_task_id_t.dtg.toml +++ /dev/null @@ -1,18 +0,0 @@ -namespace = "FlexFlow" -name = "op_task_id_t" -type = "enum" -features = [ - "hash", - "json", - "rapidcheck", - "fmt", -] - -[[values]] -name = "INIT" - -[[values]] -name = "FWD" - -[[values]] -name = "BWD" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml deleted file mode 100644 index 50349d5773..0000000000 --- a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.dtg.toml +++ /dev/null @@ -1,28 +0,0 @@ -namespace = "FlexFlow" -name = "task_id_with_noop_default_t" -type = "variant" -features = [ - "eq", - "ord", - "hash", - "fmt", - "rapidcheck", -] - -includes = [ - "task-spec/task_id_t.dtg.h", - "", -] - -src_includes = [ - "utils/rapidcheck/monostate.h", - "utils/fmt/monostate.h", -] - -[[values]] -type = "::FlexFlow::task_id_t" -key = "real_task" - -[[values]] -type = "std::monostate" -key = "noop_task" diff --git a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h b/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h deleted file mode 100644 index 054b73844e..0000000000 --- a/lib/task-spec/include/task-spec/task_id_with_noop_default_t.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_TASK_ID_WITH_NOOP_DEFAULT_T_H - -#include "op-attrs/computation_graph_op_attrs.dtg.h" -#include "op-attrs/operator_type.dtg.h" -#include "task-spec/ops/op_task_id_t.dtg.h" -#include "task-spec/task_id_with_noop_default_t.dtg.h" - -namespace FlexFlow { - -task_id_with_noop_default_t lift_task_id_t(task_id_t); -task_id_with_noop_default_t default_noop_task(); - -task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( - op_task_id_t, ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -task_id_with_noop_default_t - get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc b/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc deleted file mode 100644 index 20e0d00c57..0000000000 --- a/lib/task-spec/src/task-spec/task_id_with_noop_default_t.cc +++ /dev/null @@ -1,243 +0,0 @@ -#include "task-spec/task_id_with_noop_default_t.h" -#include "utils/overload.h" - -namespace FlexFlow { - -task_id_with_noop_default_t lift_task_id_t(task_id_t task_id) { - return task_id_with_noop_default_t{task_id}; -} - -task_id_with_noop_default_t default_noop_task() { - return task_id_with_noop_default_t{std::monostate{}}; -} - -task_id_with_noop_default_t lower_op_task_id_to_task_id_with_noop_default_t( - op_task_id_t op_task_id, ComputationGraphOpAttrs const &op_attrs) { - switch (op_task_id) { - case op_task_id_t::INIT: - return get_init_task_id_for_op_attrs(op_attrs); - case op_task_id_t::FWD: - return get_fwd_task_id_for_op_attrs(op_attrs); - case op_task_id_t::BWD: - return get_bwd_task_id_for_op_attrs(op_attrs); - default: - PANIC("Unhandled op_task_id_t", op_task_id); - } -} - -task_id_with_noop_default_t - get_init_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { return default_noop_task(); }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_INIT_TASK_ID); - }, - [](BroadcastAttrs const &) { return default_noop_task(); }, - [](CastAttrs const &) { return default_noop_task(); }, - [](ConcatAttrs const &) { return default_noop_task(); }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_INIT_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_INIT_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_INIT_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_INIT_TASK_ID); - }, - [](EmbeddingAttrs const &) { return default_noop_task(); }, - [](FlatAttrs const &) { return default_noop_task(); }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_INIT_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_INIT_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_INIT_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_INIT_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_INIT_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_INIT_TASK_ID); - }, - [](ReshapeAttrs const &) { return default_noop_task(); }, - [](ReverseAttrs const &) { return default_noop_task(); }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_INIT_TASK_ID); - }, - [](SplitAttrs const &) { return default_noop_task(); }, - [](TopKAttrs const &) { return default_noop_task(); }, - [](TransposeAttrs const &) { return default_noop_task(); }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -task_id_with_noop_default_t - get_fwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { - return lift_task_id_t(task_id_t::BATCHMATMUL_FWD_TASK_ID); - }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_FWD_TASK_ID); - }, - [](BroadcastAttrs const &) { - return lift_task_id_t(task_id_t::BROADCAST_FWD_TASK_ID); - }, - [](CastAttrs const &) { - return lift_task_id_t(task_id_t::CAST_FWD_TASK_ID); - }, - [](ConcatAttrs const &) { - return lift_task_id_t(task_id_t::CONCAT_FWD_TASK_ID); - }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_FWD_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_FWD_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_FWD_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_FWD_TASK_ID); - }, - [](EmbeddingAttrs const &) { - return lift_task_id_t(task_id_t::EMBED_FWD_TASK_ID); - }, - [](FlatAttrs const &) { - return lift_task_id_t(task_id_t::FLAT_FWD_TASK_ID); - }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_FWD_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_FWD_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_FWD_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_FWD_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_FWD_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_FWD_TASK_ID); - }, - [](ReshapeAttrs const &) { - return lift_task_id_t(task_id_t::RESHAPE_FWD_TASK_ID); - }, - [](ReverseAttrs const &) { - return lift_task_id_t(task_id_t::REVERSE_FWD_TASK_ID); - }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_FWD_TASK_ID); - }, - [](SplitAttrs const &) { - return lift_task_id_t(task_id_t::SPLIT_FWD_TASK_ID); - }, - [](TopKAttrs const &) { - return lift_task_id_t(task_id_t::TOPK_FWD_TASK_ID); - }, - [](TransposeAttrs const &) { - return lift_task_id_t(task_id_t::TRANSPOSE_FWD_TASK_ID); - }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -task_id_with_noop_default_t - get_bwd_task_id_for_op_attrs(ComputationGraphOpAttrs const &op_attrs) { - - return op_attrs.visit(overload{ - [](BatchMatmulAttrs const &) { - return lift_task_id_t(task_id_t::BATCHMATMUL_BWD_TASK_ID); - }, - [](BatchNormAttrs const &) { - return lift_task_id_t(task_id_t::BATCHNORM_BWD_TASK_ID); - }, - [](BroadcastAttrs const &) { - return lift_task_id_t(task_id_t::BROADCAST_BWD_TASK_ID); - }, - [](CastAttrs const &) { - return lift_task_id_t(task_id_t::CAST_BWD_TASK_ID); - }, - [](ConcatAttrs const &) { - return lift_task_id_t(task_id_t::CONCAT_BWD_TASK_ID); - }, - [](Conv2DAttrs const &) { - return lift_task_id_t(task_id_t::CONV2D_BWD_TASK_ID); - }, - [](DropoutAttrs const &) { - return lift_task_id_t(task_id_t::DROPOUT_BWD_TASK_ID); - }, - [](ElementBinaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTBINARY_BWD_TASK_ID); - }, - [](ElementUnaryAttrs const &) { - return lift_task_id_t(task_id_t::ELEMENTUNARY_BWD_TASK_ID); - }, - [](EmbeddingAttrs const &) { - return lift_task_id_t(task_id_t::EMBED_BWD_TASK_ID); - }, - [](FlatAttrs const &) { - return lift_task_id_t(task_id_t::FLAT_BWD_TASK_ID); - }, - [](GatherAttrs const &) { - return lift_task_id_t(task_id_t::GATHER_BWD_TASK_ID); - }, - [](InputAttrs const &) { return default_noop_task(); }, - [](LayerNormAttrs const &) { - return lift_task_id_t(task_id_t::LAYERNORM_BWD_TASK_ID); - }, - [](LinearAttrs const &) { - return lift_task_id_t(task_id_t::LINEAR_BWD_TASK_ID); - }, - [](MultiHeadAttentionAttrs const &) { - return lift_task_id_t(task_id_t::ATTENTION_BWD_TASK_ID); - }, - [](NoopAttrs const &) { return default_noop_task(); }, - [](Pool2DAttrs const &) { - return lift_task_id_t(task_id_t::POOL2D_BWD_TASK_ID); - }, - [](ReduceAttrs const &) { - return lift_task_id_t(task_id_t::REDUCE_BWD_TASK_ID); - }, - [](ReshapeAttrs const &) { - return lift_task_id_t(task_id_t::RESHAPE_BWD_TASK_ID); - }, - [](ReverseAttrs const &) { - return lift_task_id_t(task_id_t::REVERSE_BWD_TASK_ID); - }, - [](SoftmaxAttrs const &) { - return lift_task_id_t(task_id_t::SOFTMAX_BWD_TASK_ID); - }, - [](SplitAttrs const &) { - return lift_task_id_t(task_id_t::SPLIT_BWD_TASK_ID); - }, - [](TopKAttrs const &) { - return lift_task_id_t(task_id_t::TOPK_BWD_TASK_ID); - }, - [](TransposeAttrs const &) { - return lift_task_id_t(task_id_t::TRANSPOSE_BWD_TASK_ID); - }, - [](WeightAttrs const &) { return default_noop_task(); }, - }); -} - -} // namespace FlexFlow From d702afe043189284f8643c50dfdca2f5d742f5e0 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:34:22 -0800 Subject: [PATCH 13/63] Avoid pulling in the entire invocation. --- .../include/realm-execution/task_id_t.h | 4 ++-- lib/realm-execution/src/realm-execution/task_id_t.cc | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/task_id_t.h index af20dc27f6..38b82ad9e0 100644 --- a/lib/realm-execution/include/realm-execution/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/task_id_t.h @@ -4,13 +4,13 @@ #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include namespace FlexFlow { std::optional - get_task_id_for_op(DynamicNodeInvocation const &, + get_task_id_for_op(DynamicNodeAttrs const &, std::optional const &); std::optional diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc index 94b5fb5b24..574dbb1e54 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -8,16 +8,14 @@ namespace FlexFlow { std::optional - get_task_id_for_op(DynamicNodeInvocation const &invocation, + get_task_id_for_op(DynamicNodeAttrs const &node_attrs, std::optional const &optimizer_attrs) { - DynamicTaskType task_type = invocation.node_attrs.task_type.value(); + DynamicTaskType task_type = node_attrs.task_type.value(); switch (task_type) { case DynamicTaskType::FWD: - return get_fwd_task_id_for_op_attrs( - invocation.node_attrs.op_attrs.value()); + return get_fwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); case DynamicTaskType::BWD: - return get_bwd_task_id_for_op_attrs( - invocation.node_attrs.op_attrs.value()); + return get_bwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); case DynamicTaskType::UPD: return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); case DynamicTaskType::LOSS: From 4fcde77d3fe8b4a57c5243ddcbcb233446f546e5 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 14:45:46 -0800 Subject: [PATCH 14/63] Conversion into Realm task IDs. --- .../include/realm-execution/realm_task_id_t.h | 13 +++++++++++++ .../src/realm-execution/realm_manager.cc | 5 ++++- .../src/realm-execution/realm_task_id_t.cc | 10 ++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_task_id_t.h create mode 100644 lib/realm-execution/src/realm-execution/realm_task_id_t.cc diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h new file mode 100644 index 0000000000..6d2e316b14 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H + +#include "realm-execution/task_id_t.dtg.h" +#include "realm.h" + +namespace FlexFlow { + +Realm::Processor::TaskFuncID get_realm_task_id_for_task_id(task_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 0ccf3f4116..747f603f5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,4 +1,6 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { @@ -27,7 +29,8 @@ RealmManager::~RealmManager() { Realm::Event RealmManager::start_controller(std::function thunk) { - constexpr int CONTROLLER_TASK_ID = Realm::Processor::TASK_ID_FIRST_AVAILABLE; + Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = + get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, /*global=*/false, diff --git a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc b/lib/realm-execution/src/realm-execution/realm_task_id_t.cc new file mode 100644 index 0000000000..50b23dfe86 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_task_id_t.cc @@ -0,0 +1,10 @@ +#include "realm-execution/realm_task_id_t.h" + +namespace FlexFlow { + +Realm::Processor::TaskFuncID get_realm_task_id_for_task_id(task_id_t task_id) { + return Realm::Processor::TASK_ID_FIRST_AVAILABLE + + static_cast(task_id); +} + +} // namespace FlexFlow From e51b04e80057d6459f1070bd19a1915d4d3ca706 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 15:49:20 -0800 Subject: [PATCH 15/63] Add a top-level PRealm switch. --- .../include/realm-execution/realm.h | 20 +++++++++++++++++++ .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../src/realm-execution/task_id_t.cc | 1 - .../test/src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/test_e2e.cc | 3 ++- 6 files changed, 25 insertions(+), 5 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm.h diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h new file mode 100644 index 0000000000..f15113ee92 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H + +#ifdef FLEXFLOW_USE_PREALM +#include +#else +#include +#endif + +namespace FlexFlow { + +#ifdef FLEXFLOW_USE_PREALM +namespace Realm = ::PRealm; +#else +namespace Realm = ::Realm; +#endif + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index 88cc11f744..b26adea548 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,7 +4,7 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm.h" +#include "realm-execution/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 6d2e316b14..8e6da1a2bd 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/task_id_t.cc index 574dbb1e54..3521f50c02 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/task_id_t.cc @@ -2,7 +2,6 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" -#include "task-spec/dynamic_graph/dynamic_task_type.dtg.h" #include "utils/overload.h" namespace FlexFlow { diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 16b5338881..f9fbd986c2 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -17,7 +17,7 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - Realm::Event event = manager.start_controller( + FlexFlow::Realm::Event event = manager.start_controller( [&](RealmManager &manager) { ASSERT(some_data == 123); }); // Need to block on the completion of the event to ensure we don't race event.wait(); diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 623b8318e6..fa9f798e4f 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,7 +11,8 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - Realm::Event event = manager.start_controller([](RealmManager &manager) {}); + FlexFlow::Realm::Event event = + manager.start_controller([](RealmManager &manager) {}); event.wait(); } } From 895de3305d10882ed38519080d7348f5c74c6b5c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 5 Feb 2026 17:08:38 -0800 Subject: [PATCH 16/63] Some work on Realm task registry. --- .../realm-execution/realm_task_registry.h | 13 +++++ .../realm-execution/realm_task_registry.cc | 55 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 lib/realm-execution/include/realm-execution/realm_task_registry.h create mode 100644 lib/realm-execution/src/realm-execution/realm_task_registry.cc diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h new file mode 100644 index 0000000000..3a4cee106c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -0,0 +1,13 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/task_id_t.dtg.h" + +namespace FlexFlow { + +Realm::Event register_all_tasks(); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc new file mode 100644 index 0000000000..a5e52b7a7c --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -0,0 +1,55 @@ +#include "realm-execution/realm.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" + +namespace FlexFlow { + +void op_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor) {} + +static Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)) { + return Realm::Processor::register_task_by_kind( + target_kind, + /*global=*/false, + get_realm_task_id_for_task_id(func_id), + Realm::CodeDescriptor(task_body), + Realm::ProfilingRequestSet()); +} + +Realm::Event register_all_tasks() { + std::vector pending_registrations; + + std::vector init_task_ids = { + task_id_t::BATCHNORM_INIT_TASK_ID, + task_id_t::COMBINE_INIT_TASK_ID, + task_id_t::CONV2D_INIT_TASK_ID, + task_id_t::DROPOUT_INIT_TASK_ID, + task_id_t::ELEMENTBINARY_INIT_TASK_ID, + task_id_t::ELEMENTUNARY_INIT_TASK_ID, + task_id_t::GATHER_INIT_TASK_ID, + task_id_t::LAYERNORM_INIT_TASK_ID, + task_id_t::LINEAR_INIT_TASK_ID, + task_id_t::ATTENTION_INIT_TASK_ID, + task_id_t::POOL2D_INIT_TASK_ID, + task_id_t::REDUCE_INIT_TASK_ID, + task_id_t::REDUCTION_INIT_TASK_ID, + task_id_t::REPARTITION_INIT_TASK_ID, + task_id_t::REPLICATE_INIT_TASK_ID, + task_id_t::SOFTMAX_INIT_TASK_ID, + }; + + for (task_id_t init_task_id : init_task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::LOC_PROC, init_task_id, op_task_wrapper)); + } + + return Realm::Event::merge_events(pending_registrations); +} + +} // namespace FlexFlow From 09fde7d3f41ddb014c7712561995a2098a39e799 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:20:53 -0800 Subject: [PATCH 17/63] Split out the Realm context. --- .../parallel_computation_graph_instance.h | 8 +-- .../include/realm-execution/realm_context.h | 34 +++++++++++ .../include/realm-execution/realm_manager.h | 25 ++------ .../parallel_computation_graph_instance.cc | 4 +- .../src/realm-execution/realm_context.cc | 34 +++++++++++ .../src/realm-execution/realm_manager.cc | 60 +++++-------------- .../test/src/realm-execution/realm_manager.cc | 2 +- .../test/src/realm-execution/test_e2e.cc | 4 +- 8 files changed, 96 insertions(+), 75 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_context.h create mode 100644 lib/realm-execution/src/realm-execution/realm_context.cc diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 4ba77a7925..0dd87d566f 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,7 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "realm-execution/realm_manager.h" +#include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -21,7 +21,7 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(RealmManager &, + ParallelComputationGraphInstance(RealmContext &, DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, @@ -36,7 +36,7 @@ struct ParallelComputationGraphInstance { std::optional get_loss_tensor_accessor() const; private: - RealmManager &realm; + RealmContext &realm; DynamicOpenDataflowGraph dataflow_graph; std::vector topological_ordering; OptimizerAttrs optimizer_attrs; @@ -45,7 +45,7 @@ struct ParallelComputationGraphInstance { }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmManager &realm, + RealmContext &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h new file mode 100644 index 0000000000..5539fe693e --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_CONTEXT_H + +#include "kernels/allocation.h" +#include "kernels/device_handle_t.dtg.h" +#include "pcg/device_id_t.dtg.h" +#include "realm-execution/realm.h" + +namespace FlexFlow { + +struct RealmContext { +public: + RealmContext(); + virtual ~RealmContext(); + + RealmContext(RealmContext const &) = delete; + RealmContext(RealmContext &&) = delete; + + // Current device context + Allocator &get_current_device_allocator() const; + device_handle_t const &get_current_device_handle() const; + device_id_t const &get_current_device_idx() const; + +protected: + [[nodiscard]] Realm::Event merge_outstanding_events(); + +protected: + Realm::Runtime runtime; + std::vector outstanding_events; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index b26adea548..bf5e8f72f1 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -5,38 +5,21 @@ #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" namespace FlexFlow { -struct RealmManager { +struct RealmManager : private RealmContext { public: RealmManager(int *argc, char ***argv); - ~RealmManager(); + virtual ~RealmManager(); RealmManager() = delete; RealmManager(RealmManager const &) = delete; RealmManager(RealmManager &&) = delete; [[nodiscard]] Realm::Event - start_controller(std::function); - - // Current device context - Allocator &get_current_device_allocator() const; - device_handle_t const &get_current_device_handle() const; - device_id_t const &get_current_device_idx() const; - -private: - RealmManager(void const *, size_t, void const *, size_t, Realm::Processor); - - [[nodiscard]] Realm::Event merge_outstanding_events(); - - static void controller_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor); - -private: - Realm::Runtime runtime; - std::vector outstanding_events; - bool is_root_runtime; + start_controller(std::function); }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 64c9da2f4c..c8100287f8 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -12,7 +12,7 @@ namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( - RealmManager &realm, + RealmContext &realm, DynamicOpenDataflowGraph dataflow_graph, std::vector const &topological_ordering, OptimizerAttrs const &optimizer_attrs, @@ -61,7 +61,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmManager &realm, + RealmContext &realm, ParallelComputationGraph const &pcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc new file mode 100644 index 0000000000..5068373ebe --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -0,0 +1,34 @@ +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/task_id_t.dtg.h" +#include "utils/exception.h" + +namespace FlexFlow { + +RealmContext::RealmContext() {} + +RealmContext::~RealmContext() { + if (!this->outstanding_events.empty()) { + Realm::Event outstanding = this->merge_outstanding_events(); + outstanding.wait(); + } +} + +Allocator &RealmContext::get_current_device_allocator() const { + NOT_IMPLEMENTED(); +} + +device_handle_t const &RealmContext::get_current_device_handle() const { + NOT_IMPLEMENTED(); +} +device_id_t const &RealmContext::get_current_device_idx() const { + NOT_IMPLEMENTED(); +} + +Realm::Event RealmContext::merge_outstanding_events() { + Realm::Event result = Realm::Event::merge_events(this->outstanding_events); + this->outstanding_events.clear(); + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 747f603f5d..501ba7536a 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -5,37 +5,39 @@ namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) : is_root_runtime(true) { +RealmManager::RealmManager(int *argc, char ***argv) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); } -RealmManager::RealmManager(void const *args, - size_t arglen, - void const *userdata, - size_t userdatalen, - Realm::Processor proc) - : runtime(Realm::Runtime::get_runtime()), is_root_runtime(false) {} - RealmManager::~RealmManager() { Realm::Event outstanding = this->merge_outstanding_events(); - if (is_root_runtime) { this->runtime.shutdown(outstanding); this->runtime.wait_for_shutdown(); - } else { - outstanding.wait(); - } +} + +static void controller_task_wrapper(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); + + RealmContext ctx; + thunk(ctx); } Realm::Event - RealmManager::start_controller(std::function thunk) { + RealmManager::start_controller(std::function thunk) { Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); Realm::Event task_ready = Realm::Processor::register_task_by_kind( Realm::Processor::LOC_PROC, /*global=*/false, CONTROLLER_TASK_ID, - Realm::CodeDescriptor(RealmManager::controller_task_wrapper), + Realm::CodeDescriptor(controller_task_wrapper), Realm::ProfilingRequestSet(), &thunk, sizeof(thunk)); @@ -51,34 +53,4 @@ Realm::Event return task_complete; } -Allocator &RealmManager::get_current_device_allocator() const { - NOT_IMPLEMENTED(); -} - -device_handle_t const &RealmManager::get_current_device_handle() const { - NOT_IMPLEMENTED(); -} -device_id_t const &RealmManager::get_current_device_idx() const { - NOT_IMPLEMENTED(); -} - -Realm::Event RealmManager::merge_outstanding_events() { - Realm::Event result = Realm::Event::merge_events(this->outstanding_events); - this->outstanding_events.clear(); - return result; -} - -void RealmManager::controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); - - RealmManager manager(args, arglen, userdata, userlen, proc); - thunk(manager); -} - } // namespace FlexFlow diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index f9fbd986c2..6c28a001ad 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -18,7 +18,7 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; FlexFlow::Realm::Event event = manager.start_controller( - [&](RealmManager &manager) { ASSERT(some_data == 123); }); + [&](RealmContext &ctx) { ASSERT(some_data == 123); }); // Need to block on the completion of the event to ensure we don't race event.wait(); } diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index fa9f798e4f..a30d5c4d8e 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -11,8 +11,6 @@ TEST_SUITE(FF_TEST_SUITE) { int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); RealmManager manager(&fake_argc, &fake_argv); - FlexFlow::Realm::Event event = - manager.start_controller([](RealmManager &manager) {}); - event.wait(); + (void)manager.start_controller([](RealmContext &ctx) {}); } } From c5a0ea9858de3c85ffbacc8757488bdb23b13efd Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:32:02 -0800 Subject: [PATCH 18/63] Switch to mapped PCG. --- .../parallel_computation_graph_instance.h | 4 ++-- .../parallel_computation_graph_instance.cc | 7 ++++--- .../src/realm-execution/realm_manager.cc | 12 ++++++------ ...ke_dynamic_open_dataflow_graph_from_mpcg.h | 14 ++++++++++++++ ...ake_dynamic_open_dataflow_graph_from_pcg.h | 14 -------------- ..._dynamic_open_dataflow_graph_from_mpcg.cc} | 19 ++++++++++--------- 6 files changed, 36 insertions(+), 34 deletions(-) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h delete mode 100644 lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h rename lib/task-spec/src/task-spec/dynamic_graph/{make_dynamic_open_dataflow_graph_from_pcg.cc => make_dynamic_open_dataflow_graph_from_mpcg.cc} (84%) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 0dd87d566f..06c2d2d912 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -7,8 +7,8 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/device_id_t.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -46,7 +46,7 @@ struct ParallelComputationGraphInstance { ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &realm, - ParallelComputationGraph const &pcg, + MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index c8100287f8..e7bf79f12d 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -4,7 +4,7 @@ #include "pcg/optimizer_attrs.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/loss_insertion.h" -#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" @@ -62,7 +62,7 @@ static GenericTensorAccessorW ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &realm, - ParallelComputationGraph const &pcg, + MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, @@ -72,7 +72,8 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config) { - DynamicOpenDataflowGraph dg = make_dynamic_open_dataflow_graph_from_pcg(pcg); + DynamicOpenDataflowGraph dg = + make_dynamic_open_dataflow_graph_from_mpcg(mpcg); dg = perform_pass_expansion(dg); std::unordered_map inputs = diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 501ba7536a..0c34d77204 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -12,15 +12,15 @@ RealmManager::RealmManager(int *argc, char ***argv) { RealmManager::~RealmManager() { Realm::Event outstanding = this->merge_outstanding_events(); - this->runtime.shutdown(outstanding); - this->runtime.wait_for_shutdown(); + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); } static void controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { ASSERT(arglen == sizeof(std::function)); std::function thunk = *reinterpret_cast const *>(args); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h new file mode 100644 index 0000000000..758a0c2813 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_MPCG_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_MPCG_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( + MappedParallelComputationGraph const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h deleted file mode 100644 index a71eb558c1..0000000000 --- a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H -#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_OPEN_DATAFLOW_GRAPH_FROM_PCG_H - -#include "pcg/parallel_computation_graph/parallel_computation_graph.dtg.h" -#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" - -namespace FlexFlow { - -DynamicOpenDataflowGraph - make_dynamic_open_dataflow_graph_from_pcg(ParallelComputationGraph const &); - -} // namespace FlexFlow - -#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc similarity index 84% rename from lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc rename to lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index 841be27dfd..e90ef10398 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -1,4 +1,4 @@ -#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_pcg.h" +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/pcg_operator_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" @@ -13,26 +13,27 @@ namespace FlexFlow { -DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( - ParallelComputationGraph const &pcg) { +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( + MappedParallelComputationGraph const &mpcg) { DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { + for (auto const &[layer, attrs] : + get_parallel_layer_attrs_mapping(mpcg.pcg)) { DynamicNodeAttrs result_attrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, + /*mapping=*/mpcg.mapped_tasks.at(layer), /*op_attrs=*/attrs.op_attrs, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), + transform(get_incoming_tensors(mpcg.pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + get_parallel_tensor_attrs(mpcg.pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, @@ -48,11 +49,11 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_pcg( }; }); std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), + transform(get_outgoing_tensors(mpcg.pcg, layer), [&](TensorSlotName const &slot_name, parallel_tensor_guid_t const &tensor) { ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); + get_parallel_tensor_attrs(mpcg.pcg, tensor); return std::pair{ DynamicTensorSlot{ /*slot_name=*/slot_name, From a587e53f6ede06e3e0e71d18ce352c125bace915 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 09:54:57 -0800 Subject: [PATCH 19/63] Add shard expansion pass (and implement shard expansion pass). --- .../parallel_computation_graph_instance.h | 3 ++- .../parallel_computation_graph_instance.cc | 9 ++++++--- .../task-spec/dynamic_graph/shard_expansion.cc | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 06c2d2d912..f361cec3ca 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -50,7 +51,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, - std::optional logit_tensor, + std::optional logit_tensor, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index e7bf79f12d..80ed98f8c2 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -3,9 +3,11 @@ #include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" +#include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" @@ -66,7 +68,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, std::optional label_tensor, - std::optional logit_tensor, + std::optional logit_tensor, std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, @@ -81,13 +83,14 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_grad_value; if (loss_attrs) { auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( - dg, assert_unwrap(loss_attrs), assert_unwrap(logit_tensor)); + dg, loss_attrs.value(), dynamic_tensor_guid_t{logit_tensor.value()}); dg = dg2; logit_grad_value = logit_grad_v; - inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); + inputs.insert(std::pair{label_v, label_tensor.value()}); } dg = perform_update_insertion(dg, optimizer_attrs); + dg = perform_shard_expansion(dg); dg = perform_tensor_allocation( dg, inputs, realm.get_current_device_allocator()); diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index ea253b63f8..33b7fb8591 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -81,4 +81,19 @@ std::unordered_set }); } +DynamicOpenDataflowGraph + perform_shard_expansion(DynamicOpenDataflowGraph const &g) { + + ASSERT(no_part_of_graph_is_shard_expanded(g)); + + DynamicOpenDataflowGraph result = + flatmap_dynamic_invocation_set(g, [&](DynamicNodeInvocation const &i) { + return perform_shard_expansion_for_invocation(i); + }); + + ASSERT(graph_is_fully_shard_expanded(result)); + + return result; +} + } // namespace FlexFlow From 62b49f7bf6d1b07bb9de720d6a2bf36046bfe03a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 10:47:43 -0800 Subject: [PATCH 20/63] Add instance field to dynamic graph, more task IDs. --- .../include/realm-execution/realm_context.h | 2 +- .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../realm-execution/realm_task_registry.h | 4 +- .../realm-execution/realm_task_registry.cc | 81 +++++++++++++++++-- lib/task-spec/CMakeLists.txt | 1 + .../dynamic_value_attrs.dtg.toml | 6 ++ .../include/task-spec/realm/fmt/instance.h | 35 ++++++++ .../include/task-spec/realm}/realm.h | 4 +- .../task-spec/dynamic_graph/loss_insertion.cc | 2 + ...ake_dynamic_open_dataflow_graph_from_cg.cc | 2 + ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 + .../dynamic_graph/update_insertion.cc | 1 + .../src/task-spec/realm/fmt/instance.h | 10 +++ .../dynamic_open_dataflow_graph.cc | 3 + .../dynamic_graph/machine_slicing.cc | 1 + .../task-spec/dynamic_graph/pass_expansion.cc | 3 + .../dynamic_graph/shard_expansion.cc | 1 + 18 files changed, 148 insertions(+), 14 deletions(-) create mode 100644 lib/task-spec/include/task-spec/realm/fmt/instance.h rename lib/{realm-execution/include/realm-execution => task-spec/include/task-spec/realm}/realm.h (63%) create mode 100644 lib/task-spec/src/task-spec/realm/fmt/instance.h diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 5539fe693e..357b05b699 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -4,7 +4,7 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/realm.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index bf5e8f72f1..ebf3bb401e 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,8 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/realm.h" #include "realm-execution/realm_context.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 8e6da1a2bd..327cf9ffd0 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H -#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index 3a4cee106c..d9d993795b 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H -#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" +#include "task-spec/realm/realm.h" namespace FlexFlow { -Realm::Event register_all_tasks(); +[[nodiscard]] Realm::Event register_all_tasks(); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc index a5e52b7a7c..5c61c208fb 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -1,11 +1,13 @@ -#include "realm-execution/realm.h" +#include "realm-execution/realm_task_registry.h" #include "realm-execution/realm_task_id_t.h" -#include "realm-execution/task_id_t.dtg.h" +#include "utils/exception.h" namespace FlexFlow { -void op_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor) {} +static void operation_task_wrapper( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} static Realm::Event register_task(Realm::Processor::Kind target_kind, task_id_t func_id, @@ -25,7 +27,8 @@ static Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Event register_all_tasks() { std::vector pending_registrations; - std::vector init_task_ids = { + std::vector task_ids = { + // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, task_id_t::COMBINE_INIT_TASK_ID, task_id_t::CONV2D_INIT_TASK_ID, @@ -42,11 +45,75 @@ Realm::Event register_all_tasks() { task_id_t::REPARTITION_INIT_TASK_ID, task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, + + // Forward tasks + task_id_t::BATCHMATMUL_FWD_TASK_ID, + task_id_t::BATCHNORM_FWD_TASK_ID, + task_id_t::BROADCAST_FWD_TASK_ID, + task_id_t::CAST_FWD_TASK_ID, + task_id_t::COMBINE_FWD_TASK_ID, + task_id_t::CONCAT_FWD_TASK_ID, + task_id_t::CONV2D_FWD_TASK_ID, + task_id_t::DROPOUT_FWD_TASK_ID, + task_id_t::ELEMENTBINARY_FWD_TASK_ID, + task_id_t::ELEMENTUNARY_FWD_TASK_ID, + task_id_t::EMBED_FWD_TASK_ID, + task_id_t::FLAT_FWD_TASK_ID, + task_id_t::GATHER_FWD_TASK_ID, + task_id_t::LAYERNORM_FWD_TASK_ID, + task_id_t::LINEAR_FWD_TASK_ID, + task_id_t::ATTENTION_FWD_TASK_ID, + task_id_t::POOL2D_FWD_TASK_ID, + task_id_t::REDUCE_FWD_TASK_ID, + task_id_t::REDUCTION_FWD_TASK_ID, + task_id_t::REPARTITION_FWD_TASK_ID, + task_id_t::REPLICATE_FWD_TASK_ID, + task_id_t::RESHAPE_FWD_TASK_ID, + task_id_t::REVERSE_FWD_TASK_ID, + task_id_t::SOFTMAX_FWD_TASK_ID, + task_id_t::SPLIT_FWD_TASK_ID, + task_id_t::TOPK_FWD_TASK_ID, + task_id_t::TRANSPOSE_FWD_TASK_ID, + + // Backward tasks + task_id_t::BATCHMATMUL_BWD_TASK_ID, + task_id_t::BATCHNORM_BWD_TASK_ID, + task_id_t::BROADCAST_BWD_TASK_ID, + task_id_t::CAST_BWD_TASK_ID, + task_id_t::COMBINE_BWD_TASK_ID, + task_id_t::CONCAT_BWD_TASK_ID, + task_id_t::CONV2D_BWD_TASK_ID, + task_id_t::DROPOUT_BWD_TASK_ID, + task_id_t::ELEMENTBINARY_BWD_TASK_ID, + task_id_t::ELEMENTUNARY_BWD_TASK_ID, + task_id_t::EMBED_BWD_TASK_ID, + task_id_t::FLAT_BWD_TASK_ID, + task_id_t::GATHER_BWD_TASK_ID, + task_id_t::LAYERNORM_BWD_TASK_ID, + task_id_t::LINEAR_BWD_TASK_ID, + task_id_t::ATTENTION_BWD_TASK_ID, + task_id_t::POOL2D_BWD_TASK_ID, + task_id_t::REDUCE_BWD_TASK_ID, + task_id_t::REDUCTION_BWD_TASK_ID, + task_id_t::REPARTITION_BWD_TASK_ID, + task_id_t::REPLICATE_BWD_TASK_ID, + task_id_t::RESHAPE_BWD_TASK_ID, + task_id_t::REVERSE_BWD_TASK_ID, + task_id_t::SOFTMAX_BWD_TASK_ID, + task_id_t::SPLIT_BWD_TASK_ID, + task_id_t::TOPK_BWD_TASK_ID, + task_id_t::TRANSPOSE_BWD_TASK_ID, + + // Update tasks + task_id_t::SGD_UPD_NCCL_TASK_ID, + task_id_t::ADAM_UPD_NCCL_TASK_ID, }; - for (task_id_t init_task_id : init_task_ids) { + for (task_id_t task_id : task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::LOC_PROC, task_id, operation_task_wrapper)); pending_registrations.push_back(register_task( - Realm::Processor::LOC_PROC, init_task_id, op_task_wrapper)); + Realm::Processor::TOC_PROC, task_id, operation_task_wrapper)); } return Realm::Event::merge_events(pending_registrations); diff --git a/lib/task-spec/CMakeLists.txt b/lib/task-spec/CMakeLists.txt index 3c7c91af67..f4f5353f70 100644 --- a/lib/task-spec/CMakeLists.txt +++ b/lib/task-spec/CMakeLists.txt @@ -14,6 +14,7 @@ ff_add_library( pcg spdlog compiler + Realm::Realm ) add_subdirectory(test) diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 89b94b1017..763ebf180f 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -14,6 +14,8 @@ includes = [ "op-attrs/parallel_tensor_space_coordinate.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", + "task-spec/realm/fmt/instance.h", + "task-spec/realm/realm.h", ] src_includes = [ @@ -36,6 +38,10 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" name = "accessor" type = "std::optional<::FlexFlow::DynamicTensorAccessor>" +[[fields]] +name = "instance" +type = "std::optional<::FlexFlow::Realm::RegionInstance>" + [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/realm/fmt/instance.h b/lib/task-spec/include/task-spec/realm/fmt/instance.h new file mode 100644 index 0000000000..23979c7efc --- /dev/null +++ b/lib/task-spec/include/task-spec/realm/fmt/instance.h @@ -0,0 +1,35 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H + +#include "task-spec/realm/realm.h" +#include "utils/check_fmtable.h" +#include +#include + +namespace fmt { + +template +struct formatter<::FlexFlow::Realm::RegionInstance, + Char, + std::enable_if_t::value>> + : formatter<::std::string> { + template + auto format(::FlexFlow::Realm::RegionInstance const &m, + FormatContext &ctx) const -> decltype(ctx.out()) { + std::string result = fmt::format("", m.id); + + return formatter::format(result, ctx); + } +}; + +} // namespace fmt + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::RegionInstance const &m); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/task-spec/include/task-spec/realm/realm.h similarity index 63% rename from lib/realm-execution/include/realm-execution/realm.h rename to lib/task-spec/include/task-spec/realm/realm.h index f15113ee92..8123c9e9fa 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/task-spec/include/task-spec/realm/realm.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_H +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #ifdef FLEXFLOW_USE_PREALM #include diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 4270119612..837ade2aad 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -23,6 +23,7 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; DynamicValueAttrs logit_grad_value{ @@ -30,6 +31,7 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; DynamicNodeInvocation loss_invocation{ diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 204597386e..294241b732 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -45,6 +45,7 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -64,6 +65,7 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index e90ef10398..eceb580a20 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -44,6 +44,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -64,6 +65,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 58a32db6c1..23708f3779 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -51,6 +51,7 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( DynamicValueAttrs value_attrs = output.second; ASSERT(value_attrs.accessor == std::nullopt); + ASSERT(value_attrs.instance == std::nullopt); DynamicNodeAttrs update_node_attrs = i.node_attrs; update_node_attrs.task_type = DynamicTaskType::UPD; diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.h b/lib/task-spec/src/task-spec/realm/fmt/instance.h new file mode 100644 index 0000000000..fa15e1c16f --- /dev/null +++ b/lib/task-spec/src/task-spec/realm/fmt/instance.h @@ -0,0 +1,10 @@ +#include "task-spec/realm/fmt/instance.h" + +namespace FlexFlow { + +std::ostream &operator<<(std::ostream &s, + ::FlexFlow::Realm::RegionInstance const &m) { + return s << fmt::to_string(m); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index fc9110b6e4..bb9a45e59a 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -16,6 +16,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -29,6 +30,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -42,6 +44,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index 40d37f50df..c28e12e0af 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -76,6 +76,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index e8fcf2e40b..e57691b475 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -20,6 +20,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -113,6 +114,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -229,6 +231,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/tensor_type, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 23fbb6e514..4d88dde805 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -121,6 +121,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, + /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; From ce403d44fa1a451cf1f7c3cd252c6ca6a58c608b Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 10:48:32 -0800 Subject: [PATCH 21/63] Fix filename. --- lib/task-spec/src/task-spec/realm/fmt/{instance.h => instance.cc} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename lib/task-spec/src/task-spec/realm/fmt/{instance.h => instance.cc} (100%) diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.h b/lib/task-spec/src/task-spec/realm/fmt/instance.cc similarity index 100% rename from lib/task-spec/src/task-spec/realm/fmt/instance.h rename to lib/task-spec/src/task-spec/realm/fmt/instance.cc From a4183dd045fd16d57602250aae1b2fd1e3f86a99 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 11:17:12 -0800 Subject: [PATCH 22/63] Some work in instance allocation and registry/manager. --- .../realm-execution/instance_allocation.h | 26 +++++ .../include/realm-execution/realm_context.h | 2 + .../realm-execution/realm_task_registry.h | 8 ++ .../realm-execution/instance_allocation.cc | 104 ++++++++++++++++++ .../parallel_computation_graph_instance.cc | 13 +-- .../src/realm-execution/realm_context.cc | 6 + .../src/realm-execution/realm_manager.cc | 46 ++++---- .../realm-execution/realm_task_registry.cc | 14 +-- 8 files changed, 182 insertions(+), 37 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/instance_allocation.h create mode 100644 lib/realm-execution/src/realm-execution/instance_allocation.cc diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h new file mode 100644 index 0000000000..ea07cf0601 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H + +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" + +namespace FlexFlow { + +bool no_instances_are_allocated(DynamicOpenDataflowGraph const &); +bool all_instances_are_allocated(DynamicOpenDataflowGraph const &); + +bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g); + +DynamicValueAttrs + perform_instance_allocation_for_value(DynamicValueAttrs const &, + Allocator &); + +DynamicOpenDataflowGraph perform_instance_allocation( + DynamicOpenDataflowGraph const &, + std::unordered_map const + &preallocated, + RealmContext &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 357b05b699..c72fe30b72 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -21,6 +21,8 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + Realm::Event get_outstanding_events(); + protected: [[nodiscard]] Realm::Event merge_outstanding_events(); diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index d9d993795b..d6bf5b927f 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -6,6 +6,14 @@ namespace FlexFlow { +[[nodiscard]] Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)); + [[nodiscard]] Realm::Event register_all_tasks(); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc new file mode 100644 index 0000000000..76d89313a6 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -0,0 +1,104 @@ +#include "realm-execution/instance_allocation.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "utils/bidict/generate_bidict.h" +#include "utils/containers/all_are_true.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/map_values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/exception.h" +#include "utils/optional.h" + +namespace FlexFlow { + +bool no_instances_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return !v.accessor.has_value() && !v.instance.has_value(); + })); +} + +bool all_instances_are_allocated(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return v.instance.has_value(); + })); +} + +bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { + return all_are_true( + transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { + return v.parallel_tensor_shape.has_value(); + })); +} + +DynamicValueAttrs + perform_instance_allocation_for_value(DynamicValueAttrs const &value, + RealmContext &ctx) { + ASSERT(value.accessor == std::nullopt); + ASSERT(value.instance == std::nullopt); + + TensorShape shape = + get_piece_shape(assert_unwrap(value.parallel_tensor_shape)); + + NOT_IMPLEMENTED(); + // GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + + DynamicValueAttrs result = value; + // result.accessor = DynamicTensorAccessor{accessor}; + + return result; +} + +DynamicOpenDataflowGraph perform_instance_allocation( + DynamicOpenDataflowGraph const &g, + std::unordered_map const + &preallocated, + RealmContext &ctx) { + ASSERT(no_instances_are_allocated(g)); + ASSERT(instances_are_ready_for_allocation(g)); + for (DynamicValueAttrs const &v : keys(preallocated)) { + ASSERT(v.accessor == std::nullopt); + ASSERT(v.instance == std::nullopt); + } + + std::unordered_set all_values = + unordered_set_of(get_dynamic_values(g)); + + bidict unallocated_to_allocated = + generate_bidict(all_values, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + if (contains_key(preallocated, v)) { + // FIXME: Attach external instance to existing + // allocation and use that + NOT_IMPLEMENTED(); + } else { + return perform_instance_allocation_for_value(v, ctx); + } + }); + + DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( + g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/map_values( + i.inputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + /*node_attrs=*/i.node_attrs, + /*outputs=*/ + map_values(i.outputs, + [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { + return unallocated_to_allocated.at_l(v); + }), + }; + }); + + ASSERT(all_instances_are_allocated(result)); + + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 80ed98f8c2..ec80519cf3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,7 +1,7 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include "local-execution/device_state_initialization.h" -#include "local-execution/tensor_allocation.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/instance_allocation.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" @@ -63,7 +63,7 @@ static GenericTensorAccessorW } ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmContext &realm, + RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, @@ -91,8 +91,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - dg = perform_tensor_allocation( - dg, inputs, realm.get_current_device_allocator()); + dg = perform_instance_allocation(dg, inputs, ctx); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { @@ -100,12 +99,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( }); dg = perform_device_state_initialization(dg, - realm.get_current_device_allocator(), + ctx.get_current_device_allocator(), profiling_settings, - realm.get_current_device_handle(), + ctx.get_current_device_handle(), iteration_config, optimizer_attrs, - realm.get_current_device_idx()); + ctx.get_current_device_idx()); NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 5068373ebe..ede6ae6d8d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -25,6 +25,12 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event RealmContext::get_outstanding_events() { + Realm::Event result = this->merge_outstanding_events(); + this->outstanding_events.push_back(result); + return result; +} + Realm::Event RealmContext::merge_outstanding_events() { Realm::Event result = Realm::Event::merge_events(this->outstanding_events); this->outstanding_events.clear(); diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 0c34d77204..63c6266948 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,21 +1,11 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_task_id_t.h" +#include "realm-execution/realm_task_registry.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { -RealmManager::RealmManager(int *argc, char ***argv) { - bool ok = this->runtime.init(argc, argv); - ASSERT(ok); -} - -RealmManager::~RealmManager() { - Realm::Event outstanding = this->merge_outstanding_events(); - this->runtime.shutdown(outstanding); - this->runtime.wait_for_shutdown(); -} - static void controller_task_wrapper(void const *args, size_t arglen, void const *userdata, @@ -29,26 +19,36 @@ static void controller_task_wrapper(void const *args, thunk(ctx); } +RealmManager::RealmManager(int *argc, char ***argv) { + bool ok = this->runtime.init(argc, argv); + ASSERT(ok); + + // Register all tasks at initialization time so we don't need to later + register_all_tasks().wait(); + register_task(Realm::Processor::LOC_PROC, + task_id_t::CONTROLLER_TASK_ID, + controller_task_wrapper) + .wait(); +} + +RealmManager::~RealmManager() { + Realm::Event outstanding = this->merge_outstanding_events(); + this->runtime.shutdown(outstanding); + this->runtime.wait_for_shutdown(); +} + Realm::Event RealmManager::start_controller(std::function thunk) { - Realm::Processor::TaskFuncID CONTROLLER_TASK_ID = - get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID); - Realm::Event task_ready = Realm::Processor::register_task_by_kind( - Realm::Processor::LOC_PROC, - /*global=*/false, - CONTROLLER_TASK_ID, - Realm::CodeDescriptor(controller_task_wrapper), - Realm::ProfilingRequestSet(), - &thunk, - sizeof(thunk)); - Realm::Processor target_proc = Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) .only_kind(Realm::Processor::LOC_PROC) .first(); Realm::Event task_complete = this->runtime.collective_spawn( - target_proc, CONTROLLER_TASK_ID, &thunk, sizeof(thunk), task_ready); + target_proc, + get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), + &thunk, + sizeof(thunk)); this->outstanding_events.push_back(task_complete); return task_complete; } diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/realm_task_registry.cc index 5c61c208fb..436a6af3f3 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/realm_task_registry.cc @@ -9,13 +9,13 @@ static void operation_task_wrapper( NOT_IMPLEMENTED(); } -static Realm::Event register_task(Realm::Processor::Kind target_kind, - task_id_t func_id, - void (*task_body)(void const *, - size_t, - void const *, - size_t, - Realm::Processor)) { +Realm::Event register_task(Realm::Processor::Kind target_kind, + task_id_t func_id, + void (*task_body)(void const *, + size_t, + void const *, + size_t, + Realm::Processor)) { return Realm::Processor::register_task_by_kind( target_kind, /*global=*/false, From 0274dd0343b47d91b868da72f2f54bee430a4d1a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 12:21:57 -0800 Subject: [PATCH 23/63] Instance allocation. --- .../realm-execution/instance_allocation.h | 2 +- .../include/realm-execution/realm_context.h | 10 ++ .../realm-execution/instance_allocation.cc | 16 ++-- .../parallel_computation_graph_instance.cc | 18 ++-- .../src/realm-execution/realm_context.cc | 93 +++++++++++++++++++ 5 files changed, 124 insertions(+), 15 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index ea07cf0601..d1dfa3fda0 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -15,7 +15,7 @@ DynamicValueAttrs perform_instance_allocation_for_value(DynamicValueAttrs const &, Allocator &); -DynamicOpenDataflowGraph perform_instance_allocation( +std::pair perform_instance_allocation( DynamicOpenDataflowGraph const &, std::unordered_map const &preallocated, diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index c72fe30b72..90ef402fb6 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -21,9 +21,19 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + // Instance management + std::pair + create_instance(Realm::Memory memory, + TensorShape const &shape, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on = Realm::Event::NO_EVENT); + + // Get the current set of outstanding events Realm::Event get_outstanding_events(); protected: + // Compact AND CLEAR the outstanding event queue + // Important: USER MUST BLOCK on event or else use it, or it WILL BE LOST [[nodiscard]] Realm::Event merge_outstanding_events(); protected: diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 76d89313a6..0870117bfe 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -1,11 +1,13 @@ #include "realm-execution/instance_allocation.h" #include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "utils/bidict/generate_bidict.h" #include "utils/containers/all_are_true.h" #include "utils/containers/contains_key.h" +#include "utils/containers/make.h" #include "utils/containers/map_values.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" @@ -40,19 +42,19 @@ DynamicValueAttrs ASSERT(value.accessor == std::nullopt); ASSERT(value.instance == std::nullopt); - TensorShape shape = - get_piece_shape(assert_unwrap(value.parallel_tensor_shape)); + TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); - NOT_IMPLEMENTED(); - // GenericTensorAccessorW accessor = allocator.allocate_tensor(shape); + Realm::Memory memory = Realm::Memory::NO_MEMORY; // FIXME + auto [instance, ready] = + ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); DynamicValueAttrs result = value; - // result.accessor = DynamicTensorAccessor{accessor}; + result.instance = instance; return result; } -DynamicOpenDataflowGraph perform_instance_allocation( +std::pair perform_instance_allocation( DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, @@ -98,7 +100,7 @@ DynamicOpenDataflowGraph perform_instance_allocation( ASSERT(all_instances_are_allocated(result)); - return result; + return std::pair{result, ctx.get_outstanding_events()}; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index ec80519cf3..dddb624df3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -53,13 +53,12 @@ std::optional return this->logit_grad_tensor; } -static GenericTensorAccessorW - get_loss_tensor_accessor(DynamicOpenDataflowGraph const &dg, +static Realm::RegionInstance + get_loss_tensor_instance(DynamicOpenDataflowGraph const &dg, DynamicValueAttrs const &value) { return find_output_tensor(dg, value.tensor_guid, value.role) .value() - .second.accessor.value() - .get(); + .second.instance.value(); } ParallelComputationGraphInstance create_parallel_computation_graph_instance( @@ -91,11 +90,16 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - dg = perform_instance_allocation(dg, inputs, ctx); + Realm::Event instances_ready; + { + auto [dg2, ready] = perform_instance_allocation(dg, inputs, ctx); + dg = dg2; + instances_ready = ready; + } - std::optional logit_grad_tensor = + std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - return get_loss_tensor_accessor(dg, lgv); + return get_loss_tensor_instance(dg, lgv); }); dg = perform_device_state_initialization(dg, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index ede6ae6d8d..6ab7f992fa 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,7 +1,9 @@ #include "realm-execution/realm_context.h" +#include "op-attrs/datatype.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" +#include "utils/positive_int/positive_int.h" namespace FlexFlow { @@ -25,6 +27,97 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +std::pair + RealmContext::create_instance(Realm::Memory memory, + TensorShape const &shape, + Realm::ProfilingRequestSet const &prs, + Realm::Event wait_on) { + std::vector dims{shape.dims.ff_ordered.begin(), + shape.dims.ff_ordered.end()}; + std::vector field_sizes{ + static_cast(int{size_of_datatype(shape.data_type)})}; + Realm::RegionInstance inst; + Realm::Event ready; + switch (shape.dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<1>(Realm::Point<1>::ZEROES(), + Realm::Point<1>(dims.data()) - + Realm::Point<1>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 2 + case 2: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<2>(Realm::Point<2>::ZEROES(), + Realm::Point<2>(dims.data()) - + Realm::Point<2>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 3 + case 3: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<3>(Realm::Point<3>::ZEROES(), + Realm::Point<3>(dims.data()) - + Realm::Point<3>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 4 + case 4: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<4>(Realm::Point<4>::ZEROES(), + Realm::Point<4>(dims.data()) - + Realm::Point<4>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif +#if REALM_MAX_DIM >= 5 + case 5: + ready = Realm::RegionInstance::create_instance( + inst, + memory, + Realm::Rect<5>(Realm::Point<5>::ZEROES(), + Realm::Point<5>(dims.data()) - + Realm::Point<5>::ONES()), + field_sizes, + /*block_size=*/0 /*SOA*/, + prs, + wait_on); + break; +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM", + fmt::to_string(shape.dims.ff_ordered.num_dims())); + break; + } + this->outstanding_events.push_back(ready); + return std::pair{inst, ready}; +} + Realm::Event RealmContext::get_outstanding_events() { Realm::Event result = this->merge_outstanding_events(); this->outstanding_events.push_back(result); From 9d24b3dc40d517102e1d4b67c564aeac1d84ec5e Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 12:26:48 -0800 Subject: [PATCH 24/63] Simplify dims and use constructors. --- .../src/realm-execution/realm_context.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 6ab7f992fa..4890eb4a5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -38,15 +38,15 @@ std::pair static_cast(int{size_of_datatype(shape.data_type)})}; Realm::RegionInstance inst; Realm::Event ready; - switch (shape.dims.ff_ordered.num_dims()) { + switch (dims.size()) { #if REALM_MAX_DIM >= 1 case 1: ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<1>(Realm::Point<1>::ZEROES(), - Realm::Point<1>(dims.data()) - - Realm::Point<1>::ONES()), + Realm::Rect<1>{Realm::Point<1>::ZEROES(), + Realm::Point<1>{dims.data()} - + Realm::Point<1>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -58,9 +58,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<2>(Realm::Point<2>::ZEROES(), - Realm::Point<2>(dims.data()) - - Realm::Point<2>::ONES()), + Realm::Rect<2>{Realm::Point<2>::ZEROES(), + Realm::Point<2>{dims.data()} - + Realm::Point<2>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -72,9 +72,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<3>(Realm::Point<3>::ZEROES(), - Realm::Point<3>(dims.data()) - - Realm::Point<3>::ONES()), + Realm::Rect<3>{Realm::Point<3>::ZEROES(), + Realm::Point<3>{dims.data()} - + Realm::Point<3>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -86,9 +86,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<4>(Realm::Point<4>::ZEROES(), - Realm::Point<4>(dims.data()) - - Realm::Point<4>::ONES()), + Realm::Rect<4>{Realm::Point<4>::ZEROES(), + Realm::Point<4>{dims.data()} - + Realm::Point<4>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -100,9 +100,9 @@ std::pair ready = Realm::RegionInstance::create_instance( inst, memory, - Realm::Rect<5>(Realm::Point<5>::ZEROES(), - Realm::Point<5>(dims.data()) - - Realm::Point<5>::ONES()), + Realm::Rect<5>{Realm::Point<5>::ZEROES(), + Realm::Point<5>{dims.data()} - + Realm::Point<5>::ONES()}, field_sizes, /*block_size=*/0 /*SOA*/, prs, @@ -111,7 +111,7 @@ std::pair #endif default: PANIC("TensorShape dims greater than REALM_MAX_DIM", - fmt::to_string(shape.dims.ff_ordered.num_dims())); + fmt::to_string(dims.size())); break; } this->outstanding_events.push_back(ready); From 60989fe4fd9d00b3cc0d4645ec3171282985157b Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 13:24:41 -0800 Subject: [PATCH 25/63] Refactor. --- .../src/realm-execution/realm_context.cc | 105 +++++++++--------- 1 file changed, 51 insertions(+), 54 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 4890eb4a5d..b2671f709e 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,5 +1,6 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" +#include "op-attrs/tensor_dims.dtg.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" #include "utils/exception.h" @@ -27,91 +28,87 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +template +static Realm::Rect rect_from_dims(TensorDims const &dims) { + std::vector values{dims.ff_ordered.begin(), dims.ff_ordered.end()}; + return Realm::Rect{Realm::Point::ZEROES(), + Realm::Point{values.data()} - + Realm::Point::ONES()}; +} + std::pair RealmContext::create_instance(Realm::Memory memory, TensorShape const &shape, Realm::ProfilingRequestSet const &prs, Realm::Event wait_on) { - std::vector dims{shape.dims.ff_ordered.begin(), - shape.dims.ff_ordered.end()}; std::vector field_sizes{ static_cast(int{size_of_datatype(shape.data_type)})}; Realm::RegionInstance inst; Realm::Event ready; - switch (dims.size()) { + switch (shape.dims.ff_ordered.num_dims()) { #if REALM_MAX_DIM >= 1 case 1: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<1>{Realm::Point<1>::ZEROES(), - Realm::Point<1>{dims.data()} - - Realm::Point<1>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<1>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 2 case 2: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<2>{Realm::Point<2>::ZEROES(), - Realm::Point<2>{dims.data()} - - Realm::Point<2>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<2>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 3 case 3: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<3>{Realm::Point<3>::ZEROES(), - Realm::Point<3>{dims.data()} - - Realm::Point<3>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<3>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 4 case 4: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<4>{Realm::Point<4>::ZEROES(), - Realm::Point<4>{dims.data()} - - Realm::Point<4>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<4>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif #if REALM_MAX_DIM >= 5 case 5: - ready = Realm::RegionInstance::create_instance( - inst, - memory, - Realm::Rect<5>{Realm::Point<5>::ZEROES(), - Realm::Point<5>{dims.data()} - - Realm::Point<5>::ONES()}, - field_sizes, - /*block_size=*/0 /*SOA*/, - prs, - wait_on); + ready = + Realm::RegionInstance::create_instance(inst, + memory, + rect_from_dims<5>(shape.dims), + field_sizes, + 0 /*SOA*/, + prs, + wait_on); break; #endif default: PANIC("TensorShape dims greater than REALM_MAX_DIM", - fmt::to_string(dims.size())); + fmt::to_string(shape.dims.ff_ordered.num_dims())); break; } this->outstanding_events.push_back(ready); From 8d46441031bf9c36ff9fca154de62c934ff5e068 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 14:39:40 -0800 Subject: [PATCH 26/63] Sketch out device mapping. --- .../include/realm-execution/realm_context.h | 5 +++ .../realm-execution/instance_allocation.cc | 41 +++++++++++-------- .../src/realm-execution/realm_context.cc | 9 ++++ 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 90ef402fb6..6ba64338c9 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -16,6 +16,11 @@ struct RealmContext { RealmContext(RealmContext const &) = delete; RealmContext(RealmContext &&) = delete; + // Device mapping + Realm::Processor + map_device_coord_to_processor(MachineSpaceCoordinate const &); + Realm::Memory get_nearest_memory(Realm::Processor) const; + // Current device context Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 0870117bfe..33b7b54937 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -2,8 +2,10 @@ #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "utils/bidict/generate_bidict.h" #include "utils/containers/all_are_true.h" #include "utils/containers/contains_key.h" @@ -37,14 +39,17 @@ bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { } DynamicValueAttrs - perform_instance_allocation_for_value(DynamicValueAttrs const &value, + perform_instance_allocation_for_value(DynamicNodeAttrs const &node, + DynamicValueAttrs const &value, RealmContext &ctx) { ASSERT(value.accessor == std::nullopt); ASSERT(value.instance == std::nullopt); TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); - Realm::Memory memory = Realm::Memory::NO_MEMORY; // FIXME + MachineSpaceCoordinate device_coord = assert_unwrap(node.device_coord); + Realm::Processor proc = ctx.map_device_coord_to_processor(device_coord); + Realm::Memory memory = ctx.get_nearest_memory(proc); auto [instance, ready] = ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); @@ -66,20 +71,20 @@ std::pair perform_instance_allocation( ASSERT(v.instance == std::nullopt); } - std::unordered_set all_values = - unordered_set_of(get_dynamic_values(g)); - - bidict unallocated_to_allocated = - generate_bidict(all_values, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - if (contains_key(preallocated, v)) { - // FIXME: Attach external instance to existing - // allocation and use that - NOT_IMPLEMENTED(); - } else { - return perform_instance_allocation_for_value(v, ctx); - } - }); + bidict unallocated_to_allocated; + auto allocate = [&](DynamicNodeAttrs const &n, DynamicValueAttrs const &v) { + if (contains_key(preallocated, v)) { + // FIXME: Attach external instance to existing allocation and use that + NOT_IMPLEMENTED(); + } else { + if (contains_key(unallocated_to_allocated, v)) { + return unallocated_to_allocated.at_l(v); + } else { + DynamicValueAttrs v2 = perform_instance_allocation_for_value(n, v, ctx); + uallocated_to_allocated.equate(v, v2); + } + } + }; DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { @@ -87,13 +92,13 @@ std::pair perform_instance_allocation( /*inputs=*/map_values( i.inputs, [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return unallocated_to_allocated.at_l(v); + return allocate(i.node_attrs, v); }), /*node_attrs=*/i.node_attrs, /*outputs=*/ map_values(i.outputs, [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return unallocated_to_allocated.at_l(v); + return allocate(i.node_attrs, v); }), }; }); diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index b2671f709e..30343652d7 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -17,6 +17,15 @@ RealmContext::~RealmContext() { } } +Realm::Processor RealmContext::map_device_coord_to_processor( + MachineSpaceCoordinate const &device_coord) { + NOT_IMPLEMENTED(); +} + +Realm::Memory get_nearest_memory(Realm::Processor proc) const { + NOT_IMPLEMENTED(); +} + Allocator &RealmContext::get_current_device_allocator() const { NOT_IMPLEMENTED(); } From 0dfa1a38b4995f5b01fc2b24bdcb7df19a27171f Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 15:53:27 -0800 Subject: [PATCH 27/63] Move instance backing to a separate map, remove realm from task-spec. --- .../include/realm-execution}/fmt/instance.h | 6 +- .../realm-execution/instance_allocation.h | 8 +-- .../include/realm-execution}/realm.h | 0 .../include/realm-execution/realm_context.h | 3 +- .../include/realm-execution/realm_manager.h | 2 +- .../include/realm-execution/realm_task_id_t.h | 2 +- .../realm-execution/realm_task_registry.h | 2 +- .../tensor_instance_backing.dtg.toml | 24 +++++++ .../realm-execution/tensor_instance_backing.h | 12 ++++ .../src/realm-execution}/fmt/instance.cc | 2 +- .../realm-execution/instance_allocation.cc | 72 ++++--------------- .../parallel_computation_graph_instance.cc | 17 +---- .../src/realm-execution/realm_context.cc | 2 +- .../tensor_instance_backing.cc | 11 +++ lib/task-spec/CMakeLists.txt | 1 - .../dynamic_value_attrs.dtg.toml | 6 -- .../task-spec/dynamic_graph/loss_insertion.cc | 2 - ...ake_dynamic_open_dataflow_graph_from_cg.cc | 2 - ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 - .../dynamic_graph/update_insertion.cc | 1 - .../dynamic_open_dataflow_graph.cc | 3 - .../dynamic_graph/machine_slicing.cc | 1 - .../task-spec/dynamic_graph/pass_expansion.cc | 3 - .../dynamic_graph/shard_expansion.cc | 1 - 24 files changed, 74 insertions(+), 111 deletions(-) rename lib/{task-spec/include/task-spec/realm => realm-execution/include/realm-execution}/fmt/instance.h (83%) rename lib/{task-spec/include/task-spec/realm => realm-execution/include/realm-execution}/realm.h (100%) create mode 100644 lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tensor_instance_backing.h rename lib/{task-spec/src/task-spec/realm => realm-execution/src/realm-execution}/fmt/instance.cc (82%) create mode 100644 lib/realm-execution/src/realm-execution/tensor_instance_backing.cc diff --git a/lib/task-spec/include/task-spec/realm/fmt/instance.h b/lib/realm-execution/include/realm-execution/fmt/instance.h similarity index 83% rename from lib/task-spec/include/task-spec/realm/fmt/instance.h rename to lib/realm-execution/include/realm-execution/fmt/instance.h index 23979c7efc..b2efc59b7d 100644 --- a/lib/task-spec/include/task-spec/realm/fmt/instance.h +++ b/lib/realm-execution/include/realm-execution/fmt/instance.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#include "task-spec/realm/realm.h" +#include "realm-execution/realm.h" #include "utils/check_fmtable.h" #include #include @@ -15,8 +15,8 @@ struct formatter<::FlexFlow::Realm::RegionInstance, ::FlexFlow::Realm::RegionInstance>::value>> : formatter<::std::string> { template - auto format(::FlexFlow::Realm::RegionInstance const &m, - FormatContext &ctx) const -> decltype(ctx.out()) { + auto format(::FlexFlow::Realm::RegionInstance const &m, FormatContext &ctx) + -> decltype(ctx.out()) { std::string result = fmt::format("", m.id); return formatter::format(result, ctx); diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index d1dfa3fda0..59065694e9 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -2,20 +2,16 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_INSTANCE_ALLOCATION_H #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" namespace FlexFlow { -bool no_instances_are_allocated(DynamicOpenDataflowGraph const &); -bool all_instances_are_allocated(DynamicOpenDataflowGraph const &); - -bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g); - DynamicValueAttrs perform_instance_allocation_for_value(DynamicValueAttrs const &, Allocator &); -std::pair perform_instance_allocation( +TensorInstanceBacking perform_instance_allocation( DynamicOpenDataflowGraph const &, std::unordered_map const &preallocated, diff --git a/lib/task-spec/include/task-spec/realm/realm.h b/lib/realm-execution/include/realm-execution/realm.h similarity index 100% rename from lib/task-spec/include/task-spec/realm/realm.h rename to lib/realm-execution/include/realm-execution/realm.h diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 6ba64338c9..bfc1a53cd3 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -4,7 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" -#include "task-spec/realm/realm.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "realm-execution/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index ebf3bb401e..bf5e8f72f1 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -4,8 +4,8 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" #include "pcg/device_id_t.dtg.h" +#include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/realm_task_id_t.h index 327cf9ffd0..8e6da1a2bd 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/realm_task_id_t.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/realm_task_registry.h index d6bf5b927f..f800b1d8c4 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/realm_task_registry.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#include "realm-execution/realm.h" #include "realm-execution/task_id_t.dtg.h" -#include "task-spec/realm/realm.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml new file mode 100644 index 0000000000..bdf08df59c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -0,0 +1,24 @@ +namespace = "FlexFlow" +name = "TensorInstanceBacking" +type = "struct" +features = [ + "eq", + #"fmt", + "hash", +] + +includes = [ + "", + "realm-execution/realm.h", + "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "realm-execution/fmt/instance.h", + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "backing" +type = "std::unordered_map<::FlexFlow::DynamicValueAttrs, std::pair<::FlexFlow::Realm::RegionInstance, ::FlexFlow::Realm::Event>>" diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.h b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h new file mode 100644 index 0000000000..1d143b7409 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.h @@ -0,0 +1,12 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TENSOR_INSTANCE_BACKING_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TENSOR_INSTANCE_BACKING_H + +#include "realm-execution/tensor_instance_backing.dtg.h" + +namespace FlexFlow { + +TensorInstanceBacking make_empty_tensor_instance_backing(); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/realm/fmt/instance.cc b/lib/realm-execution/src/realm-execution/fmt/instance.cc similarity index 82% rename from lib/task-spec/src/task-spec/realm/fmt/instance.cc rename to lib/realm-execution/src/realm-execution/fmt/instance.cc index fa15e1c16f..f8eabe9bb0 100644 --- a/lib/task-spec/src/task-spec/realm/fmt/instance.cc +++ b/lib/realm-execution/src/realm-execution/fmt/instance.cc @@ -1,4 +1,4 @@ -#include "task-spec/realm/fmt/instance.h" +#include "realm-execution/fmt/instance.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 33b7b54937..c033f0bac1 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -1,7 +1,9 @@ #include "realm-execution/instance_allocation.h" +#include "local-execution/tensor_allocation.h" #include "op-attrs/parallel_tensor_shape.h" #include "op-attrs/tensor_shape.dtg.h" #include "realm-execution/realm_context.h" +#include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -17,95 +19,47 @@ namespace FlexFlow { -bool no_instances_are_allocated(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return !v.accessor.has_value() && !v.instance.has_value(); - })); -} - -bool all_instances_are_allocated(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return v.instance.has_value(); - })); -} - -bool instances_are_ready_for_allocation(DynamicOpenDataflowGraph const &g) { - return all_are_true( - transform(get_dynamic_values(g), [](DynamicValueAttrs const &v) -> bool { - return v.parallel_tensor_shape.has_value(); - })); -} - -DynamicValueAttrs +std::pair perform_instance_allocation_for_value(DynamicNodeAttrs const &node, DynamicValueAttrs const &value, RealmContext &ctx) { ASSERT(value.accessor == std::nullopt); - ASSERT(value.instance == std::nullopt); TensorShape shape = get_piece_shape(value.parallel_tensor_shape.value()); MachineSpaceCoordinate device_coord = assert_unwrap(node.device_coord); Realm::Processor proc = ctx.map_device_coord_to_processor(device_coord); Realm::Memory memory = ctx.get_nearest_memory(proc); - auto [instance, ready] = - ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); - - DynamicValueAttrs result = value; - result.instance = instance; - - return result; + return ctx.create_instance(memory, shape, Realm::ProfilingRequestSet()); } -std::pair perform_instance_allocation( +TensorInstanceBacking perform_instance_allocation( DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, RealmContext &ctx) { - ASSERT(no_instances_are_allocated(g)); - ASSERT(instances_are_ready_for_allocation(g)); + ASSERT(no_tensors_are_allocated(g)); + ASSERT(tensors_are_ready_for_allocation(g)); for (DynamicValueAttrs const &v : keys(preallocated)) { ASSERT(v.accessor == std::nullopt); - ASSERT(v.instance == std::nullopt); } - bidict unallocated_to_allocated; + TensorInstanceBacking result = make_empty_tensor_instance_backing(); auto allocate = [&](DynamicNodeAttrs const &n, DynamicValueAttrs const &v) { if (contains_key(preallocated, v)) { // FIXME: Attach external instance to existing allocation and use that NOT_IMPLEMENTED(); } else { - if (contains_key(unallocated_to_allocated, v)) { - return unallocated_to_allocated.at_l(v); + if (contains_key(result.backing, v)) { + return result.backing.at(v); } else { - DynamicValueAttrs v2 = perform_instance_allocation_for_value(n, v, ctx); - uallocated_to_allocated.equate(v, v2); + result.backing.insert( + std::pair{v, perform_instance_allocation_for_value(n, v, ctx)}); } } }; - DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( - g, [&](DynamicNodeInvocation const &i) -> DynamicNodeInvocation { - return DynamicNodeInvocation{ - /*inputs=*/map_values( - i.inputs, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return allocate(i.node_attrs, v); - }), - /*node_attrs=*/i.node_attrs, - /*outputs=*/ - map_values(i.outputs, - [&](DynamicValueAttrs const &v) -> DynamicValueAttrs { - return allocate(i.node_attrs, v); - }), - }; - }); - - ASSERT(all_instances_are_allocated(result)); - - return std::pair{result, ctx.get_outstanding_events()}; + return result; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index dddb624df3..e0e4f769d3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -53,14 +53,6 @@ std::optional return this->logit_grad_tensor; } -static Realm::RegionInstance - get_loss_tensor_instance(DynamicOpenDataflowGraph const &dg, - DynamicValueAttrs const &value) { - return find_output_tensor(dg, value.tensor_guid, value.role) - .value() - .second.instance.value(); -} - ParallelComputationGraphInstance create_parallel_computation_graph_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, @@ -90,16 +82,11 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_update_insertion(dg, optimizer_attrs); dg = perform_shard_expansion(dg); - Realm::Event instances_ready; - { - auto [dg2, ready] = perform_instance_allocation(dg, inputs, ctx); - dg = dg2; - instances_ready = ready; - } + TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { - return get_loss_tensor_instance(dg, lgv); + return backing.backing.at(lgv).first; }); dg = perform_device_state_initialization(dg, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 30343652d7..4c02c13aa0 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -22,7 +22,7 @@ Realm::Processor RealmContext::map_device_coord_to_processor( NOT_IMPLEMENTED(); } -Realm::Memory get_nearest_memory(Realm::Processor proc) const { +Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { NOT_IMPLEMENTED(); } diff --git a/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc new file mode 100644 index 0000000000..53c2a2b271 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tensor_instance_backing.cc @@ -0,0 +1,11 @@ +#include "realm-execution/tensor_instance_backing.h" + +namespace FlexFlow { + +TensorInstanceBacking make_empty_tensor_instance_backing() { + return TensorInstanceBacking{ + /*backing=*/{}, + }; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/CMakeLists.txt b/lib/task-spec/CMakeLists.txt index f4f5353f70..3c7c91af67 100644 --- a/lib/task-spec/CMakeLists.txt +++ b/lib/task-spec/CMakeLists.txt @@ -14,7 +14,6 @@ ff_add_library( pcg spdlog compiler - Realm::Realm ) add_subdirectory(test) diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 763ebf180f..89b94b1017 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -14,8 +14,6 @@ includes = [ "op-attrs/parallel_tensor_space_coordinate.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", - "task-spec/realm/fmt/instance.h", - "task-spec/realm/realm.h", ] src_includes = [ @@ -38,10 +36,6 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" name = "accessor" type = "std::optional<::FlexFlow::DynamicTensorAccessor>" -[[fields]] -name = "instance" -type = "std::optional<::FlexFlow::Realm::RegionInstance>" - [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 837ade2aad..4270119612 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -23,7 +23,6 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; DynamicValueAttrs logit_grad_value{ @@ -31,7 +30,6 @@ LossInsertionResult perform_loss_insertion(DynamicOpenDataflowGraph const &dg, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, /*shard_coord=*/logit_value.shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; DynamicNodeInvocation loss_invocation{ diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 294241b732..204597386e 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -45,7 +45,6 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -65,7 +64,6 @@ DynamicOpenDataflowGraph /*parallel_tensor_shape=*/lift_to_parallel(attrs.shape), /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index eceb580a20..e90ef10398 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -44,7 +44,6 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; @@ -65,7 +64,6 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*parallel_tensor_shape=*/attrs.shape, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }, }; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 23708f3779..58a32db6c1 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -51,7 +51,6 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( DynamicValueAttrs value_attrs = output.second; ASSERT(value_attrs.accessor == std::nullopt); - ASSERT(value_attrs.instance == std::nullopt); DynamicNodeAttrs update_node_attrs = i.node_attrs; update_node_attrs.task_type = DynamicTaskType::UPD; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index bb9a45e59a..fc9110b6e4 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -16,7 +16,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -30,7 +29,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; @@ -44,7 +42,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*tensor_type=*/std::nullopt, }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index c28e12e0af..40d37f50df 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -76,7 +76,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index e57691b475..e8fcf2e40b 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -20,7 +20,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -114,7 +113,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_role, }; }; @@ -231,7 +229,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/std::nullopt, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/tensor_type, }; }; diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index 4d88dde805..23fbb6e514 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -121,7 +121,6 @@ TEST_SUITE(FF_TEST_SUITE) { /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, /*accessor=*/std::nullopt, - /*instance=*/std::nullopt, /*role=*/std::nullopt, }; }; From a4bc84ecca300c6b4fc5e7e01004c0196ad87fdf Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 6 Feb 2026 16:51:32 -0800 Subject: [PATCH 28/63] Implement processor queries. --- .../include/realm-execution/realm_context.h | 11 +++- .../parallel_computation_graph_instance.cc | 7 ++- .../src/realm-execution/realm_context.cc | 56 ++++++++++++++++++- .../src/realm-execution/realm_manager.cc | 6 +- 4 files changed, 72 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index bfc1a53cd3..73d60e9f50 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,14 +6,16 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include namespace FlexFlow { struct RealmContext { public: - RealmContext(); + RealmContext(Realm::Processor); virtual ~RealmContext(); + RealmContext() = delete; RealmContext(RealmContext const &) = delete; RealmContext(RealmContext &&) = delete; @@ -23,6 +25,7 @@ struct RealmContext { Realm::Memory get_nearest_memory(Realm::Processor) const; // Current device context + Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; @@ -42,9 +45,15 @@ struct RealmContext { // Important: USER MUST BLOCK on event or else use it, or it WILL BE LOST [[nodiscard]] Realm::Event merge_outstanding_events(); + void discover_machine_topology(); + protected: Realm::Runtime runtime; + Realm::Processor processor; std::vector outstanding_events; + std::unordered_map, + std::vector> + processors; }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index e0e4f769d3..5d6aeddf83 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -10,6 +10,7 @@ #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" #include "utils/exception.h" +#include "utils/optional.h" namespace FlexFlow { @@ -74,10 +75,12 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::optional logit_grad_value; if (loss_attrs) { auto [dg2, label_v, logit_grad_v] = perform_loss_insertion( - dg, loss_attrs.value(), dynamic_tensor_guid_t{logit_tensor.value()}); + dg, + assert_unwrap(loss_attrs), + dynamic_tensor_guid_t{assert_unwrap(logit_tensor)}); dg = dg2; logit_grad_value = logit_grad_v; - inputs.insert(std::pair{label_v, label_tensor.value()}); + inputs.insert(std::pair{label_v, assert_unwrap(label_tensor)}); } dg = perform_update_insertion(dg, optimizer_attrs); diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 4c02c13aa0..bf5f337796 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,14 +1,19 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" +#include "pcg/device_type.dtg.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/task_id_t.dtg.h" +#include "utils/containers/contains_key.h" +#include "utils/containers/transform.h" #include "utils/exception.h" +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/one_to_many/one_to_many.h" #include "utils/positive_int/positive_int.h" namespace FlexFlow { -RealmContext::RealmContext() {} +RealmContext::RealmContext(Realm::Processor proc) : processor(proc) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -17,13 +22,45 @@ RealmContext::~RealmContext() { } } +static std::tuple + convert_machine_space_coordinate( + MachineSpaceCoordinate const &device_coord) { + Realm::AddressSpace as = int{device_coord.node_idx}; + Realm::Processor::Kind kind; + switch (device_coord.device_type) { + case DeviceType::CPU: + kind = Realm::Processor::Kind::LOC_PROC; + break; + case DeviceType::GPU: + kind = Realm::Processor::Kind::TOC_PROC; + break; + default: + PANIC("Unhandled DeviceType", fmt::to_string(device_coord.device_type)); + break; + } + nonnegative_int proc_in_node = device_coord.device_idx; + return std::tuple{as, kind, proc_in_node}; +} + Realm::Processor RealmContext::map_device_coord_to_processor( MachineSpaceCoordinate const &device_coord) { - NOT_IMPLEMENTED(); + this->discover_machine_topology(); + auto [as, kind, proc_in_node] = + convert_machine_space_coordinate(device_coord); + return this->processors.at(std::pair{as, kind}).at(int{proc_in_node}); } Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { - NOT_IMPLEMENTED(); + // FIMXE: this isn't going to do what you expect until + // https://github.com/StanfordLegion/realm/pull/392 merges + Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); + mq.best_affinity_to(proc); + ASSERT(mq.count() > 0); + return mq.first(); +} + +Realm::Processor RealmContext::get_current_processor() const { + return this->processor; } Allocator &RealmContext::get_current_device_allocator() const { @@ -136,4 +173,17 @@ Realm::Event RealmContext::merge_outstanding_events() { return result; } +void RealmContext::discover_machine_topology() { + if (!this->processors.empty()) { + return; + } + + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + for (Realm::Processor proc : pq) { + Realm::AddressSpace as = proc.address_space(); + Realm::Processor::Kind kind = proc.kind(); + this->processors[std::pair{as, kind}].push_back(proc); + } +} + } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 63c6266948..f8a3e4014b 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/realm_context.h" #include "realm-execution/realm_task_id_t.h" #include "realm-execution/realm_task_registry.h" #include "realm-execution/task_id_t.dtg.h" @@ -15,11 +16,12 @@ static void controller_task_wrapper(void const *args, std::function thunk = *reinterpret_cast const *>(args); - RealmContext ctx; + RealmContext ctx{proc}; thunk(ctx); } -RealmManager::RealmManager(int *argc, char ***argv) { +RealmManager::RealmManager(int *argc, char ***argv) + : RealmContext(Realm::Processor::NO_PROC) { bool ok = this->runtime.init(argc, argv); ASSERT(ok); From 02b71a81bada07327e3812e287a78aed37249c57 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 7 Feb 2026 11:41:18 -0800 Subject: [PATCH 29/63] Enable PRealm. --- .flake/pkgs/realm.nix | 10 ++++++---- lib/realm-execution/include/realm-execution/realm.h | 2 ++ .../realm-execution/tensor_instance_backing.dtg.toml | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.flake/pkgs/realm.nix b/.flake/pkgs/realm.nix index 1249c0ae28..b809573690 100644 --- a/.flake/pkgs/realm.nix +++ b/.flake/pkgs/realm.nix @@ -3,6 +3,7 @@ , fetchFromGitHub , cmake , cudaPackages ? { } +, zlib , maxDim ? 5 }: @@ -12,14 +13,13 @@ in stdenv.mkDerivation rec { pname = "realm"; - version = "2025-01-06"; + version = "2026-02-06"; - # This version is compatible with Legion 7be1abd0207eb1126c7629b16d1123fa6f58ce9d src = fetchFromGitHub { owner = "StanfordLegion"; repo = "realm"; - rev = "0ef7edc8c012d4ab6a50805c044cec8a8edeae33"; - sha256 = "sha256-57/a1lAgs+ajpRn0y0Lk1gP5nKt+N08WW0DIJP4vdho="; + rev = "0405b67ca14b586f7dec0dcddee194cecee7efa6"; + sha256 = "sha256-iUPVV1rh3QuyDKgXuu8aDlaZGlNwcpPvPsSVLWp8tr4="; }; nativeBuildInputs = [ @@ -29,11 +29,13 @@ stdenv.mkDerivation rec { cmakeFlags = [ "-DBUILD_SHARED_LIBS=ON" "-DREALM_ENABLE_CUDA=ON" + "-DREALM_ENABLE_PREALM=ON" "-DREALM_MAX_DIM=${toString maxDim}" ]; buildInputs = [ cudatoolkit + zlib ]; meta = with lib; { diff --git a/lib/realm-execution/include/realm-execution/realm.h b/lib/realm-execution/include/realm-execution/realm.h index 8123c9e9fa..b6913e66f5 100644 --- a/lib/realm-execution/include/realm-execution/realm.h +++ b/lib/realm-execution/include/realm-execution/realm.h @@ -1,6 +1,8 @@ #ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H #define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_REALM_H +#define FLEXFLOW_USE_PREALM + #ifdef FLEXFLOW_USE_PREALM #include #else diff --git a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml index bdf08df59c..e6a8bd58d9 100644 --- a/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tensor_instance_backing.dtg.toml @@ -4,7 +4,7 @@ type = "struct" features = [ "eq", #"fmt", - "hash", + #"hash", ] includes = [ From b144d6dfbd1ec7457ae98f4239ac6c5f084a1b5a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:18:32 -0800 Subject: [PATCH 30/63] Move tasks to dedicated file, stub out device state init, shuffle directories. --- .../distributed_device_state_initialization.h | 21 ++++++++++++++++ .../{ => tasks}/realm_task_id_t.h | 2 +- .../{ => tasks}/realm_task_registry.h | 2 +- .../realm-execution/tasks/realm_tasks.h | 15 ++++++++++++ .../{ => tasks}/task_id_t.dtg.toml | 0 .../realm-execution/{ => tasks}/task_id_t.h | 2 +- ...distributed_device_state_initialization.cc | 15 ++++++++++++ .../parallel_computation_graph_instance.cc | 17 ++++++------- .../src/realm-execution/realm_context.cc | 4 ++-- .../src/realm-execution/realm_manager.cc | 23 +++--------------- .../{ => tasks}/realm_task_id_t.cc | 2 +- .../{ => tasks}/realm_task_registry.cc | 21 ++++++++-------- .../src/realm-execution/tasks/realm_tasks.cc | 24 +++++++++++++++++++ .../realm-execution/{ => tasks}/task_id_t.cc | 2 +- 14 files changed, 104 insertions(+), 46 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h rename lib/realm-execution/include/realm-execution/{ => tasks}/realm_task_id_t.h (86%) rename lib/realm-execution/include/realm-execution/{ => tasks}/realm_task_registry.h (94%) create mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_tasks.h rename lib/realm-execution/include/realm-execution/{ => tasks}/task_id_t.dtg.toml (100%) rename lib/realm-execution/include/realm-execution/{ => tasks}/task_id_t.h (94%) create mode 100644 lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc rename lib/realm-execution/src/realm-execution/{ => tasks}/realm_task_id_t.cc (82%) rename lib/realm-execution/src/realm-execution/{ => tasks}/realm_task_registry.cc (86%) create mode 100644 lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc rename lib/realm-execution/src/realm-execution/{ => tasks}/task_id_t.cc (99%) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h new file mode 100644 index 0000000000..4121f10341 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_STATE_INITIALIZATION_H + +#include "kernels/profiling_settings.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph perform_distributed_device_state_initialization( + DynamicOpenDataflowGraph const &, + RealmContext &ctx, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h similarity index 86% rename from lib/realm-execution/include/realm-execution/realm_task_id_t.h rename to lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h index 8e6da1a2bd..cd5eba2f34 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H #include "realm-execution/realm.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/realm_task_registry.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h similarity index 94% rename from lib/realm-execution/include/realm-execution/realm_task_registry.h rename to lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h index f800b1d8c4..a0277382bf 100644 --- a/lib/realm-execution/include/realm-execution/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H #include "realm-execution/realm.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" namespace FlexFlow { diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h new file mode 100644 index 0000000000..d2b104faa8 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H + +#include "realm-execution/realm.h" + +namespace FlexFlow { + +void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); + +void controller_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml similarity index 100% rename from lib/realm-execution/include/realm-execution/task_id_t.dtg.toml rename to lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml diff --git a/lib/realm-execution/include/realm-execution/task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h similarity index 94% rename from lib/realm-execution/include/realm-execution/task_id_t.h rename to lib/realm-execution/include/realm-execution/tasks/task_id_t.h index 38b82ad9e0..4a5d9299ae 100644 --- a/lib/realm-execution/include/realm-execution/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h @@ -3,7 +3,7 @@ #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc new file mode 100644 index 0000000000..c6d0621f3d --- /dev/null +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -0,0 +1,15 @@ +#include "realm-execution/distributed_device_state_initialization.h" +#include "utils/exception.h" + +namespace FlexFlow { + +DynamicOpenDataflowGraph perform_distributed_device_state_initialization( + DynamicOpenDataflowGraph const &dg, + RealmContext &ctx, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs) { + NOT_IMPLEMENTED(); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 5d6aeddf83..bb763334d5 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -1,6 +1,6 @@ #include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" -#include "local-execution/device_state_initialization.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" @@ -92,14 +92,15 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( return backing.backing.at(lgv).first; }); - dg = perform_device_state_initialization(dg, - ctx.get_current_device_allocator(), - profiling_settings, - ctx.get_current_device_handle(), - iteration_config, - optimizer_attrs, - ctx.get_current_device_idx()); + dg = perform_distributed_device_state_initialization( + dg, ctx, profiling_settings, iteration_config, optimizer_attrs); NOT_IMPLEMENTED(); + + // TODO list: + // * per-device state initialization (RPC mechanism?) + // * Realm allocator + // * task body + // * external instances } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index bf5f337796..37f72ba86d 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -2,8 +2,8 @@ #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" #include "pcg/device_type.dtg.h" -#include "realm-execution/realm_task_id_t.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/containers/contains_key.h" #include "utils/containers/transform.h" #include "utils/exception.h" diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index f8a3e4014b..9d8b9f0b7f 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,25 +1,12 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/realm_task_id_t.h" -#include "realm-execution/realm_task_registry.h" -#include "realm-execution/task_id_t.dtg.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/exception.h" namespace FlexFlow { -static void controller_task_wrapper(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); - - RealmContext ctx{proc}; - thunk(ctx); -} - RealmManager::RealmManager(int *argc, char ***argv) : RealmContext(Realm::Processor::NO_PROC) { bool ok = this->runtime.init(argc, argv); @@ -27,10 +14,6 @@ RealmManager::RealmManager(int *argc, char ***argv) // Register all tasks at initialization time so we don't need to later register_all_tasks().wait(); - register_task(Realm::Processor::LOC_PROC, - task_id_t::CONTROLLER_TASK_ID, - controller_task_wrapper) - .wait(); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc similarity index 82% rename from lib/realm-execution/src/realm-execution/realm_task_id_t.cc rename to lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc index 50b23dfe86..ec1aa143a6 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_id_t.cc @@ -1,4 +1,4 @@ -#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_id_t.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc similarity index 86% rename from lib/realm-execution/src/realm-execution/realm_task_registry.cc rename to lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 436a6af3f3..7e30edbc9f 100644 --- a/lib/realm-execution/src/realm-execution/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,14 +1,10 @@ -#include "realm-execution/realm_task_registry.h" -#include "realm-execution/realm_task_id_t.h" +#include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/realm_tasks.h" #include "utils/exception.h" namespace FlexFlow { -static void operation_task_wrapper( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - Realm::Event register_task(Realm::Processor::Kind target_kind, task_id_t func_id, void (*task_body)(void const *, @@ -110,12 +106,15 @@ Realm::Event register_all_tasks() { }; for (task_id_t task_id : task_ids) { - pending_registrations.push_back(register_task( - Realm::Processor::LOC_PROC, task_id, operation_task_wrapper)); - pending_registrations.push_back(register_task( - Realm::Processor::TOC_PROC, task_id, operation_task_wrapper)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, task_id, op_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::TOC_PROC, task_id, op_task_body)); } + pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, + task_id_t::CONTROLLER_TASK_ID, + controller_task_body)); return Realm::Event::merge_events(pending_registrations); } diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc new file mode 100644 index 0000000000..a50f7f3e47 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc @@ -0,0 +1,24 @@ +#include "realm-execution/tasks/realm_tasks.h" +#include "realm-execution/realm_context.h" + +namespace FlexFlow { + +void op_task_body( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} + +void controller_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(std::function)); + std::function thunk = + *reinterpret_cast const *>(args); + + RealmContext ctx{proc}; + thunk(ctx); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc similarity index 99% rename from lib/realm-execution/src/realm-execution/task_id_t.cc rename to lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index 3521f50c02..5a99f2bea8 100644 --- a/lib/realm-execution/src/realm-execution/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -1,4 +1,4 @@ -#include "realm-execution/task_id_t.h" +#include "realm-execution/tasks/task_id_t.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" From 4d43a7bb175573b6b6665cad7e65c56b42a391c7 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:47:41 -0800 Subject: [PATCH 31/63] Make use of task args struct. --- .../realm-execution/tasks/realm_tasks.h | 20 +++++++++++++++++++ .../src/realm-execution/tasks/realm_tasks.cc | 14 +++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h index d2b104faa8..ceda961914 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h @@ -2,11 +2,31 @@ #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H #include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceInitTaskArgs { +public: + DynamicNodeInvocation *invocation; +}; +static_assert(std::has_unique_object_representations_v); + +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +struct ControllerTaskArgs { +public: + std::function thunk; +}; + void controller_task_body( void const *, size_t, void const *, size_t, Realm::Processor); diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc index a50f7f3e47..b1da1f0694 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc @@ -1,5 +1,6 @@ #include "realm-execution/tasks/realm_tasks.h" #include "realm-execution/realm_context.h" +#include "utils/exception.h" namespace FlexFlow { @@ -8,17 +9,22 @@ void op_task_body( NOT_IMPLEMENTED(); } +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor) { + NOT_IMPLEMENTED(); +} + void controller_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(std::function)); - std::function thunk = - *reinterpret_cast const *>(args); + ASSERT(arglen == sizeof(ControllerTaskArgs)); + ControllerTaskArgs task_args = + *reinterpret_cast(args); RealmContext ctx{proc}; - thunk(ctx); + task_args.thunk(ctx); } } // namespace FlexFlow From 499191170fbcdd2a93444ef063f9c54c8b242068 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 10:54:01 -0800 Subject: [PATCH 32/63] Use task args struct. --- lib/realm-execution/src/realm-execution/realm_manager.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 9d8b9f0b7f..dec2ed7847 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -2,6 +2,7 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/realm_tasks.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/exception.h" @@ -29,11 +30,13 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); + ControllerTaskArgs task_args; + task_args.thunk = thunk; Realm::Event task_complete = this->runtime.collective_spawn( target_proc, get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), - &thunk, - sizeof(thunk)); + &task_args, + sizeof(task_args)); this->outstanding_events.push_back(task_complete); return task_complete; } From 6f65c510eccc1ff685cbb9aef70479ed8f34a663 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 14:26:45 -0800 Subject: [PATCH 33/63] Refactor task APIs. --- .../include/realm-execution/realm_context.h | 18 +++++++ .../tasks/impl/controller_task.h | 19 +++++++ .../tasks/impl/device_init_return_task.h | 21 ++++++++ .../tasks/impl/device_init_task.h | 24 +++++++++ .../realm-execution/tasks/impl/op_task.h | 21 ++++++++ .../tasks/realm_task_registry.h | 4 +- .../realm-execution/tasks/realm_tasks.h | 35 ------------ .../realm-execution/tasks/task_id_t.dtg.toml | 3 ++ .../include/realm-execution/tasks/task_id_t.h | 4 +- .../src/realm-execution/realm_context.cc | 35 ++++++++++++ .../src/realm-execution/realm_manager.cc | 15 +----- .../tasks/impl/controller_task.cc | 37 +++++++++++++ .../tasks/impl/device_init_return_task.cc | 49 +++++++++++++++++ .../tasks/impl/device_init_task.cc | 54 +++++++++++++++++++ .../src/realm-execution/tasks/impl/op_task.cc | 48 +++++++++++++++++ .../tasks/realm_task_registry.cc | 5 +- .../src/realm-execution/tasks/realm_tasks.cc | 30 ----------- 17 files changed, 339 insertions(+), 83 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/op_task.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/realm_tasks.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc delete mode 100644 lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 73d60e9f50..422c4f4027 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,6 +6,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include namespace FlexFlow { @@ -30,6 +31,23 @@ struct RealmContext { device_handle_t const &get_current_device_handle() const; device_id_t const &get_current_device_idx() const; + // Task creation + Realm::Event spawn_task(Realm::Processor proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + + Realm::Event + collective_spawn_task(Realm::Processor target_proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); + // Instance management std::pair create_instance(Realm::Memory memory, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h new file mode 100644 index 0000000000..d4c397bb37 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_CONTROLLER_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_CONTROLLER_TASK_H + +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" + +namespace FlexFlow { + +void controller_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + collective_spawn_controller_task(RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h new file mode 100644 index 0000000000..fc6c8bdb9f --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H + +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_init_return_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h new file mode 100644 index 0000000000..bd4ca269df --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H + +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +void device_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h new file mode 100644 index 0000000000..4c3e6d38d1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H + +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h index a0277382bf..8114f1a82c 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_registry.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_REGISTRY_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_REGISTRY_H #include "realm-execution/realm.h" #include "realm-execution/tasks/task_id_t.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h b/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h deleted file mode 100644 index ceda961914..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/realm_tasks.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASKS_H - -#include "realm-execution/realm.h" -#include "realm-execution/realm_context.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include - -namespace FlexFlow { - -void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); - -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceInitTaskArgs { -public: - DynamicNodeInvocation *invocation; -}; -static_assert(std::has_unique_object_representations_v); - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -struct ControllerTaskArgs { -public: - std::function thunk; -}; - -void controller_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -} // namespace FlexFlow - -#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index 0336bc81a4..34e5183488 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -11,6 +11,9 @@ features = [ [[values]] name = "CONTROLLER_TASK_ID" +[[values]] +name = "DEVICE_INIT_RETURN_TASK_ID" + [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h index 4a5d9299ae..53945d2e5b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASK_ID_T_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_TASK_ID_T_H #include "op-attrs/pcg_operator_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 37f72ba86d..7e6c73c9e7 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -74,6 +74,41 @@ device_id_t const &RealmContext::get_current_device_idx() const { NOT_IMPLEMENTED(); } +Realm::Event + RealmContext::spawn_task(Realm::Processor proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + Realm::Event result = proc.spawn(get_realm_task_id_for_task_id(task_id), + args, + arglen, + requests, + wait_on, + priority); + this->outstanding_events.push_back(result); + return result; +} + +Realm::Event RealmContext::collective_spawn_task(Realm::Processor target_proc, + task_id_t task_id, + void const *args, + size_t arglen, + Realm::Event wait_on, + int priority) { + Realm::Event result = + this->runtime.collective_spawn(target_proc, + get_realm_task_id_for_task_id(task_id), + args, + arglen, + wait_on, + priority); + this->outstanding_events.push_back(result); + return result; +} + template static Realm::Rect rect_from_dims(TensorDims const &dims) { std::vector values{dims.ff_ordered.begin(), dims.ff_ordered.end()}; diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index dec2ed7847..7233103cc3 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,10 +1,7 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" -#include "realm-execution/tasks/realm_task_id_t.h" +#include "realm-execution/tasks/impl/controller_task.h" #include "realm-execution/tasks/realm_task_registry.h" -#include "realm-execution/tasks/realm_tasks.h" -#include "realm-execution/tasks/task_id_t.dtg.h" -#include "utils/exception.h" namespace FlexFlow { @@ -30,15 +27,7 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); - ControllerTaskArgs task_args; - task_args.thunk = thunk; - Realm::Event task_complete = this->runtime.collective_spawn( - target_proc, - get_realm_task_id_for_task_id(task_id_t::CONTROLLER_TASK_ID), - &task_args, - sizeof(task_args)); - this->outstanding_events.push_back(task_complete); - return task_complete; + return collective_spawn_controller_task(*this, target_proc, thunk); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc new file mode 100644 index 0000000000..2fd5cee52d --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc @@ -0,0 +1,37 @@ +#include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/task_id_t.h" + +namespace FlexFlow { + +struct ControllerTaskArgs { +public: + std::function thunk; +}; + +void controller_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(ControllerTaskArgs)); + ControllerTaskArgs task_args = + *reinterpret_cast(args); + + RealmContext ctx{proc}; + task_args.thunk(ctx); +} + +Realm::Event collective_spawn_controller_task( + RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk) { + ControllerTaskArgs task_args; + task_args.thunk = thunk; + + return ctx.collective_spawn_task(target_proc, + task_id_t::CONTROLLER_TASK_ID, + &task_args, + sizeof(task_args)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc new file mode 100644 index 0000000000..fa421cda30 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc @@ -0,0 +1,49 @@ +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceInitReturnTaskArgs { +public: + DeviceInitReturnTaskArgs() = delete; + DeviceInitReturnTaskArgs(DeviceSpecificPerDeviceOpState result, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecificPerDeviceOpState result; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; + +void device_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceInitReturnTaskArgs)); + DeviceInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr) { + DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc new file mode 100644 index 0000000000..0deb8407c4 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -0,0 +1,54 @@ +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/task_id_t.h" +#include "utils/optional.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceInitTaskArgs { +public: + DynamicNodeInvocation const *invocation; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; +static_assert(std::has_unique_object_representations_v); + +void device_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceInitTaskArgs)); + DeviceInitTaskArgs task_args = + *reinterpret_cast(args); + + // FIXME: not safe to dereference unless we're on the same address space + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + NOT_IMPLEMENTED(); +} + +Realm::Event + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr) { + DeviceInitTaskArgs task_args; + task_args.invocation = &invocation; + task_args.origin_proc = ctx.get_current_processor(); + task_args.origin_result_ptr = result_ptr; + + return ctx.spawn_task(target_proc, + assert_unwrap(get_init_task_id_for_op_attrs( + assert_unwrap(invocation.node_attrs.op_attrs))), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc new file mode 100644 index 0000000000..9d9a36e2d5 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -0,0 +1,48 @@ +#include "realm-execution/tasks/impl/op_task.h" +#include "realm-execution/tasks/task_id_t.h" +#include "utils/optional.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct OpTaskArgs { +public: + DynamicNodeInvocation const *invocation; + Realm::Processor origin_proc; +}; +static_assert(std::has_unique_object_representations_v); + +void op_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(OpTaskArgs)); + OpTaskArgs task_args = *reinterpret_cast(args); + + // FIXME: not safe to dereference unless we're on the same address space + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + NOT_IMPLEMENTED(); +} + +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + std::optional const &optimizer_attrs) { + OpTaskArgs task_args; + task_args.invocation = &invocation; + return ctx.spawn_task( + target_proc, + assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 7e30edbc9f..c604d1b06a 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/realm_task_registry.h" +#include "realm-execution/tasks/impl/controller_task.h" +#include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/realm_task_id_t.h" -#include "realm-execution/tasks/realm_tasks.h" #include "utils/exception.h" namespace FlexFlow { diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc b/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc deleted file mode 100644 index b1da1f0694..0000000000 --- a/lib/realm-execution/src/realm-execution/tasks/realm_tasks.cc +++ /dev/null @@ -1,30 +0,0 @@ -#include "realm-execution/tasks/realm_tasks.h" -#include "realm-execution/realm_context.h" -#include "utils/exception.h" - -namespace FlexFlow { - -void op_task_body( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor) { - NOT_IMPLEMENTED(); -} - -void controller_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(ControllerTaskArgs)); - ControllerTaskArgs task_args = - *reinterpret_cast(args); - - RealmContext ctx{proc}; - task_args.thunk(ctx); -} - -} // namespace FlexFlow From fce23cf89ae22672ed3038a62d1ae63b12bfdecc Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 14:45:44 -0800 Subject: [PATCH 34/63] Finish implementation of device init task. --- .../tasks/impl/device_init_task.h | 15 +++--- .../realm-execution/tasks/realm_task_id_t.h | 4 +- .../tasks/impl/device_init_task.cc | 50 ++++++++++++++++--- .../tasks/realm_task_registry.cc | 13 ++++- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index bd4ca269df..ebce5fed4c 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -1,23 +1,26 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H +#include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { void device_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - std::optional const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); +Realm::Event spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h index cd5eba2f34..a3c6891fb0 100644 --- a/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h +++ b/lib/realm-execution/include/realm-execution/tasks/realm_task_id_t.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_TASK_ID_T_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_REALM_TASK_ID_T_H #include "realm-execution/realm.h" #include "realm-execution/tasks/task_id_t.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 0deb8407c4..c27fc5802b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/impl/device_init_task.h" +#include "local-execution/device_state_initialization.h" +#include "realm-execution/tasks/impl/device_init_return_task.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/optional.h" +#include #include namespace FlexFlow { @@ -9,8 +12,22 @@ namespace FlexFlow { // now just pass the pointer and assume we're running inside a single address // space struct DeviceInitTaskArgs { + DeviceInitTaskArgs() = delete; + DeviceInitTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : invocation(invocation), profiling_settings(profiling_settings), + iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), + origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + public: DynamicNodeInvocation const *invocation; + ProfilingSettings const *profiling_settings; + FFIterationConfig const *iteration_config; + OptimizerAttrs const *optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; @@ -29,19 +46,40 @@ void device_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - NOT_IMPLEMENTED(); + DynamicNodeInvocation result_invocation = + initialize_node(*task_args.invocation, + ctx.get_current_device_allocator(), + *task_args.profiling_settings, + ctx.get_current_device_handle(), + *task_args.iteration_config, + *task_args.optimizer_attrs, + ctx.get_current_device_idx()); + std::optional result_state = + result_invocation.node_attrs.per_device_op_state; + if (result_state) { + spawn_device_init_return_task(ctx, + task_args.origin_proc, + assert_unwrap(result_state), + task_args.origin_result_ptr); + } } Realm::Event spawn_device_init_task(RealmContext &ctx, Realm::Processor &target_proc, DynamicNodeInvocation const &invocation, - std::optional const &optimizer_attrs, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, DeviceSpecificPerDeviceOpState *result_ptr) { - DeviceInitTaskArgs task_args; - task_args.invocation = &invocation; - task_args.origin_proc = ctx.get_current_processor(); - task_args.origin_result_ptr = result_ptr; + DeviceInitTaskArgs task_args{ + &invocation, + &profiling_settings, + &iteration_config, + &optimizer_attrs, + ctx.get_current_processor(), + result_ptr, + }; return ctx.spawn_task(target_proc, assert_unwrap(get_init_task_id_for_op_attrs( diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index c604d1b06a..c63d4727a9 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -26,7 +26,7 @@ Realm::Event register_task(Realm::Processor::Kind target_kind, Realm::Event register_all_tasks() { std::vector pending_registrations; - std::vector task_ids = { + std::vector init_task_ids = { // Init tasks task_id_t::BATCHNORM_INIT_TASK_ID, task_id_t::COMBINE_INIT_TASK_ID, @@ -44,7 +44,14 @@ Realm::Event register_all_tasks() { task_id_t::REPARTITION_INIT_TASK_ID, task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, + }; + for (task_id_t task_id : init_task_ids) { + pending_registrations.push_back(register_task( + Realm::Processor::TOC_PROC, task_id, device_init_task_body)); + } + + std::vector task_ids = { // Forward tasks task_id_t::BATCHMATMUL_FWD_TASK_ID, task_id_t::BATCHNORM_FWD_TASK_ID, @@ -118,6 +125,10 @@ Realm::Event register_all_tasks() { pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, task_id_t::CONTROLLER_TASK_ID, controller_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_INIT_RETURN_TASK_ID, + device_init_return_task_body)); return Realm::Event::merge_events(pending_registrations); } From 6fc3b9b699755b799e736b10cecf9d4080216fe6 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 15:14:24 -0800 Subject: [PATCH 35/63] Finish implementation of device state initialization. --- .../tasks/impl/device_init_task.h | 15 ++--- ...distributed_device_state_initialization.cc | 57 ++++++++++++++++++- .../tasks/impl/device_init_task.cc | 29 +++++----- 3 files changed, 79 insertions(+), 22 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index ebce5fed4c..af07139483 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -14,13 +14,14 @@ namespace FlexFlow { void device_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); +std::optional + spawn_device_init_task(RealmContext &ctx, + Realm::Processor &target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index c6d0621f3d..f7fcea87e7 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,5 +1,11 @@ #include "realm-execution/distributed_device_state_initialization.h" -#include "utils/exception.h" +#include "local-execution/device_state_initialization.h" +#include "realm-execution/tasks/impl/device_init_task.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" +#include "utils/optional.h" +#include +#include namespace FlexFlow { @@ -9,7 +15,54 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs) { - NOT_IMPLEMENTED(); + + // Initialize all operators and save the per-device op state + ASSERT(no_nodes_are_initialized(dg)); + + std::unordered_map + result_map; + for (DynamicNodeInvocation const &invocation : dg.invocations) { + Realm::Processor target_proc = ctx.map_device_coord_to_processor( + assert_unwrap(invocation.node_attrs.device_coord)); + + // FIXME: in the absense of a real serializer we're just tossing around raw + // bytes, which means we need to bypass the constructor for this type (yes, + // ugh) + DeviceSpecificPerDeviceOpState *output = + static_cast( + malloc(sizeof(DeviceSpecificPerDeviceOpState))); + std::optional result = + spawn_device_init_task(ctx, + target_proc, + invocation, + profiling_settings, + iteration_config, + optimizer_attrs, + output); + if (result) { + result_map[invocation] = output; + } else { + free(output); + } + } + + ctx.get_outstanding_events().wait(); + + DynamicOpenDataflowGraph result = transform_dynamic_invocation_set( + dg, [&](DynamicNodeInvocation const &invocation) { + DynamicNodeInvocation result = invocation; + auto device_state = result_map.find(invocation); + if (device_state != result_map.end()) { + result.node_attrs.per_device_op_state = *device_state->second; + } + return result; + }); + + for (auto &[invocation, output] : result_map) { + free(output); + } + + return result; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index c27fc5802b..91b753d639 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -1,6 +1,7 @@ #include "realm-execution/tasks/impl/device_init_task.h" #include "local-execution/device_state_initialization.h" #include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/optional.h" #include @@ -56,15 +57,13 @@ void device_init_task_body(void const *args, ctx.get_current_device_idx()); std::optional result_state = result_invocation.node_attrs.per_device_op_state; - if (result_state) { - spawn_device_init_return_task(ctx, - task_args.origin_proc, - assert_unwrap(result_state), - task_args.origin_result_ptr); - } + spawn_device_init_return_task(ctx, + task_args.origin_proc, + assert_unwrap(result_state), + task_args.origin_result_ptr); } -Realm::Event +std::optional spawn_device_init_task(RealmContext &ctx, Realm::Processor &target_proc, DynamicNodeInvocation const &invocation, @@ -81,12 +80,16 @@ Realm::Event result_ptr, }; - return ctx.spawn_task(target_proc, - assert_unwrap(get_init_task_id_for_op_attrs( - assert_unwrap(invocation.node_attrs.op_attrs))), - &task_args, - sizeof(task_args), - Realm::ProfilingRequestSet{}); + std::optional task_id = get_init_task_id_for_op_attrs( + assert_unwrap(invocation.node_attrs.op_attrs)); + if (task_id) { + return ctx.spawn_task(target_proc, + assert_unwrap(task_id), + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}); + } + return std::nullopt; } } // namespace FlexFlow From 2de35164797fb36bf746755470644db515bc1429 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 15:15:51 -0800 Subject: [PATCH 36/63] Block on initialization. --- .../parallel_computation_graph_instance.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index bb763334d5..cdb3e5fe46 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -87,6 +87,10 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); + // FIXME: for now we're going to be lazy and block on everything rather than + // do fine-grained dependencies + ctx.get_outstanding_events().wait(); + std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { return backing.backing.at(lgv).first; From 2a174e02d8a4bd6a4877229c92769f8136628c17 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 17:04:10 -0800 Subject: [PATCH 37/63] Wire up rest of Realm implementation. --- .../parallel_computation_graph_instance.h | 19 +-- .../realm-execution/tasks/impl/op_task.h | 8 +- .../parallel_computation_graph_instance.cc | 159 +++++++++++++++--- .../tasks/impl/device_init_task.cc | 13 +- .../src/realm-execution/tasks/impl/op_task.cc | 49 +++++- 5 files changed, 206 insertions(+), 42 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index f361cec3ca..0886dcf4c0 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -23,30 +23,27 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: ParallelComputationGraphInstance(RealmContext &, - DynamicOpenDataflowGraph, std::vector const &, OptimizerAttrs const &, std::optional const &, - std::optional); - DynamicOpenDataflowGraph const &get_dynamic_dataflow_graph() const; - Allocator &get_allocator() const; - std::vector const &get_topological_ordering() const; + std::optional); + RealmContext &get_realm_context(); + std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); std::optional const &get_loss_attrs() const; - std::optional get_loss_tensor_accessor() const; + std::optional get_loss_tensor_instance() const; private: - RealmContext &realm; - DynamicOpenDataflowGraph dataflow_graph; - std::vector topological_ordering; + RealmContext &ctx; + std::vector execution_order; OptimizerAttrs optimizer_attrs; std::optional loss_attrs; - std::optional logit_grad_tensor; + std::optional logit_grad_tensor; }; ParallelComputationGraphInstance create_parallel_computation_graph_instance( - RealmContext &realm, + RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 4c3e6d38d1..dd75ed66ea 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -1,10 +1,13 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_OP_TASK_H +#include "kernels/profiling_settings.dtg.h" +#include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" namespace FlexFlow { @@ -12,8 +15,11 @@ void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor &target_proc, + Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index cdb3e5fe46..2683d019c3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -2,6 +2,8 @@ #include "pcg/optimizer_attrs.h" #include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/tasks/impl/op_task.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" @@ -9,33 +11,27 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" -#include "utils/exception.h" +#include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/optional.h" namespace FlexFlow { ParallelComputationGraphInstance::ParallelComputationGraphInstance( - RealmContext &realm, - DynamicOpenDataflowGraph dataflow_graph, - std::vector const &topological_ordering, + RealmContext &ctx, + std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, std::optional const &loss_attrs, - std::optional logit_grad_tensor) - : realm(realm), dataflow_graph(dataflow_graph), - topological_ordering(topological_ordering), + std::optional logit_grad_tensor) + : ctx(ctx), execution_order(execution_order), optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), logit_grad_tensor(logit_grad_tensor) {} -DynamicOpenDataflowGraph const & - ParallelComputationGraphInstance::get_dynamic_dataflow_graph() const { - return this->dataflow_graph; -} -Allocator &ParallelComputationGraphInstance::get_allocator() const { - return this->realm.get_current_device_allocator(); +RealmContext &ParallelComputationGraphInstance::get_realm_context() { + return this->ctx; } std::vector const & - ParallelComputationGraphInstance::get_topological_ordering() const { - return this->topological_ordering; + ParallelComputationGraphInstance::get_execution_order() const { + return this->execution_order; } OptimizerAttrs const & ParallelComputationGraphInstance::get_optimizer_attrs() const { @@ -49,8 +45,8 @@ std::optional const & ParallelComputationGraphInstance::get_loss_attrs() const { return this->loss_attrs; } -std::optional - ParallelComputationGraphInstance::get_loss_tensor_accessor() const { +std::optional + ParallelComputationGraphInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; } @@ -88,7 +84,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); // FIXME: for now we're going to be lazy and block on everything rather than - // do fine-grained dependencies + // do fine-grained dependencies on instances ctx.get_outstanding_events().wait(); std::optional logit_grad_tensor = @@ -98,13 +94,134 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( dg = perform_distributed_device_state_initialization( dg, ctx, profiling_settings, iteration_config, optimizer_attrs); - NOT_IMPLEMENTED(); + + // Compute the topological ordering of the graph + auto [kwarg_graph, node_map] = + labelled_open_kwarg_dataflow_graph_from_dynamic_open_dataflow_graph(dg); + std::vector node_topo_order = get_topological_ordering(kwarg_graph); + std::vector invocation_topo_order = transform( + node_topo_order, [&](Node node) { return node_map.at_l(node); }); + + return ParallelComputationGraphInstance{ctx, + invocation_topo_order, + optimizer_attrs, + loss_attrs, + logit_grad_tensor}; // TODO list: - // * per-device state initialization (RPC mechanism?) // * Realm allocator - // * task body // * external instances } +static std::unordered_map + execute_distributed_dynamic_node_invocation_set( + RealmContext &ctx, + std::vector const &invocations, + OptimizerAttrs const &optimizer_attrs, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig iteration_config) { + return unordered_map_from_pairs( + transform(invocations, [&](DynamicNodeInvocation const &invocation) { + Realm::Event result = + spawn_op_task(ctx, + ctx.map_device_coord_to_processor(assert_unwrap( + invocation.node_attrs.device_coord)), + invocation, + profiling_settings, + loss_attrs, + iteration_config, + optimizer_attrs); + return std::pair{invocation.node_attrs.layer_guid, result}; + })); +} + +std::unordered_map + perform_all_passes_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + instance.get_execution_order(); + std::unordered_map result = + execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); + instance.update_optimizer_attrs_for_next_iter(); + return result; +} + +std::unordered_map + perform_forward_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::FWD; + }); + + return execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); +} + +std::unordered_map + perform_backward_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::BWD; + }); + + return execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); +} + +std::unordered_map + perform_update_pass_for_parallel_computation_graph_instance( + ParallelComputationGraphInstance &instance, + ProfilingSettings const &profiling_settings, + FFIterationConfig iteration_config) { + std::vector const &execution_order = + filter(instance.get_execution_order(), + [](DynamicNodeInvocation const &invocation) { + DynamicTaskType task_type = + assert_unwrap(invocation.node_attrs.task_type); + return task_type == DynamicTaskType::UPD; + }); + + std::unordered_map result = + execute_distributed_dynamic_node_invocation_set( + /*ctx=*/instance.get_realm_context(), + /*invocations=*/execution_order, + /*optimizer_attrs=*/instance.get_optimizer_attrs(), + /*profiling_settings=*/profiling_settings, + /*loss_attrs=*/instance.get_loss_attrs(), + /*iteration_config=*/iteration_config); + instance.update_optimizer_attrs_for_next_iter(); + return result; +} + } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 91b753d639..49b5568d26 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -3,6 +3,7 @@ #include "realm-execution/tasks/impl/device_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" #include "utils/optional.h" #include #include @@ -43,7 +44,7 @@ void device_init_task_body(void const *args, DeviceInitTaskArgs task_args = *reinterpret_cast(args); - // FIXME: not safe to dereference unless we're on the same address space + // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; @@ -55,11 +56,15 @@ void device_init_task_body(void const *args, *task_args.iteration_config, *task_args.optimizer_attrs, ctx.get_current_device_idx()); - std::optional result_state = - result_invocation.node_attrs.per_device_op_state; + DeviceSpecificPerDeviceOpState result_state = + assert_unwrap(result_invocation.node_attrs.per_device_op_state); + // Important: to make sure this doesn't get deallocated, we intentionally leak + // the allocation here + DeviceSpecificPerDeviceOpState *result_state_ptr = + new DeviceSpecificPerDeviceOpState{result_state}; spawn_device_init_return_task(ctx, task_args.origin_proc, - assert_unwrap(result_state), + *result_state_ptr, task_args.origin_result_ptr); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 9d9a36e2d5..79c152844b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/impl/op_task.h" +#include "local-execution/task_execution.h" #include "realm-execution/tasks/task_id_t.h" +#include "task-spec/per_device_op_state.h" #include "utils/optional.h" #include @@ -9,8 +11,24 @@ namespace FlexFlow { // now just pass the pointer and assume we're running inside a single address // space struct OpTaskArgs { +public: + OpTaskArgs() = delete; + OpTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + std::optional const *loss_attrs, + FFIterationConfig const *iteration_config, + std::optional const *optimizer_attrs, + Realm::Processor origin_proc) + : invocation(invocation), profiling_settings(profiling_settings), + loss_attrs(loss_attrs), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs) {} + public: DynamicNodeInvocation const *invocation; + ProfilingSettings const *profiling_settings; + std::optional const *loss_attrs; + FFIterationConfig const *iteration_config; + std::optional const *optimizer_attrs; Realm::Processor origin_proc; }; static_assert(std::has_unique_object_representations_v); @@ -23,20 +41,41 @@ void op_task_body(void const *args, ASSERT(arglen == sizeof(OpTaskArgs)); OpTaskArgs task_args = *reinterpret_cast(args); - // FIXME: not safe to dereference unless we're on the same address space + // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - NOT_IMPLEMENTED(); + execute_dynamic_node_invocation( + /*invocation=*/*task_args.invocation, + /*allocator=*/ctx.get_current_device_allocator(), + /*profiling_settings=*/*task_args.profiling_settings, + /*ff_handle=*/ctx.get_current_device_handle(), + /*loss_attrs=*/*task_args.loss_attrs, + /*per_device_op_state=*/ + transform(task_args.invocation->node_attrs.per_device_op_state, + [&](DeviceSpecificPerDeviceOpState const &op_state) { + return get_device_state_from_device_specific( + op_state, ctx.get_current_device_idx()); + }), + /*iteration_config=*/*task_args.iteration_config, + /*optimizer_attrs=*/*task_args.optimizer_attrs, + /*device_idx=*/ctx.get_current_device_idx()); } Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor &target_proc, + Realm::Processor target_proc, DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + std::optional const &loss_attrs, + FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs) { - OpTaskArgs task_args; - task_args.invocation = &invocation; + OpTaskArgs task_args{&invocation, + &profiling_settings, + &loss_attrs, + &iteration_config, + &optimizer_attrs, + ctx.get_current_processor()}; return ctx.spawn_task( target_proc, assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), From 7e78e3fee00a5840e91de8778018908b63095a53 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Tue, 10 Feb 2026 17:16:31 -0800 Subject: [PATCH 38/63] Implement Realm device idx. --- .../include/realm-execution/realm_context.h | 2 +- .../src/realm-execution/realm_context.cc | 26 +++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 422c4f4027..e28e91234e 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -29,7 +29,7 @@ struct RealmContext { Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator() const; device_handle_t const &get_current_device_handle() const; - device_id_t const &get_current_device_idx() const; + device_id_t get_current_device_idx() const; // Task creation Realm::Event spawn_task(Realm::Processor proc, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 7e6c73c9e7..781561c95a 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,6 +1,7 @@ #include "realm-execution/realm_context.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" +#include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" @@ -70,8 +71,29 @@ Allocator &RealmContext::get_current_device_allocator() const { device_handle_t const &RealmContext::get_current_device_handle() const { NOT_IMPLEMENTED(); } -device_id_t const &RealmContext::get_current_device_idx() const { - NOT_IMPLEMENTED(); +device_id_t RealmContext::get_current_device_idx() const { + Realm::Processor proc = this->get_current_processor(); + + // FIXME: find a more efficient way to implement this than scanning the + // machine every time + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + pq.same_address_space_as(proc); + nonnegative_int idx{0}; + for (Realm::Processor p : pq) { + if (p == proc) { + break; + } + idx++; + } + + switch (proc.kind()) { + case Realm::Processor::LOC_PROC: + return make_device_id_t_from_idx(idx, DeviceType::CPU); + case Realm::Processor::TOC_PROC: + return make_device_id_t_from_idx(idx, DeviceType::GPU); + default: + PANIC("Unhandled Realm::ProcessorKind", fmt::to_string(int{proc.kind()})); + } } Realm::Event From 5ffc1ddf8ad93b6314806bff2248836e005d277c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 09:47:00 -0800 Subject: [PATCH 39/63] Updates to compile against latest local-execution. --- .../parallel_computation_graph_instance.h | 3 -- .../realm-execution/tasks/impl/op_task.h | 1 - .../parallel_computation_graph_instance.cc | 32 ++++++------------- .../tasks/impl/device_init_task.cc | 11 +++++-- .../src/realm-execution/tasks/impl/op_task.cc | 8 +---- .../src/realm-execution/tasks/task_id_t.cc | 12 ++++--- ...e_dynamic_open_dataflow_graph_from_mpcg.cc | 2 +- 7 files changed, 28 insertions(+), 41 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index 0886dcf4c0..de06f457e2 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -25,20 +25,17 @@ struct ParallelComputationGraphInstance { ParallelComputationGraphInstance(RealmContext &, std::vector const &, OptimizerAttrs const &, - std::optional const &, std::optional); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; void update_optimizer_attrs_for_next_iter(); - std::optional const &get_loss_attrs() const; std::optional get_loss_tensor_instance() const; private: RealmContext &ctx; std::vector execution_order; OptimizerAttrs optimizer_attrs; - std::optional loss_attrs; std::optional logit_grad_tensor; }; diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index dd75ed66ea..3fcffc30fa 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -18,7 +18,6 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs); diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc index 2683d019c3..05dfec74c3 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc @@ -20,11 +20,9 @@ ParallelComputationGraphInstance::ParallelComputationGraphInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, - std::optional const &loss_attrs, std::optional logit_grad_tensor) : ctx(ctx), execution_order(execution_order), - optimizer_attrs(optimizer_attrs), loss_attrs(loss_attrs), - logit_grad_tensor(logit_grad_tensor) {} + optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} RealmContext &ParallelComputationGraphInstance::get_realm_context() { return this->ctx; @@ -41,10 +39,6 @@ void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { this->optimizer_attrs = get_optimizer_attrs_for_next_iter(this->optimizer_attrs); } -std::optional const & - ParallelComputationGraphInstance::get_loss_attrs() const { - return this->loss_attrs; -} std::optional ParallelComputationGraphInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; @@ -102,15 +96,15 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::vector invocation_topo_order = transform( node_topo_order, [&](Node node) { return node_map.at_l(node); }); - return ParallelComputationGraphInstance{ctx, - invocation_topo_order, - optimizer_attrs, - loss_attrs, - logit_grad_tensor}; + return ParallelComputationGraphInstance{ + ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: // * Realm allocator // * external instances + // * dependencies + // * task argument serializer + // * copies } static std::unordered_map @@ -119,7 +113,6 @@ static std::unordered_map std::vector const &invocations, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig iteration_config) { return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { @@ -129,7 +122,6 @@ static std::unordered_map invocation.node_attrs.device_coord)), invocation, profiling_settings, - loss_attrs, iteration_config, optimizer_attrs); return std::pair{invocation.node_attrs.layer_guid, result}; @@ -141,7 +133,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = instance.get_execution_order(); std::unordered_map result = execute_distributed_dynamic_node_invocation_set( @@ -149,7 +141,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; @@ -160,7 +151,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -173,7 +164,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); } @@ -182,7 +172,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -195,7 +185,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); } @@ -204,7 +193,7 @@ std::unordered_map ParallelComputationGraphInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { - std::vector const &execution_order = + std::vector execution_order = filter(instance.get_execution_order(), [](DynamicNodeInvocation const &invocation) { DynamicTaskType task_type = @@ -218,7 +207,6 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, - /*loss_attrs=*/instance.get_loss_attrs(), /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index 49b5568d26..cc080255e2 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -4,6 +4,7 @@ #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "utils/optional.h" #include #include @@ -85,9 +86,13 @@ std::optional result_ptr, }; - std::optional task_id = get_init_task_id_for_op_attrs( - assert_unwrap(invocation.node_attrs.op_attrs)); - if (task_id) { + std::optional task_id = + and_then(and_then(invocation.node_attrs.op_attrs, + [](TrainingOperationAttrs const &op_attrs) { + return op_attrs.try_require_pcg_op(); + }), + get_init_task_id_for_op_attrs); + if (task_id.has_value()) { return ctx.spawn_task(target_proc, assert_unwrap(task_id), &task_args, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 79c152844b..5f6ab40607 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -15,18 +15,15 @@ struct OpTaskArgs { OpTaskArgs() = delete; OpTaskArgs(DynamicNodeInvocation const *invocation, ProfilingSettings const *profiling_settings, - std::optional const *loss_attrs, FFIterationConfig const *iteration_config, std::optional const *optimizer_attrs, Realm::Processor origin_proc) : invocation(invocation), profiling_settings(profiling_settings), - loss_attrs(loss_attrs), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs) {} + iteration_config(iteration_config), optimizer_attrs(optimizer_attrs) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; - std::optional const *loss_attrs; FFIterationConfig const *iteration_config; std::optional const *optimizer_attrs; Realm::Processor origin_proc; @@ -50,7 +47,6 @@ void op_task_body(void const *args, /*allocator=*/ctx.get_current_device_allocator(), /*profiling_settings=*/*task_args.profiling_settings, /*ff_handle=*/ctx.get_current_device_handle(), - /*loss_attrs=*/*task_args.loss_attrs, /*per_device_op_state=*/ transform(task_args.invocation->node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { @@ -67,12 +63,10 @@ Realm::Event Realm::Processor target_proc, DynamicNodeInvocation const &invocation, ProfilingSettings const &profiling_settings, - std::optional const &loss_attrs, FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs) { OpTaskArgs task_args{&invocation, &profiling_settings, - &loss_attrs, &iteration_config, &optimizer_attrs, ctx.get_current_processor()}; diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index 5a99f2bea8..94e1b887e7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -2,6 +2,7 @@ #include "pcg/optimizer_attrs.dtg.h" #include "pcg/optimizers/adam_optimizer_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "utils/optional.h" #include "utils/overload.h" namespace FlexFlow { @@ -9,14 +10,17 @@ namespace FlexFlow { std::optional get_task_id_for_op(DynamicNodeAttrs const &node_attrs, std::optional const &optimizer_attrs) { - DynamicTaskType task_type = node_attrs.task_type.value(); + DynamicTaskType task_type = assert_unwrap(node_attrs.task_type); switch (task_type) { case DynamicTaskType::FWD: - return get_fwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); + return get_fwd_task_id_for_op_attrs( + assert_unwrap(node_attrs.op_attrs).require_pcg_op()); case DynamicTaskType::BWD: - return get_bwd_task_id_for_op_attrs(node_attrs.op_attrs.value()); + return get_bwd_task_id_for_op_attrs( + assert_unwrap(node_attrs.op_attrs).require_pcg_op()); case DynamicTaskType::UPD: - return get_update_task_id_for_optimizer_attrs(optimizer_attrs.value()); + return get_update_task_id_for_optimizer_attrs( + assert_unwrap(optimizer_attrs)); case DynamicTaskType::LOSS: return task_id_t::LOSS_BWD_TASK_ID; default: diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc index e90ef10398..ced98dfd44 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.cc @@ -23,7 +23,7 @@ DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mpcg( /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/mpcg.mapped_tasks.at(layer), - /*op_attrs=*/attrs.op_attrs, + /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, /*per_device_op_state=*/std::nullopt, }; From e1b6fcadfdc138824a7d02f20a3573338e78988a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 10:01:14 -0800 Subject: [PATCH 40/63] Fix up function arguments. --- .../distributed_device_state_initialization.h | 2 +- .../include/realm-execution/instance_allocation.h | 11 ++++++----- .../parallel_computation_graph_instance.h | 9 +++++---- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index 4121f10341..d2ed093c0b 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -10,7 +10,7 @@ namespace FlexFlow { DynamicOpenDataflowGraph perform_distributed_device_state_initialization( - DynamicOpenDataflowGraph const &, + DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, diff --git a/lib/realm-execution/include/realm-execution/instance_allocation.h b/lib/realm-execution/include/realm-execution/instance_allocation.h index 59065694e9..09709201ce 100644 --- a/lib/realm-execution/include/realm-execution/instance_allocation.h +++ b/lib/realm-execution/include/realm-execution/instance_allocation.h @@ -7,15 +7,16 @@ namespace FlexFlow { -DynamicValueAttrs - perform_instance_allocation_for_value(DynamicValueAttrs const &, - Allocator &); +std::pair + perform_instance_allocation_for_value(DynamicNodeAttrs const &node, + DynamicValueAttrs const &value, + RealmContext &ctx); TensorInstanceBacking perform_instance_allocation( - DynamicOpenDataflowGraph const &, + DynamicOpenDataflowGraph const &g, std::unordered_map const &preallocated, - RealmContext &); + RealmContext &ctx); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h index de06f457e2..f48879a2bb 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h @@ -22,10 +22,11 @@ namespace FlexFlow { struct ParallelComputationGraphInstance { public: - ParallelComputationGraphInstance(RealmContext &, - std::vector const &, - OptimizerAttrs const &, - std::optional); + ParallelComputationGraphInstance( + RealmContext &ctx, + std::vector const &execution_order, + OptimizerAttrs const &optimizer_attrs, + std::optional logit_grad_tensor); RealmContext &get_realm_context(); std::vector const &get_execution_order() const; OptimizerAttrs const &get_optimizer_attrs() const; From e2ccf4afa96ee45930e4e872e10d60e1c382840d Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 10:50:16 -0800 Subject: [PATCH 41/63] Rename PCGInstance and add dependency set. --- .../realm-execution/atomic_dependency_set.h | 26 ++++++++++++ .../include/realm-execution/dependency_set.h | 34 +++++++++++++++ .../pcg_instance.h} | 13 +++--- .../realm-execution/atomic_dependency_set.cc | 23 +++++++++++ .../src/realm-execution/dependency_set.cc | 41 +++++++++++++++++++ .../pcg_instance.cc} | 31 +++++++------- .../test/src/realm-execution/realm_manager.cc | 1 - .../test/src/realm-execution/test_e2e.cc | 2 +- 8 files changed, 150 insertions(+), 21 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/atomic_dependency_set.h create mode 100644 lib/realm-execution/include/realm-execution/dependency_set.h rename lib/realm-execution/include/realm-execution/{parallel_computation_graph_instance/parallel_computation_graph_instance.h => pcg_instance/pcg_instance.h} (84%) create mode 100644 lib/realm-execution/src/realm-execution/atomic_dependency_set.cc create mode 100644 lib/realm-execution/src/realm-execution/dependency_set.cc rename lib/realm-execution/src/realm-execution/{parallel_computation_graph_instance/parallel_computation_graph_instance.cc => pcg_instance/pcg_instance.cc} (90%) diff --git a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h new file mode 100644 index 0000000000..8a1ae96b3e --- /dev/null +++ b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_ATOMIC_DEPENDENCY_SET_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_ATOMIC_DEPENDENCY_SET_H + +#include "realm-execution/realm.h" +#include + +namespace FlexFlow { + +struct AtomicDependencySet { +public: + AtomicDependencySet() = delete; + explicit AtomicDependencySet(Realm::Event precondition); + + void add_writer(Realm::Event writer); + void add_reader(Realm::Event reader); + + Realm::Event get_current_outstanding_events() const; + +private: + Realm::Event writer; + std::vector readers; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/dependency_set.h b/lib/realm-execution/include/realm-execution/dependency_set.h new file mode 100644 index 0000000000..a7100076b2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/dependency_set.h @@ -0,0 +1,34 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEPENDENCY_SET_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEPENDENCY_SET_H + +#include "realm-execution/atomic_dependency_set.h" +#include "realm-execution/realm.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include + +namespace FlexFlow { + +struct DependencySet { +public: + DependencySet() = delete; + explicit DependencySet(Realm::Event precondition); + + void add_writer(DynamicValueAttrs const &value, Realm::Event writer); + void add_reader(DynamicValueAttrs const &value, Realm::Event reader); + + Realm::Event + get_current_outstanding_events(DynamicValueAttrs const &value) const; + +private: + AtomicDependencySet & + get_atomic_dependency_set(DynamicValueAttrs const &value); + +private: + Realm::Event precondition; + std::unordered_map + atomic_dependencies; +}; + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h similarity index 84% rename from lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h rename to lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index f48879a2bb..3c5b4189ea 100644 --- a/lib/realm-execution/include/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PARALLEL_COMPUTATION_GRAPH_INSTANCE_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H #include "kernels/accessor.h" #include "kernels/allocation.h" @@ -20,9 +20,12 @@ namespace FlexFlow { -struct ParallelComputationGraphInstance { +struct PCGInstance { public: - ParallelComputationGraphInstance( + PCGInstance() = delete; + PCGInstance(PCGInstance const &) = delete; + PCGInstance(PCGInstance &&) = delete; + explicit PCGInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, @@ -40,7 +43,7 @@ struct ParallelComputationGraphInstance { std::optional logit_grad_tensor; }; -ParallelComputationGraphInstance create_parallel_computation_graph_instance( +PCGInstance create_pcg_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, diff --git a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc new file mode 100644 index 0000000000..bdc05b7c46 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc @@ -0,0 +1,23 @@ +#include "realm-execution/atomic_dependency_set.h" + +namespace FlexFlow { + +AtomicDependencySet::AtomicDependencySet(Realm::Event precondition) + : writer(precondition) {} + +void AtomicDependencySet::add_writer(Realm::Event writer) { + this->writer = Realm::Event::merge_events( + writer, this->get_current_outstanding_events()); + this->readers.clear(); +} + +void AtomicDependencySet::add_reader(Realm::Event reader) { + this->readers.push_back(reader); +} + +Realm::Event AtomicDependencySet::get_current_outstanding_events() const { + Realm::Event readers = Realm::Event::merge_events(this->readers); + return Realm::Event::merge_events(writer, readers); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/dependency_set.cc b/lib/realm-execution/src/realm-execution/dependency_set.cc new file mode 100644 index 0000000000..3af03ffcef --- /dev/null +++ b/lib/realm-execution/src/realm-execution/dependency_set.cc @@ -0,0 +1,41 @@ +#include "realm-execution/dependency_set.h" +#include "realm-execution/atomic_dependency_set.h" +#include "utils/containers/contains_key.h" + +namespace FlexFlow { + +DependencySet::DependencySet(Realm::Event precondition) + : precondition(precondition) {} + +void DependencySet::add_writer(DynamicValueAttrs const &value, + Realm::Event writer) { + AtomicDependencySet &atomic_dependence_set = + this->get_atomic_dependency_set(value); + atomic_dependence_set.add_writer(writer); +} + +void DependencySet::add_reader(DynamicValueAttrs const &value, + Realm::Event reader) { + AtomicDependencySet &atomic_dependence_set = + this->get_atomic_dependency_set(value); + atomic_dependence_set.add_reader(reader); +} + +Realm::Event DependencySet::get_current_outstanding_events( + DynamicValueAttrs const &value) const { + if (contains_key(this->atomic_dependencies, value)) { + return this->atomic_dependencies.at(value).get_current_outstanding_events(); + } + return this->precondition; +} + +AtomicDependencySet & + DependencySet::get_atomic_dependency_set(DynamicValueAttrs const &value) { + if (!contains_key(this->atomic_dependencies, value)) { + this->atomic_dependencies.insert( + {value, AtomicDependencySet{this->precondition}}); + } + return this->atomic_dependencies.at(value); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc similarity index 90% rename from lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc rename to lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 05dfec74c3..c1654397ec 100644 --- a/lib/realm-execution/src/realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -1,5 +1,6 @@ -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/pcg_instance/pcg_instance.h" #include "pcg/optimizer_attrs.h" +#include "realm-execution/dependency_set.h" #include "realm-execution/distributed_device_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" @@ -16,7 +17,7 @@ namespace FlexFlow { -ParallelComputationGraphInstance::ParallelComputationGraphInstance( +PCGInstance::PCGInstance( RealmContext &ctx, std::vector const &execution_order, OptimizerAttrs const &optimizer_attrs, @@ -24,27 +25,26 @@ ParallelComputationGraphInstance::ParallelComputationGraphInstance( : ctx(ctx), execution_order(execution_order), optimizer_attrs(optimizer_attrs), logit_grad_tensor(logit_grad_tensor) {} -RealmContext &ParallelComputationGraphInstance::get_realm_context() { +RealmContext &PCGInstance::get_realm_context() { return this->ctx; } std::vector const & - ParallelComputationGraphInstance::get_execution_order() const { + PCGInstance::get_execution_order() const { return this->execution_order; } -OptimizerAttrs const & - ParallelComputationGraphInstance::get_optimizer_attrs() const { +OptimizerAttrs const &PCGInstance::get_optimizer_attrs() const { return this->optimizer_attrs; } -void ParallelComputationGraphInstance::update_optimizer_attrs_for_next_iter() { +void PCGInstance::update_optimizer_attrs_for_next_iter() { this->optimizer_attrs = get_optimizer_attrs_for_next_iter(this->optimizer_attrs); } std::optional - ParallelComputationGraphInstance::get_loss_tensor_instance() const { + PCGInstance::get_loss_tensor_instance() const { return this->logit_grad_tensor; } -ParallelComputationGraphInstance create_parallel_computation_graph_instance( +PCGInstance create_parallel_computation_graph_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, @@ -96,7 +96,7 @@ ParallelComputationGraphInstance create_parallel_computation_graph_instance( std::vector invocation_topo_order = transform( node_topo_order, [&](Node node) { return node_map.at_l(node); }); - return ParallelComputationGraphInstance{ + return PCGInstance{ ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: @@ -114,6 +114,9 @@ static std::unordered_map OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { + // For simplicity we'll track a dependency on all outstanding operations up to + // this point. This will create an effective barrier between phases. + DependencySet dependency_set{ctx.get_outstanding_events()}; return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { Realm::Event result = @@ -130,7 +133,7 @@ static std::unordered_map std::unordered_map perform_all_passes_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -148,7 +151,7 @@ std::unordered_map std::unordered_map perform_forward_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -169,7 +172,7 @@ std::unordered_map std::unordered_map perform_backward_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = @@ -190,7 +193,7 @@ std::unordered_map std::unordered_map perform_update_pass_for_parallel_computation_graph_instance( - ParallelComputationGraphInstance &instance, + PCGInstance &instance, ProfilingSettings const &profiling_settings, FFIterationConfig iteration_config) { std::vector execution_order = diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 6c28a001ad..94e0d7d0f4 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,5 +1,4 @@ #include "realm-execution/realm_manager.h" -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" #include using namespace ::FlexFlow; diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index a30d5c4d8e..37f1a9b42c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,4 +1,4 @@ -#include "realm-execution/parallel_computation_graph_instance/parallel_computation_graph_instance.h" +#include "realm-execution/pcg_instance/pcg_instance.h" #include "realm-execution/realm_manager.h" #include From ffd2738eb6aee30513e18d05cb391465349a458c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:17:21 -0800 Subject: [PATCH 42/63] Dependency tracking. --- .../realm-execution/atomic_dependency_set.h | 3 +- .../include/realm-execution/dependency_set.h | 4 +- .../distributed_device_state_initialization.h | 3 +- .../tasks/impl/controller_task.h | 3 +- .../tasks/impl/device_init_return_task.h | 3 +- .../tasks/impl/device_init_task.h | 3 +- .../realm-execution/tasks/impl/op_task.h | 14 +++---- .../realm-execution/atomic_dependency_set.cc | 12 ++++-- .../src/realm-execution/dependency_set.cc | 12 +++++- ...distributed_device_state_initialization.cc | 6 ++- .../pcg_instance/pcg_instance.cc | 39 +++++++++++++++---- .../src/realm-execution/realm_manager.cc | 3 +- .../tasks/impl/controller_task.cc | 12 +++--- .../tasks/impl/device_init_return_task.cc | 6 ++- .../tasks/impl/device_init_task.cc | 9 +++-- .../src/realm-execution/tasks/impl/op_task.cc | 17 ++++---- 16 files changed, 101 insertions(+), 48 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h index 8a1ae96b3e..da6ba86638 100644 --- a/lib/realm-execution/include/realm-execution/atomic_dependency_set.h +++ b/lib/realm-execution/include/realm-execution/atomic_dependency_set.h @@ -14,7 +14,8 @@ struct AtomicDependencySet { void add_writer(Realm::Event writer); void add_reader(Realm::Event reader); - Realm::Event get_current_outstanding_events() const; + Realm::Event get_dependency_for_writer() const; + Realm::Event get_dependency_for_reader() const; private: Realm::Event writer; diff --git a/lib/realm-execution/include/realm-execution/dependency_set.h b/lib/realm-execution/include/realm-execution/dependency_set.h index a7100076b2..629a40e2e7 100644 --- a/lib/realm-execution/include/realm-execution/dependency_set.h +++ b/lib/realm-execution/include/realm-execution/dependency_set.h @@ -16,8 +16,8 @@ struct DependencySet { void add_writer(DynamicValueAttrs const &value, Realm::Event writer); void add_reader(DynamicValueAttrs const &value, Realm::Event reader); - Realm::Event - get_current_outstanding_events(DynamicValueAttrs const &value) const; + Realm::Event get_dependency_for_writer(DynamicValueAttrs const &value) const; + Realm::Event get_dependency_for_reader(DynamicValueAttrs const &value) const; private: AtomicDependencySet & diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index d2ed093c0b..5530f473d8 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -14,7 +14,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs); + OptimizerAttrs const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h index d4c397bb37..7134973ead 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/controller_task.h @@ -12,7 +12,8 @@ void controller_task_body( Realm::Event collective_spawn_controller_task(RealmContext &ctx, Realm::Processor &target_proc, - std::function thunk); + std::function thunk, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h index fc6c8bdb9f..0f92b35c24 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h @@ -14,7 +14,8 @@ Realm::Event spawn_device_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr); + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h index af07139483..7842963c7b 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h @@ -21,7 +21,8 @@ std::optional ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr); + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 3fcffc30fa..21d8795339 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -13,13 +13,13 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs); +Realm::Event spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc index bdc05b7c46..ba4fcc5a9f 100644 --- a/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc +++ b/lib/realm-execution/src/realm-execution/atomic_dependency_set.cc @@ -6,8 +6,8 @@ AtomicDependencySet::AtomicDependencySet(Realm::Event precondition) : writer(precondition) {} void AtomicDependencySet::add_writer(Realm::Event writer) { - this->writer = Realm::Event::merge_events( - writer, this->get_current_outstanding_events()); + this->writer = + Realm::Event::merge_events(writer, this->get_dependency_for_writer()); this->readers.clear(); } @@ -15,9 +15,13 @@ void AtomicDependencySet::add_reader(Realm::Event reader) { this->readers.push_back(reader); } -Realm::Event AtomicDependencySet::get_current_outstanding_events() const { +Realm::Event AtomicDependencySet::get_dependency_for_writer() const { Realm::Event readers = Realm::Event::merge_events(this->readers); - return Realm::Event::merge_events(writer, readers); + return Realm::Event::merge_events(this->writer, readers); +} + +Realm::Event AtomicDependencySet::get_dependency_for_reader() const { + return this->writer; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/dependency_set.cc b/lib/realm-execution/src/realm-execution/dependency_set.cc index 3af03ffcef..84412a125d 100644 --- a/lib/realm-execution/src/realm-execution/dependency_set.cc +++ b/lib/realm-execution/src/realm-execution/dependency_set.cc @@ -21,10 +21,18 @@ void DependencySet::add_reader(DynamicValueAttrs const &value, atomic_dependence_set.add_reader(reader); } -Realm::Event DependencySet::get_current_outstanding_events( +Realm::Event DependencySet::get_dependency_for_writer( DynamicValueAttrs const &value) const { if (contains_key(this->atomic_dependencies, value)) { - return this->atomic_dependencies.at(value).get_current_outstanding_events(); + return this->atomic_dependencies.at(value).get_dependency_for_writer(); + } + return this->precondition; +} + +Realm::Event DependencySet::get_dependency_for_reader( + DynamicValueAttrs const &value) const { + if (contains_key(this->atomic_dependencies, value)) { + return this->atomic_dependencies.at(value).get_dependency_for_reader(); } return this->precondition; } diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index f7fcea87e7..4ea8d0bbd1 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -14,7 +14,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( RealmContext &ctx, ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs) { + OptimizerAttrs const &optimizer_attrs, + Realm::Event precondition) { // Initialize all operators and save the per-device op state ASSERT(no_nodes_are_initialized(dg)); @@ -38,7 +39,8 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( profiling_settings, iteration_config, optimizer_attrs, - output); + output, + precondition); if (result) { result_map[invocation] = output; } else { diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index c1654397ec..e636cbf259 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -7,11 +7,14 @@ #include "realm-execution/tasks/impl/op_task.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "task-spec/dynamic_graph/loss_insertion.h" #include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mpcg.h" #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/update_insertion.h" +#include "utils/containers/transform.h" +#include "utils/containers/values.h" #include "utils/graph/digraph/algorithms/get_topological_ordering.h" #include "utils/optional.h" @@ -77,17 +80,20 @@ PCGInstance create_parallel_computation_graph_instance( dg = perform_shard_expansion(dg); TensorInstanceBacking backing = perform_instance_allocation(dg, inputs, ctx); - // FIXME: for now we're going to be lazy and block on everything rather than - // do fine-grained dependencies on instances - ctx.get_outstanding_events().wait(); - std::optional logit_grad_tensor = transform(logit_grad_value, [&](DynamicValueAttrs const &lgv) { return backing.backing.at(lgv).first; }); + // FIXME: for now we're going to be lazy and block on everything rather than + // do fine-grained dependencies on instances dg = perform_distributed_device_state_initialization( - dg, ctx, profiling_settings, iteration_config, optimizer_attrs); + dg, + ctx, + profiling_settings, + iteration_config, + optimizer_attrs, + ctx.get_outstanding_events()); // Compute the topological ordering of the graph auto [kwarg_graph, node_map] = @@ -102,7 +108,6 @@ PCGInstance create_parallel_computation_graph_instance( // TODO list: // * Realm allocator // * external instances - // * dependencies // * task argument serializer // * copies } @@ -119,6 +124,19 @@ static std::unordered_map DependencySet dependency_set{ctx.get_outstanding_events()}; return unordered_map_from_pairs( transform(invocations, [&](DynamicNodeInvocation const &invocation) { + std::vector input_dependencies = + transform(vector_of(values(invocation.inputs)), + [&](DynamicValueAttrs const &value) { + return dependency_set.get_dependency_for_reader(value); + }); + std::vector output_dependencies = + transform(vector_of(values(invocation.outputs)), + [&](DynamicValueAttrs const &value) { + return dependency_set.get_dependency_for_writer(value); + }); + Realm::Event dependencies = Realm::Event::merge_events( + Realm::Event::merge_events(input_dependencies), + Realm::Event::merge_events(output_dependencies)); Realm::Event result = spawn_op_task(ctx, ctx.map_device_coord_to_processor(assert_unwrap( @@ -126,7 +144,14 @@ static std::unordered_map invocation, profiling_settings, iteration_config, - optimizer_attrs); + optimizer_attrs, + dependencies); + for (DynamicValueAttrs const &value : values(invocation.inputs)) { + dependency_set.add_reader(value, result); + } + for (DynamicValueAttrs const &value : values(invocation.outputs)) { + dependency_set.add_writer(value, result); + } return std::pair{invocation.node_attrs.layer_guid, result}; })); } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index 7233103cc3..adafea47e6 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -27,7 +27,8 @@ Realm::Event .only_kind(Realm::Processor::LOC_PROC) .first(); - return collective_spawn_controller_task(*this, target_proc, thunk); + return collective_spawn_controller_task( + *this, target_proc, thunk, Realm::Event::NO_EVENT); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc index 2fd5cee52d..285e8acaa7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/controller_task.cc @@ -21,17 +21,19 @@ void controller_task_body(void const *args, task_args.thunk(ctx); } -Realm::Event collective_spawn_controller_task( - RealmContext &ctx, - Realm::Processor &target_proc, - std::function thunk) { +Realm::Event + collective_spawn_controller_task(RealmContext &ctx, + Realm::Processor &target_proc, + std::function thunk, + Realm::Event precondition) { ControllerTaskArgs task_args; task_args.thunk = thunk; return ctx.collective_spawn_task(target_proc, task_id_t::CONTROLLER_TASK_ID, &task_args, - sizeof(task_args)); + sizeof(task_args), + precondition); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc index fa421cda30..610500a94b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc @@ -36,14 +36,16 @@ Realm::Event spawn_device_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr) { + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition) { DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; return ctx.spawn_task(origin_proc, task_id_t::DEVICE_INIT_RETURN_TASK_ID, &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc index cc080255e2..7f36f48921 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc @@ -66,7 +66,8 @@ void device_init_task_body(void const *args, spawn_device_init_return_task(ctx, task_args.origin_proc, *result_state_ptr, - task_args.origin_result_ptr); + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional @@ -76,7 +77,8 @@ std::optional ProfilingSettings const &profiling_settings, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr) { + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { DeviceInitTaskArgs task_args{ &invocation, &profiling_settings, @@ -97,7 +99,8 @@ std::optional assert_unwrap(task_id), &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } return std::nullopt; } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 5f6ab40607..216f0badde 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -58,13 +58,13 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event - spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs) { +Realm::Event spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{&invocation, &profiling_settings, &iteration_config, @@ -75,7 +75,8 @@ Realm::Event assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), &task_args, sizeof(task_args), - Realm::ProfilingRequestSet{}); + Realm::ProfilingRequestSet{}, + precondition); } } // namespace FlexFlow From 81cc4850cc2cccff70590bc21e29fb99cd73ddf2 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:18:53 -0800 Subject: [PATCH 43/63] Add event argument to controller. --- lib/realm-execution/include/realm-execution/realm_manager.h | 3 ++- lib/realm-execution/src/realm-execution/realm_manager.cc | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_manager.h b/lib/realm-execution/include/realm-execution/realm_manager.h index bf5e8f72f1..8a79476bcf 100644 --- a/lib/realm-execution/include/realm-execution/realm_manager.h +++ b/lib/realm-execution/include/realm-execution/realm_manager.h @@ -19,7 +19,8 @@ struct RealmManager : private RealmContext { RealmManager(RealmManager &&) = delete; [[nodiscard]] Realm::Event - start_controller(std::function); + start_controller(std::function, + Realm::Event wait_on = Realm::Event::NO_EVENT); }; } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index adafea47e6..fc74fffe5d 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -21,14 +21,14 @@ RealmManager::~RealmManager() { } Realm::Event - RealmManager::start_controller(std::function thunk) { + RealmManager::start_controller(std::function thunk, + Realm::Event wait_on) { Realm::Processor target_proc = Realm::Machine::ProcessorQuery(Realm::Machine::get_machine()) .only_kind(Realm::Processor::LOC_PROC) .first(); - return collective_spawn_controller_task( - *this, target_proc, thunk, Realm::Event::NO_EVENT); + return collective_spawn_controller_task(*this, target_proc, thunk, wait_on); } } // namespace FlexFlow From bb0ea6bd658065b73242ed2c46bca93bad028337 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 11:53:33 -0800 Subject: [PATCH 44/63] Implement the allocator. --- .../include/realm-execution/realm_allocator.h | 31 +++++++++++ .../include/realm-execution/realm_context.h | 6 ++- .../pcg_instance/pcg_instance.cc | 2 +- .../src/realm-execution/realm_allocator.cc | 53 +++++++++++++++++++ .../src/realm-execution/realm_context.cc | 10 ++-- 5 files changed, 95 insertions(+), 7 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/realm_allocator.h create mode 100644 lib/realm-execution/src/realm-execution/realm_allocator.cc diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h new file mode 100644 index 0000000000..dab6f3ea63 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_ALLOCATOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REALM_ALLOCATOR_H + +#include "kernels/allocation.h" +#include "realm-execution/realm.h" + +namespace FlexFlow { + +struct RealmAllocator : public IAllocator { + RealmAllocator(Realm::Processor processor, Realm::Memory memory); + RealmAllocator(RealmAllocator const &) = delete; + RealmAllocator(RealmAllocator &&) = delete; + ~RealmAllocator() = default; + + void *allocate(size_t) override; + void deallocate(void *) override; + + DeviceType get_allocation_device_type() const override; + +private: + Realm::Processor processor; + Realm::Memory memory; + std::unordered_map ptr_instances; +}; +CHECK_RC_COPY_VIRTUAL_COMPLIANT(RealmAllocator); + +Allocator get_realm_allocator(Realm::Processor processor, Realm::Memory memory); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index e28e91234e..755bf595d6 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -6,6 +6,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include @@ -23,11 +24,11 @@ struct RealmContext { // Device mapping Realm::Processor map_device_coord_to_processor(MachineSpaceCoordinate const &); - Realm::Memory get_nearest_memory(Realm::Processor) const; + static Realm::Memory get_nearest_memory(Realm::Processor); // Current device context Realm::Processor get_current_processor() const; - Allocator &get_current_device_allocator() const; + Allocator &get_current_device_allocator(); device_handle_t const &get_current_device_handle() const; device_id_t get_current_device_idx() const; @@ -68,6 +69,7 @@ struct RealmContext { protected: Realm::Runtime runtime; Realm::Processor processor; + Allocator allocator; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index e636cbf259..93b42743a0 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -106,7 +106,7 @@ PCGInstance create_parallel_computation_graph_instance( ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: - // * Realm allocator + // * current device handle // * external instances // * task argument serializer // * copies diff --git a/lib/realm-execution/src/realm-execution/realm_allocator.cc b/lib/realm-execution/src/realm-execution/realm_allocator.cc new file mode 100644 index 0000000000..f24106b0bc --- /dev/null +++ b/lib/realm-execution/src/realm-execution/realm_allocator.cc @@ -0,0 +1,53 @@ +#include "realm-execution/realm_allocator.h" +#include "kernels/device.h" +#include "pcg/device_type.dtg.h" + +namespace FlexFlow { + +RealmAllocator::RealmAllocator(Realm::Processor processor, Realm::Memory memory) + : processor(processor), memory(memory) {} + +void *RealmAllocator::allocate(size_t requested_memory_size) { + Realm::Rect<1> bounds{Realm::Point<1>::ZEROES(), + Realm::Point<1>{requested_memory_size} - + Realm::Point<1>::ONES()}; + std::vector field_sizes{1}; + Realm::RegionInstance inst; + Realm::Event ready = + Realm::RegionInstance::create_instance(inst, + this->memory, + bounds, + field_sizes, + 0 /*SOA*/, + Realm::ProfilingRequestSet{}); + ready.wait(); + void *ptr = + inst.pointer_untyped(/*offset=*/0, /*datalen=*/requested_memory_size); + ASSERT(ptr); + this->ptr_instances.insert({ptr, inst}); + return ptr; +} + +void RealmAllocator::deallocate(void *ptr) { + this->ptr_instances.at(ptr).destroy(Realm::Event::NO_EVENT); + this->ptr_instances.erase(ptr); +} + +DeviceType RealmAllocator::get_allocation_device_type() const { + switch (this->processor.kind()) { + case Realm::Processor::Kind::LOC_PROC: + return DeviceType::CPU; + case Realm::Processor::Kind::TOC_PROC: + return DeviceType::GPU; + default: + PANIC("Unhandled FwbTensorType", this->processor.kind()); + } +} + +Allocator get_realm_allocator(Realm::Processor processor, + Realm::Memory memory) { + Allocator allocator = Allocator::create(processor, memory); + return allocator; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 781561c95a..a77383779f 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -14,7 +14,9 @@ namespace FlexFlow { -RealmContext::RealmContext(Realm::Processor proc) : processor(proc) {} +RealmContext::RealmContext(Realm::Processor proc) + : processor(proc), allocator(get_realm_allocator( + proc, RealmContext::get_nearest_memory(proc))) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -51,7 +53,7 @@ Realm::Processor RealmContext::map_device_coord_to_processor( return this->processors.at(std::pair{as, kind}).at(int{proc_in_node}); } -Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) const { +Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) { // FIMXE: this isn't going to do what you expect until // https://github.com/StanfordLegion/realm/pull/392 merges Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); @@ -64,8 +66,8 @@ Realm::Processor RealmContext::get_current_processor() const { return this->processor; } -Allocator &RealmContext::get_current_device_allocator() const { - NOT_IMPLEMENTED(); +Allocator &RealmContext::get_current_device_allocator() { + return this->allocator; } device_handle_t const &RealmContext::get_current_device_handle() const { From feb5897b2a927c17cc367488e40cda9b6a89a435 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 12:17:50 -0800 Subject: [PATCH 45/63] Implement device handle. --- .../include/realm-execution/realm_allocator.h | 2 + .../include/realm-execution/realm_context.h | 10 ++++- .../pcg_instance/pcg_instance.cc | 1 - .../src/realm-execution/realm_context.cc | 42 ++++++++++++++++--- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/realm_allocator.h b/lib/realm-execution/include/realm-execution/realm_allocator.h index dab6f3ea63..d72f2d7f91 100644 --- a/lib/realm-execution/include/realm-execution/realm_allocator.h +++ b/lib/realm-execution/include/realm-execution/realm_allocator.h @@ -8,6 +8,8 @@ namespace FlexFlow { struct RealmAllocator : public IAllocator { RealmAllocator(Realm::Processor processor, Realm::Memory memory); + + RealmAllocator() = delete; RealmAllocator(RealmAllocator const &) = delete; RealmAllocator(RealmAllocator &&) = delete; ~RealmAllocator() = default; diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index 755bf595d6..eb4d6d0935 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -3,18 +3,19 @@ #include "kernels/allocation.h" #include "kernels/device_handle_t.dtg.h" +#include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" -#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/task_id_t.dtg.h" +#include #include namespace FlexFlow { struct RealmContext { public: - RealmContext(Realm::Processor); + RealmContext(Realm::Processor processor); virtual ~RealmContext(); RealmContext() = delete; @@ -66,10 +67,15 @@ struct RealmContext { void discover_machine_topology(); + static std::optional + make_device_handle_for_processor(Realm::Processor processor); + protected: Realm::Runtime runtime; Realm::Processor processor; Allocator allocator; + std::optional managed_handle; + device_handle_t device_handle; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 93b42743a0..d56dbb9ca9 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -106,7 +106,6 @@ PCGInstance create_parallel_computation_graph_instance( ctx, invocation_topo_order, optimizer_attrs, logit_grad_tensor}; // TODO list: - // * current device handle // * external instances // * task argument serializer // * copies diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index a77383779f..38ce052da9 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -1,22 +1,27 @@ #include "realm-execution/realm_context.h" +#include "kernels/device_handle_t.dtg.h" +#include "kernels/device_handle_t.h" #include "op-attrs/datatype.h" #include "op-attrs/tensor_dims.dtg.h" #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" +#include "realm-execution/realm_allocator.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "utils/containers/contains_key.h" #include "utils/containers/transform.h" -#include "utils/exception.h" #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/one_to_many/one_to_many.h" #include "utils/positive_int/positive_int.h" namespace FlexFlow { -RealmContext::RealmContext(Realm::Processor proc) - : processor(proc), allocator(get_realm_allocator( - proc, RealmContext::get_nearest_memory(proc))) {} +RealmContext::RealmContext(Realm::Processor processor) + : processor(processor), + allocator(get_realm_allocator( + processor, RealmContext::get_nearest_memory(processor))), + managed_handle(RealmContext::make_device_handle_for_processor(processor)), + device_handle(device_handle_t_from_managed_handle(managed_handle)) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -54,6 +59,10 @@ Realm::Processor RealmContext::map_device_coord_to_processor( } Realm::Memory RealmContext::get_nearest_memory(Realm::Processor proc) { + if (!proc.exists()) { + return Realm::Memory::NO_MEMORY; + } + // FIMXE: this isn't going to do what you expect until // https://github.com/StanfordLegion/realm/pull/392 merges Realm::Machine::MemoryQuery mq(Realm::Machine::get_machine()); @@ -71,8 +80,9 @@ Allocator &RealmContext::get_current_device_allocator() { } device_handle_t const &RealmContext::get_current_device_handle() const { - NOT_IMPLEMENTED(); + return this->device_handle; } + device_id_t RealmContext::get_current_device_idx() const { Realm::Processor proc = this->get_current_processor(); @@ -245,4 +255,26 @@ void RealmContext::discover_machine_topology() { } } +std::optional + RealmContext::make_device_handle_for_processor(Realm::Processor processor) { + if (!processor.exists()) { + return std::nullopt; + } + + switch (processor.kind()) { + case Realm::Processor::LOC_PROC: + return std::nullopt; + case Realm::Processor::TOC_PROC: + // FIXME: not sure what workSpaceSize to choose here + return initialize_multi_gpu_handle( + /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), + /*my_rank=*/processor.address_space(), + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + default: + PANIC("Unhandled Realm::ProcessorKind", + fmt::to_string(int{processor.kind()})); + } +} + } // namespace FlexFlow From 202889fceb01cda4565a2da779dfd6354f98e18a Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 14:47:23 -0800 Subject: [PATCH 46/63] Distributed device handle initialization. --- .../distributed_device_handle.h | 38 +++++++ .../impl/device_handle_init_return_task.h | 24 +++++ .../tasks/impl/device_handle_init_task.h | 24 +++++ .../tasks/impl/device_init_task.h | 29 ----- ...task.h => device_state_init_return_task.h} | 8 +- .../tasks/impl/device_state_init_task.h | 29 +++++ .../realm-execution/tasks/task_id_t.dtg.toml | 8 +- .../distributed_device_handle.cc | 50 +++++++++ ...distributed_device_state_initialization.cc | 18 ++-- .../impl/device_handle_init_return_task.cc | 55 ++++++++++ .../tasks/impl/device_handle_init_task.cc | 100 ++++++++++++++++++ .../tasks/impl/device_init_return_task.cc | 51 --------- .../impl/device_state_init_return_task.cc | 53 ++++++++++ ...init_task.cc => device_state_init_task.cc} | 67 ++++++------ .../tasks/realm_task_registry.cc | 10 +- 15 files changed, 432 insertions(+), 132 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/distributed_device_handle.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h delete mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h rename lib/realm-execution/include/realm-execution/tasks/impl/{device_init_return_task.h => device_state_init_return_task.h} (77%) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h create mode 100644 lib/realm-execution/src/realm-execution/distributed_device_handle.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc delete mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc rename lib/realm-execution/src/realm-execution/tasks/impl/{device_init_task.cc => device_state_init_task.cc} (58%) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h new file mode 100644 index 0000000000..ca3f08fc41 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -0,0 +1,38 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific.h" +#include +#include + +namespace FlexFlow { + +struct DistributedDeviceHandle { +public: + DistributedDeviceHandle() = delete; + explicit DistributedDeviceHandle( + std::map>> const + &handles); + + DeviceSpecific> const & + at(Realm::Processor processor) const; + +private: + std::map>> + handles; +}; + +DistributedDeviceHandle create_distributed_device_handle( + RealmContext &ctx, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Event precondition = Realm::Event::NO_EVENT); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h new file mode 100644 index 0000000000..8b358ee4ce --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_handle_init_return_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_handle_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecific> const &result, + DeviceSpecific> + *origin_result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h new file mode 100644 index 0000000000..c26633bd9a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H + +#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" + +namespace FlexFlow { + +void device_handle_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +Realm::Event spawn_device_handle_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + DeviceSpecific> *result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h deleted file mode 100644 index 7842963c7b..0000000000 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_task.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_TASK_H - -#include "kernels/profiling_settings.dtg.h" -#include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/realm.h" -#include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" -#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" -#include "task-spec/ff_iteration_config.dtg.h" - -namespace FlexFlow { - -void device_init_task_body( - void const *, size_t, void const *, size_t, Realm::Processor); - -std::optional - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition); - -} // namespace FlexFlow - -#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h similarity index 77% rename from lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h rename to lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h index 0f92b35c24..8f44680815 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_return_task.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_INIT_RETURN_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_RETURN_TASK_H #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" @@ -7,10 +7,10 @@ namespace FlexFlow { -void device_init_return_task_body( +void device_state_init_return_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_device_init_return_task( +Realm::Event spawn_device_state_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState const &result, diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h new file mode 100644 index 0000000000..4cd65a0a2a --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -0,0 +1,29 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_STATE_INIT_TASK_H + +#include "kernels/profiling_settings.dtg.h" +#include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/realm_context.h" +#include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/ff_iteration_config.dtg.h" + +namespace FlexFlow { + +void device_state_init_task_body( + void const *, size_t, void const *, size_t, Realm::Processor); + +std::optional + spawn_device_state_init_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index 34e5183488..97b19b5f51 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -12,7 +12,13 @@ features = [ name = "CONTROLLER_TASK_ID" [[values]] -name = "DEVICE_INIT_RETURN_TASK_ID" +name = "DEVICE_HANDLE_INIT_TASK_ID" + +[[values]] +name = "DEVICE_HANDLE_INIT_RETURN_TASK_ID" + +[[values]] +name = "DEVICE_STATE_INIT_RETURN_TASK_ID" [[values]] name = "IMAGE_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc new file mode 100644 index 0000000000..00c2e76360 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -0,0 +1,50 @@ +#include "realm-execution/distributed_device_handle.h" +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "task-spec/device_specific.h" + +namespace FlexFlow { + +DistributedDeviceHandle::DistributedDeviceHandle( + std::map>> const + &handles) + : handles(handles) {} + +DeviceSpecific> const & + DistributedDeviceHandle::at(Realm::Processor processor) const { + return this->handles.at(processor); +} + +DistributedDeviceHandle + create_distributed_device_handle(RealmContext &ctx, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Event precondition) { + std::map>> + handles; + + // Allocate space for the result before launching any tasks + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + for (Realm::Processor proc : pq) { + handles.insert( + {proc, + DeviceSpecific>::create( + ctx.get_current_device_idx(), std::nullopt)}); + } + + for (auto &[proc, handle] : handles) { + spawn_device_handle_init_task(ctx, + proc, + workSpaceSize, + allowTensorOpMathConversion, + &handle, + precondition); + } + + ctx.get_outstanding_events().wait(); + + return DistributedDeviceHandle{handles}; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index 4ea8d0bbd1..9627a71e87 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -1,6 +1,6 @@ #include "realm-execution/distributed_device_state_initialization.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "utils/optional.h" @@ -33,14 +33,14 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( static_cast( malloc(sizeof(DeviceSpecificPerDeviceOpState))); std::optional result = - spawn_device_init_task(ctx, - target_proc, - invocation, - profiling_settings, - iteration_config, - optimizer_attrs, - output, - precondition); + spawn_device_state_init_task(ctx, + target_proc, + invocation, + profiling_settings, + iteration_config, + optimizer_attrs, + output, + precondition); if (result) { result_map[invocation] = output; } else { diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc new file mode 100644 index 0000000000..2839beef0c --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc @@ -0,0 +1,55 @@ +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceHandleInitReturnTaskArgs { +public: + DeviceHandleInitReturnTaskArgs() = delete; + DeviceHandleInitReturnTaskArgs( + DeviceSpecific> result, + Realm::Processor origin_proc, + DeviceSpecific> + *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecific> result; + Realm::Processor origin_proc; + DeviceSpecific> *origin_result_ptr; +}; + +void device_handle_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceHandleInitReturnTaskArgs)); + DeviceHandleInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_handle_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecific> const &result, + DeviceSpecific> + *origin_result_ptr, + Realm::Event precondition) { + DeviceHandleInitReturnTaskArgs task_args{ + result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_HANDLE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc new file mode 100644 index 0000000000..86a576d26b --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -0,0 +1,100 @@ +#include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" +#include + +namespace FlexFlow { + +// TODO: at some point we're going to have to actually serialize these, but for +// now just pass the pointer and assume we're running inside a single address +// space +struct DeviceHandleInitTaskArgs { + DeviceHandleInitTaskArgs() = delete; + DeviceHandleInitTaskArgs( + size_t workSpaceSize, + bool allowTensorOpMathConversion, + Realm::Processor origin_proc, + DeviceSpecific> + *origin_result_ptr) + : workSpaceSize(workSpaceSize), + allowTensorOpMathConversion(allowTensorOpMathConversion), + origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + +public: + size_t workSpaceSize; + bool allowTensorOpMathConversion; + Realm::Processor origin_proc; + DeviceSpecific> *origin_result_ptr; +}; +static_assert(std::is_trivially_copy_constructible_v); + +static std::optional + make_device_handle_for_processor(Realm::Processor processor, + size_t workSpaceSize, + bool allowTensorOpMathConversion) { + switch (processor.kind()) { + case Realm::Processor::LOC_PROC: + return std::nullopt; + case Realm::Processor::TOC_PROC: + return new ManagedPerDeviceFFHandle{initialize_multi_gpu_handle( + /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), + /*my_rank=*/processor.address_space(), + /*workSpaceSize=*/workSpaceSize, + /*allowTensorOpMathConversion=*/allowTensorOpMathConversion)}; + default: + PANIC("Unhandled Realm::ProcessorKind", + fmt::to_string(int{processor.kind()})); + } +} + +void device_handle_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceHandleInitTaskArgs)); + DeviceHandleInitTaskArgs task_args = + *reinterpret_cast(args); + + // FIXME: serialize instead of passing pointers around + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + + RealmContext ctx{proc}; + DeviceSpecific> managed_handle = + DeviceSpecific>::create( + ctx.get_current_device_idx(), + make_device_handle_for_processor( + proc, + task_args.workSpaceSize, + task_args.allowTensorOpMathConversion)); + + spawn_device_handle_init_return_task(ctx, + task_args.origin_proc, + managed_handle, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); +} + +Realm::Event spawn_device_handle_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + size_t workSpaceSize, + bool allowTensorOpMathConversion, + DeviceSpecific> *result_ptr, + Realm::Event precondition) { + DeviceHandleInitTaskArgs task_args{ + workSpaceSize, + allowTensorOpMathConversion, + ctx.get_current_processor(), + result_ptr, + }; + + return ctx.spawn_task(target_proc, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc deleted file mode 100644 index 610500a94b..0000000000 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_return_task.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "realm-execution/tasks/impl/device_init_task.h" -#include "realm-execution/tasks/task_id_t.dtg.h" - -namespace FlexFlow { - -// FIXME: Can't make this trivially copyable? -struct DeviceInitReturnTaskArgs { -public: - DeviceInitReturnTaskArgs() = delete; - DeviceInitReturnTaskArgs(DeviceSpecificPerDeviceOpState result, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) - : result(result), origin_proc(origin_proc), - origin_result_ptr(origin_result_ptr) {} - -public: - DeviceSpecificPerDeviceOpState result; - Realm::Processor origin_proc; - DeviceSpecificPerDeviceOpState *origin_result_ptr; -}; - -void device_init_return_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceInitReturnTaskArgs)); - DeviceInitReturnTaskArgs task_args = - *reinterpret_cast(args); - - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); - *task_args.origin_result_ptr = task_args.result; -} - -Realm::Event spawn_device_init_return_task( - RealmContext &ctx, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState const &result, - DeviceSpecificPerDeviceOpState *origin_result_ptr, - Realm::Event precondition) { - DeviceInitReturnTaskArgs task_args{result, origin_proc, origin_result_ptr}; - - return ctx.spawn_task(origin_proc, - task_id_t::DEVICE_INIT_RETURN_TASK_ID, - &task_args, - sizeof(task_args), - Realm::ProfilingRequestSet{}, - precondition); -} - -} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc new file mode 100644 index 0000000000..c1bd7c1081 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc @@ -0,0 +1,53 @@ +#include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/task_id_t.dtg.h" + +namespace FlexFlow { + +// FIXME: Can't make this trivially copyable? +struct DeviceStateInitReturnTaskArgs { +public: + DeviceStateInitReturnTaskArgs() = delete; + DeviceStateInitReturnTaskArgs( + DeviceSpecificPerDeviceOpState result, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) + : result(result), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} + +public: + DeviceSpecificPerDeviceOpState result; + Realm::Processor origin_proc; + DeviceSpecificPerDeviceOpState *origin_result_ptr; +}; + +void device_state_init_return_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceStateInitReturnTaskArgs)); + DeviceStateInitReturnTaskArgs task_args = + *reinterpret_cast(args); + + ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + *task_args.origin_result_ptr = task_args.result; +} + +Realm::Event spawn_device_state_init_return_task( + RealmContext &ctx, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState const &result, + DeviceSpecificPerDeviceOpState *origin_result_ptr, + Realm::Event precondition) { + DeviceStateInitReturnTaskArgs task_args{ + result, origin_proc, origin_result_ptr}; + + return ctx.spawn_task(origin_proc, + task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, + &task_args, + sizeof(task_args), + Realm::ProfilingRequestSet{}, + precondition); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc similarity index 58% rename from lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc rename to lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 7f36f48921..f63efba14b 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,6 +1,6 @@ -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/tasks/impl/device_init_return_task.h" +#include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" @@ -14,14 +14,14 @@ namespace FlexFlow { // TODO: at some point we're going to have to actually serialize these, but for // now just pass the pointer and assume we're running inside a single address // space -struct DeviceInitTaskArgs { - DeviceInitTaskArgs() = delete; - DeviceInitTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) +struct DeviceStateInitTaskArgs { + DeviceStateInitTaskArgs() = delete; + DeviceStateInitTaskArgs(DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} @@ -34,16 +34,17 @@ struct DeviceInitTaskArgs { Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; -static_assert(std::has_unique_object_representations_v); +static_assert( + std::has_unique_object_representations_v); -void device_init_task_body(void const *args, - size_t arglen, - void const *userdata, - size_t userlen, - Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceInitTaskArgs)); - DeviceInitTaskArgs task_args = - *reinterpret_cast(args); +void device_state_init_task_body(void const *args, + size_t arglen, + void const *userdata, + size_t userlen, + Realm::Processor proc) { + ASSERT(arglen == sizeof(DeviceStateInitTaskArgs)); + DeviceStateInitTaskArgs task_args = + *reinterpret_cast(args); // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); @@ -63,23 +64,23 @@ void device_init_task_body(void const *args, // the allocation here DeviceSpecificPerDeviceOpState *result_state_ptr = new DeviceSpecificPerDeviceOpState{result_state}; - spawn_device_init_return_task(ctx, - task_args.origin_proc, - *result_state_ptr, - task_args.origin_result_ptr, - Realm::Event::NO_EVENT); + spawn_device_state_init_return_task(ctx, + task_args.origin_proc, + *result_state_ptr, + task_args.origin_result_ptr, + Realm::Event::NO_EVENT); } std::optional - spawn_device_init_task(RealmContext &ctx, - Realm::Processor &target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition) { - DeviceInitTaskArgs task_args{ + spawn_device_state_init_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { + DeviceStateInitTaskArgs task_args{ &invocation, &profiling_settings, &iteration_config, diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index c63d4727a9..9150ce6892 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,7 +1,7 @@ #include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/tasks/impl/controller_task.h" -#include "realm-execution/tasks/impl/device_init_return_task.h" -#include "realm-execution/tasks/impl/device_init_task.h" +#include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/impl/device_state_init_task.h" #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tasks/realm_task_id_t.h" #include "utils/exception.h" @@ -48,7 +48,7 @@ Realm::Event register_all_tasks() { for (task_id_t task_id : init_task_ids) { pending_registrations.push_back(register_task( - Realm::Processor::TOC_PROC, task_id, device_init_task_body)); + Realm::Processor::TOC_PROC, task_id, device_state_init_task_body)); } std::vector task_ids = { @@ -127,8 +127,8 @@ Realm::Event register_all_tasks() { controller_task_body)); pending_registrations.push_back( register_task(Realm::Processor::LOC_PROC, - task_id_t::DEVICE_INIT_RETURN_TASK_ID, - device_init_return_task_body)); + task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, + device_state_init_return_task_body)); return Realm::Event::merge_events(pending_registrations); } From 8f816f058af64565fc12a122c9f9488e81e90423 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 16:45:58 -0800 Subject: [PATCH 47/63] Distributed device handle initialization. --- lib/kernels/include/kernels/device_handle_t.h | 3 ++ lib/kernels/src/kernels/device_handle_t.cc | 9 ++++ ...ific_managed_per_device_ff_handle.dtg.toml | 16 ++++++ ...ce_specific_managed_per_device_ff_handle.h | 19 +++++++ .../distributed_device_handle.h | 16 +++--- .../distributed_device_state_initialization.h | 2 + .../include/realm-execution/fmt/instance.h | 4 +- .../include/realm-execution/hash/processor.h | 16 ++++++ .../pcg_instance/pcg_instance.h | 2 + .../include/realm-execution/realm_context.h | 3 -- .../impl/device_handle_init_return_task.h | 8 ++- .../tasks/impl/device_handle_init_task.h | 5 +- .../tasks/impl/device_state_init_task.h | 20 ++++---- .../realm-execution/tasks/impl/op_task.h | 17 ++++--- ...e_specific_managed_per_device_ff_handle.cc | 21 ++++++++ .../distributed_device_handle.cc | 17 +++---- ...distributed_device_state_initialization.cc | 2 + .../src/realm-execution/hash/processor.cc | 11 +++++ .../pcg_instance/pcg_instance.cc | 30 ++++++++---- .../src/realm-execution/realm_context.cc | 30 +----------- .../impl/device_handle_init_return_task.cc | 15 +++--- .../tasks/impl/device_handle_init_task.cc | 12 ++--- .../impl/device_state_init_return_task.cc | 1 - .../tasks/impl/device_state_init_task.cc | 49 +++++++++++-------- .../src/realm-execution/tasks/impl/op_task.cc | 29 +++++++---- 25 files changed, 224 insertions(+), 133 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h create mode 100644 lib/realm-execution/include/realm-execution/hash/processor.h create mode 100644 lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc create mode 100644 lib/realm-execution/src/realm-execution/hash/processor.cc diff --git a/lib/kernels/include/kernels/device_handle_t.h b/lib/kernels/include/kernels/device_handle_t.h index 9b7769355e..0836503717 100644 --- a/lib/kernels/include/kernels/device_handle_t.h +++ b/lib/kernels/include/kernels/device_handle_t.h @@ -9,6 +9,9 @@ namespace FlexFlow { device_handle_t device_handle_t_from_managed_handle( std::optional const &managed_handle); +device_handle_t device_handle_t_from_managed_handle_ptr( + std::optional const &managed_handle); + device_handle_t gpu_make_device_handle_t(PerDeviceFFHandle const &ff_handle); device_handle_t cpu_make_device_handle_t(); diff --git a/lib/kernels/src/kernels/device_handle_t.cc b/lib/kernels/src/kernels/device_handle_t.cc index 85f9e2a388..0225ee8e94 100644 --- a/lib/kernels/src/kernels/device_handle_t.cc +++ b/lib/kernels/src/kernels/device_handle_t.cc @@ -11,6 +11,15 @@ device_handle_t device_handle_t_from_managed_handle( } } +device_handle_t device_handle_t_from_managed_handle_ptr( + std::optional const &managed_handle) { + if (managed_handle.has_value()) { + return gpu_make_device_handle_t(managed_handle.value()->raw_handle()); + } else { + return cpu_make_device_handle_t(); + } +} + device_handle_t gpu_make_device_handle_t(PerDeviceFFHandle const &ff_handle) { return device_handle_t{ ff_handle, diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml new file mode 100644 index 0000000000..1458adcba3 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml @@ -0,0 +1,16 @@ +namespace = "FlexFlow" +name = "DeviceSpecificManagedPerDeviceFFHandle" +type = "struct" +features = [ + "eq", +] + +includes = [ + "", + "kernels/managed_per_device_ff_handle.h", + "task-spec/device_specific.h", +] + +[[fields]] +name = "handle" +type = "::FlexFlow::DeviceSpecific>" diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h new file mode 100644 index 0000000000..eefa6c86ac --- /dev/null +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_MANAGED_PER_DEVICE_FF_HANDLE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DEVICE_SPECIFIC_MANAGED_PER_DEVICE_FF_HANDLE_H + +#include "kernels/device_handle_t.dtg.h" +#include "kernels/managed_per_device_ff_handle.h" +#include "pcg/device_id_t.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" + +namespace FlexFlow { + +DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( + device_id_t const &, std::optional const &); + +device_handle_t device_handle_t_from_device_specific_managed_handle( + DeviceSpecificManagedPerDeviceFFHandle const &, device_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index ca3f08fc41..3f55c47192 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -1,12 +1,11 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/hash/processor.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific.h" -#include -#include +#include namespace FlexFlow { @@ -14,17 +13,14 @@ struct DistributedDeviceHandle { public: DistributedDeviceHandle() = delete; explicit DistributedDeviceHandle( - std::map>> const + std::unordered_map const &handles); - DeviceSpecific> const & + DeviceSpecificManagedPerDeviceFFHandle const & at(Realm::Processor processor) const; private: - std::map>> - handles; + std::unordered_map handles; }; DistributedDeviceHandle create_distributed_device_handle( diff --git a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h index 5530f473d8..ca24ecdd4c 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_state_initialization.h @@ -3,6 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/ff_iteration_config.dtg.h" @@ -13,6 +14,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition); diff --git a/lib/realm-execution/include/realm-execution/fmt/instance.h b/lib/realm-execution/include/realm-execution/fmt/instance.h index b2efc59b7d..c7c2df6735 100644 --- a/lib/realm-execution/include/realm-execution/fmt/instance.h +++ b/lib/realm-execution/include/realm-execution/fmt/instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FMT_PAIR_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_FMT_INSTANCE_H #include "realm-execution/realm.h" #include "utils/check_fmtable.h" diff --git a/lib/realm-execution/include/realm-execution/hash/processor.h b/lib/realm-execution/include/realm-execution/hash/processor.h new file mode 100644 index 0000000000..e5eb8eb503 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/hash/processor.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_HASH_PROCESSOR_H + +#include "realm-execution/realm.h" +#include + +namespace std { + +template <> +struct hash<::FlexFlow::Realm::Processor> { + size_t operator()(::FlexFlow::Realm::Processor const &p) const; +}; + +} // namespace std + +#endif diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index 3c5b4189ea..b917477df4 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -10,6 +10,7 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/optimizer_attrs.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" @@ -53,6 +54,7 @@ PCGInstance create_pcg_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index eb4d6d0935..b8baad41b9 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -30,7 +30,6 @@ struct RealmContext { // Current device context Realm::Processor get_current_processor() const; Allocator &get_current_device_allocator(); - device_handle_t const &get_current_device_handle() const; device_id_t get_current_device_idx() const; // Task creation @@ -74,8 +73,6 @@ struct RealmContext { Realm::Runtime runtime; Realm::Processor processor; Allocator allocator; - std::optional managed_handle; - device_handle_t device_handle; std::vector outstanding_events; std::unordered_map, std::vector> diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h index 8b358ee4ce..9bae546403 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -1,10 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" namespace FlexFlow { @@ -14,9 +13,8 @@ void device_handle_init_return_task_body( Realm::Event spawn_device_handle_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecific> const &result, - DeviceSpecific> - *origin_result_ptr, + DeviceSpecificManagedPerDeviceFFHandle const &result, + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h index c26633bd9a..624eb6e682 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -1,10 +1,9 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H -#include "kernels/managed_per_device_ff_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" namespace FlexFlow { @@ -16,7 +15,7 @@ Realm::Event spawn_device_handle_init_task( Realm::Processor target_proc, size_t workSpaceSize, bool allowTensorOpMathConversion, - DeviceSpecific> *result_ptr, + DeviceSpecificManagedPerDeviceFFHandle *result_ptr, Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 4cd65a0a2a..933d4f9283 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -3,6 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" @@ -14,15 +15,16 @@ namespace FlexFlow { void device_state_init_task_body( void const *, size_t, void const *, size_t, Realm::Processor); -std::optional - spawn_device_state_init_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition); +std::optional spawn_device_state_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 21d8795339..847154192a 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -4,6 +4,7 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" @@ -13,13 +14,15 @@ namespace FlexFlow { void op_task_body(void const *, size_t, void const *, size_t, Realm::Processor); -Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition); +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition); } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc new file mode 100644 index 0000000000..440b9d18f7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -0,0 +1,21 @@ +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "kernels/device_handle_t.h" + +namespace FlexFlow { + +DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( + device_id_t const &device_id, + std::optional const &managed_handle) { + return DeviceSpecificManagedPerDeviceFFHandle{ + DeviceSpecific>::create( + device_id, managed_handle)}; +} + +device_handle_t device_handle_t_from_device_specific_managed_handle( + DeviceSpecificManagedPerDeviceFFHandle const &device_specific, + device_id_t device_idx) { + return device_handle_t_from_managed_handle_ptr( + *device_specific.handle.get(device_idx)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 00c2e76360..404feb014c 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -1,16 +1,16 @@ #include "realm-execution/distributed_device_handle.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_task.h" #include "task-spec/device_specific.h" namespace FlexFlow { DistributedDeviceHandle::DistributedDeviceHandle( - std::map>> const + std::unordered_map const &handles) : handles(handles) {} -DeviceSpecific> const & +DeviceSpecificManagedPerDeviceFFHandle const & DistributedDeviceHandle::at(Realm::Processor processor) const { return this->handles.at(processor); } @@ -20,17 +20,14 @@ DistributedDeviceHandle size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Event precondition) { - std::map>> - handles; + std::unordered_map handles; // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); for (Realm::Processor proc : pq) { - handles.insert( - {proc, - DeviceSpecific>::create( - ctx.get_current_device_idx(), std::nullopt)}); + handles.insert({proc, + make_device_specific_managed_handle( + ctx.get_current_device_idx(), std::nullopt)}); } for (auto &[proc, handle] : handles) { diff --git a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc index 9627a71e87..cab2b49e15 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_state_initialization.cc @@ -13,6 +13,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( DynamicOpenDataflowGraph const &dg, RealmContext &ctx, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config, OptimizerAttrs const &optimizer_attrs, Realm::Event precondition) { @@ -37,6 +38,7 @@ DynamicOpenDataflowGraph perform_distributed_device_state_initialization( target_proc, invocation, profiling_settings, + device_handle.at(target_proc), iteration_config, optimizer_attrs, output, diff --git a/lib/realm-execution/src/realm-execution/hash/processor.cc b/lib/realm-execution/src/realm-execution/hash/processor.cc new file mode 100644 index 0000000000..dcc1bc5d06 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/hash/processor.cc @@ -0,0 +1,11 @@ +#include "realm-execution/hash/processor.h" +#include + +namespace std { + +size_t hash<::FlexFlow::Realm::Processor>::operator()( + ::FlexFlow::Realm::Processor const &p) const { + return hash<::FlexFlow::Realm::Processor::id_t>{}(p.id); +} + +} // namespace std diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index d56dbb9ca9..c79d8e8abd 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -57,6 +57,7 @@ PCGInstance create_parallel_computation_graph_instance( std::unordered_map const &input_tensors, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config) { DynamicOpenDataflowGraph dg = @@ -91,6 +92,7 @@ PCGInstance create_parallel_computation_graph_instance( dg, ctx, profiling_settings, + device_handle, iteration_config, optimizer_attrs, ctx.get_outstanding_events()); @@ -117,6 +119,7 @@ static std::unordered_map std::vector const &invocations, OptimizerAttrs const &optimizer_attrs, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { // For simplicity we'll track a dependency on all outstanding operations up to // this point. This will create an effective barrier between phases. @@ -136,15 +139,16 @@ static std::unordered_map Realm::Event dependencies = Realm::Event::merge_events( Realm::Event::merge_events(input_dependencies), Realm::Event::merge_events(output_dependencies)); - Realm::Event result = - spawn_op_task(ctx, - ctx.map_device_coord_to_processor(assert_unwrap( - invocation.node_attrs.device_coord)), - invocation, - profiling_settings, - iteration_config, - optimizer_attrs, - dependencies); + Realm::Processor target_proc = ctx.map_device_coord_to_processor( + assert_unwrap(invocation.node_attrs.device_coord)); + Realm::Event result = spawn_op_task(ctx, + target_proc, + invocation, + profiling_settings, + device_handle.at(target_proc), + iteration_config, + optimizer_attrs, + dependencies); for (DynamicValueAttrs const &value : values(invocation.inputs)) { dependency_set.add_reader(value, result); } @@ -159,6 +163,7 @@ std::unordered_map perform_all_passes_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = instance.get_execution_order(); @@ -168,6 +173,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; @@ -177,6 +183,7 @@ std::unordered_map perform_forward_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -191,6 +198,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); } @@ -198,6 +206,7 @@ std::unordered_map perform_backward_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -212,6 +221,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); } @@ -219,6 +229,7 @@ std::unordered_map perform_update_pass_for_parallel_computation_graph_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, FFIterationConfig iteration_config) { std::vector execution_order = filter(instance.get_execution_order(), @@ -234,6 +245,7 @@ std::unordered_map /*invocations=*/execution_order, /*optimizer_attrs=*/instance.get_optimizer_attrs(), /*profiling_settings=*/profiling_settings, + /*device_handle=*/device_handle, /*iteration_config=*/iteration_config); instance.update_optimizer_attrs_for_next_iter(); return result; diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 38ce052da9..3427e8cbee 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -19,9 +19,7 @@ namespace FlexFlow { RealmContext::RealmContext(Realm::Processor processor) : processor(processor), allocator(get_realm_allocator( - processor, RealmContext::get_nearest_memory(processor))), - managed_handle(RealmContext::make_device_handle_for_processor(processor)), - device_handle(device_handle_t_from_managed_handle(managed_handle)) {} + processor, RealmContext::get_nearest_memory(processor))) {} RealmContext::~RealmContext() { if (!this->outstanding_events.empty()) { @@ -79,10 +77,6 @@ Allocator &RealmContext::get_current_device_allocator() { return this->allocator; } -device_handle_t const &RealmContext::get_current_device_handle() const { - return this->device_handle; -} - device_id_t RealmContext::get_current_device_idx() const { Realm::Processor proc = this->get_current_processor(); @@ -255,26 +249,4 @@ void RealmContext::discover_machine_topology() { } } -std::optional - RealmContext::make_device_handle_for_processor(Realm::Processor processor) { - if (!processor.exists()) { - return std::nullopt; - } - - switch (processor.kind()) { - case Realm::Processor::LOC_PROC: - return std::nullopt; - case Realm::Processor::TOC_PROC: - // FIXME: not sure what workSpaceSize to choose here - return initialize_multi_gpu_handle( - /*num_ranks=*/Realm::Machine::get_machine().get_address_space_count(), - /*my_rank=*/processor.address_space(), - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); - default: - PANIC("Unhandled Realm::ProcessorKind", - fmt::to_string(int{processor.kind()})); - } -} - } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc index 2839beef0c..bda6f7781c 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_return_task.cc @@ -3,22 +3,20 @@ namespace FlexFlow { -// FIXME: Can't make this trivially copyable? struct DeviceHandleInitReturnTaskArgs { public: DeviceHandleInitReturnTaskArgs() = delete; DeviceHandleInitReturnTaskArgs( - DeviceSpecific> result, + DeviceSpecificManagedPerDeviceFFHandle result, Realm::Processor origin_proc, - DeviceSpecific> - *origin_result_ptr) + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) : result(result), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} public: - DeviceSpecific> result; + DeviceSpecificManagedPerDeviceFFHandle result; Realm::Processor origin_proc; - DeviceSpecific> *origin_result_ptr; + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; }; void device_handle_init_return_task_body(void const *args, @@ -37,9 +35,8 @@ void device_handle_init_return_task_body(void const *args, Realm::Event spawn_device_handle_init_return_task( RealmContext &ctx, Realm::Processor origin_proc, - DeviceSpecific> const &result, - DeviceSpecific> - *origin_result_ptr, + DeviceSpecificManagedPerDeviceFFHandle const &result, + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr, Realm::Event precondition) { DeviceHandleInitReturnTaskArgs task_args{ result, origin_proc, origin_result_ptr}; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index 86a576d26b..cd5608ca7e 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -1,4 +1,5 @@ #include "realm-execution/tasks/impl/device_handle_init_task.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include @@ -14,8 +15,7 @@ struct DeviceHandleInitTaskArgs { size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Processor origin_proc, - DeviceSpecific> - *origin_result_ptr) + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) : workSpaceSize(workSpaceSize), allowTensorOpMathConversion(allowTensorOpMathConversion), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} @@ -24,7 +24,7 @@ struct DeviceHandleInitTaskArgs { size_t workSpaceSize; bool allowTensorOpMathConversion; Realm::Processor origin_proc; - DeviceSpecific> *origin_result_ptr; + DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; }; static_assert(std::is_trivially_copy_constructible_v); @@ -60,8 +60,8 @@ void device_handle_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; - DeviceSpecific> managed_handle = - DeviceSpecific>::create( + DeviceSpecificManagedPerDeviceFFHandle managed_handle = + make_device_specific_managed_handle( ctx.get_current_device_idx(), make_device_handle_for_processor( proc, @@ -80,7 +80,7 @@ Realm::Event spawn_device_handle_init_task( Realm::Processor target_proc, size_t workSpaceSize, bool allowTensorOpMathConversion, - DeviceSpecific> *result_ptr, + DeviceSpecificManagedPerDeviceFFHandle *result_ptr, Realm::Event precondition) { DeviceHandleInitTaskArgs task_args{ workSpaceSize, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc index c1bd7c1081..306697e950 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_return_task.cc @@ -3,7 +3,6 @@ namespace FlexFlow { -// FIXME: Can't make this trivially copyable? struct DeviceStateInitReturnTaskArgs { public: DeviceStateInitReturnTaskArgs() = delete; diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index f63efba14b..5a51b1c803 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/impl/device_state_init_task.h" +#include "kernels/device_handle_t.dtg.h" #include "local-execution/device_state_initialization.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" @@ -16,26 +18,28 @@ namespace FlexFlow { // space struct DeviceStateInitTaskArgs { DeviceStateInitTaskArgs() = delete; - DeviceStateInitTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) + DeviceStateInitTaskArgs( + DynamicNodeInvocation const *invocation, + ProfilingSettings const *profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const *iteration_config, + OptimizerAttrs const *optimizer_attrs, + Realm::Processor origin_proc, + DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), - iteration_config(iteration_config), optimizer_attrs(optimizer_attrs), - origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + device_handle(device_handle), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), + origin_result_ptr(origin_result_ptr) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; + DeviceSpecificManagedPerDeviceFFHandle device_handle; FFIterationConfig const *iteration_config; OptimizerAttrs const *optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; -static_assert( - std::has_unique_object_representations_v); void device_state_init_task_body(void const *args, size_t arglen, @@ -50,11 +54,14 @@ void device_state_init_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; + device_handle_t device_handle = + device_handle_t_from_device_specific_managed_handle( + task_args.device_handle, ctx.get_current_device_idx()); DynamicNodeInvocation result_invocation = initialize_node(*task_args.invocation, ctx.get_current_device_allocator(), *task_args.profiling_settings, - ctx.get_current_device_handle(), + device_handle, *task_args.iteration_config, *task_args.optimizer_attrs, ctx.get_current_device_idx()); @@ -71,18 +78,20 @@ void device_state_init_task_body(void const *args, Realm::Event::NO_EVENT); } -std::optional - spawn_device_state_init_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - DeviceSpecificPerDeviceOpState *result_ptr, - Realm::Event precondition) { +std::optional spawn_device_state_init_task( + RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, + DeviceSpecificPerDeviceOpState *result_ptr, + Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ &invocation, &profiling_settings, + device_handle, &iteration_config, &optimizer_attrs, ctx.get_current_processor(), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index 216f0badde..e17973febb 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,5 +1,6 @@ #include "realm-execution/tasks/impl/op_task.h" #include "local-execution/task_execution.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" #include "utils/optional.h" @@ -15,20 +16,22 @@ struct OpTaskArgs { OpTaskArgs() = delete; OpTaskArgs(DynamicNodeInvocation const *invocation, ProfilingSettings const *profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, FFIterationConfig const *iteration_config, std::optional const *optimizer_attrs, Realm::Processor origin_proc) : invocation(invocation), profiling_settings(profiling_settings), - iteration_config(iteration_config), optimizer_attrs(optimizer_attrs) {} + device_handle(device_handle), iteration_config(iteration_config), + optimizer_attrs(optimizer_attrs) {} public: DynamicNodeInvocation const *invocation; ProfilingSettings const *profiling_settings; + DeviceSpecificManagedPerDeviceFFHandle device_handle; FFIterationConfig const *iteration_config; std::optional const *optimizer_attrs; Realm::Processor origin_proc; }; -static_assert(std::has_unique_object_representations_v); void op_task_body(void const *args, size_t arglen, @@ -42,11 +45,14 @@ void op_task_body(void const *args, ASSERT(task_args.origin_proc.address_space() == proc.address_space()); RealmContext ctx{proc}; + device_handle_t device_handle = + device_handle_t_from_device_specific_managed_handle( + task_args.device_handle, ctx.get_current_device_idx()); execute_dynamic_node_invocation( /*invocation=*/*task_args.invocation, /*allocator=*/ctx.get_current_device_allocator(), /*profiling_settings=*/*task_args.profiling_settings, - /*ff_handle=*/ctx.get_current_device_handle(), + /*ff_handle=*/device_handle, /*per_device_op_state=*/ transform(task_args.invocation->node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { @@ -58,15 +64,18 @@ void op_task_body(void const *args, /*device_idx=*/ctx.get_current_device_idx()); } -Realm::Event spawn_op_task(RealmContext &ctx, - Realm::Processor target_proc, - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - FFIterationConfig const &iteration_config, - std::optional const &optimizer_attrs, - Realm::Event precondition) { +Realm::Event + spawn_op_task(RealmContext &ctx, + Realm::Processor target_proc, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, + DeviceSpecificManagedPerDeviceFFHandle const &device_handle, + FFIterationConfig const &iteration_config, + std::optional const &optimizer_attrs, + Realm::Event precondition) { OpTaskArgs task_args{&invocation, &profiling_settings, + device_handle, &iteration_config, &optimizer_attrs, ctx.get_current_processor()}; From 37beaa47a178dbc19978a241911fa7691673f071 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 17:02:00 -0800 Subject: [PATCH 48/63] Test distributed device handle. --- .../realm-execution/distributed_device_handle.h | 6 ++++-- .../realm-execution/distributed_device_handle.cc | 7 ++++--- .../realm-execution/tasks/realm_task_registry.cc | 14 ++++++++++++++ .../test/src/realm-execution/realm_manager.cc | 14 ++++++++++++-- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index 3f55c47192..40f3b98fb3 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -13,14 +13,16 @@ struct DistributedDeviceHandle { public: DistributedDeviceHandle() = delete; explicit DistributedDeviceHandle( - std::unordered_map const + std::unordered_map const &handles); DeviceSpecificManagedPerDeviceFFHandle const & at(Realm::Processor processor) const; private: - std::unordered_map handles; + std::unordered_map + handles; }; DistributedDeviceHandle create_distributed_device_handle( diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 404feb014c..3cd01f292e 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -6,8 +6,8 @@ namespace FlexFlow { DistributedDeviceHandle::DistributedDeviceHandle( - std::unordered_map const - &handles) + std::unordered_map const &handles) : handles(handles) {} DeviceSpecificManagedPerDeviceFFHandle const & @@ -20,7 +20,8 @@ DistributedDeviceHandle size_t workSpaceSize, bool allowTensorOpMathConversion, Realm::Event precondition) { - std::unordered_map handles; + std::unordered_map + handles; // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index 9150ce6892..cff12c2391 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -1,5 +1,7 @@ #include "realm-execution/tasks/realm_task_registry.h" #include "realm-execution/tasks/impl/controller_task.h" +#include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/impl/device_handle_init_task.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" #include "realm-execution/tasks/impl/device_state_init_task.h" #include "realm-execution/tasks/impl/op_task.h" @@ -125,6 +127,18 @@ Realm::Event register_all_tasks() { pending_registrations.push_back(register_task(Realm::Processor::LOC_PROC, task_id_t::CONTROLLER_TASK_ID, controller_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + device_handle_init_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::TOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_TASK_ID, + device_handle_init_task_body)); + pending_registrations.push_back( + register_task(Realm::Processor::LOC_PROC, + task_id_t::DEVICE_HANDLE_INIT_RETURN_TASK_ID, + device_handle_init_return_task_body)); pending_registrations.push_back( register_task(Realm::Processor::LOC_PROC, task_id_t::DEVICE_STATE_INIT_RETURN_TASK_ID, diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 94e0d7d0f4..41fa63f4f9 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "realm-execution/distributed_device_handle.h" #include using namespace ::FlexFlow; @@ -16,8 +17,17 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - FlexFlow::Realm::Event event = manager.start_controller( - [&](RealmContext &ctx) { ASSERT(some_data == 123); }); + FlexFlow::Realm::Event event = + manager.start_controller([&](RealmContext &ctx) { + // Data is captured and retains value + ASSERT(some_data == 123); + + // Launch some basic task to ensure everything works + DistributedDeviceHandle handle = create_distributed_device_handle( + /*ctx=*/ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + }); // Need to block on the completion of the event to ensure we don't race event.wait(); } From c61604065b0ff1d952a39e27891d0e6448e38cf6 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Thu, 12 Feb 2026 20:54:38 -0800 Subject: [PATCH 49/63] Guard the kinds of procs we run on. --- .../src/realm-execution/distributed_device_handle.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc index 3cd01f292e..87376be9b1 100644 --- a/lib/realm-execution/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/src/realm-execution/distributed_device_handle.cc @@ -26,9 +26,12 @@ DistributedDeviceHandle // Allocate space for the result before launching any tasks Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); for (Realm::Processor proc : pq) { - handles.insert({proc, - make_device_specific_managed_handle( - ctx.get_current_device_idx(), std::nullopt)}); + if (proc.kind() == Realm::Processor::LOC_PROC || + proc.kind() == Realm::Processor::TOC_PROC) { + handles.insert({proc, + make_device_specific_managed_handle( + ctx.get_current_device_idx(), std::nullopt)}); + } } for (auto &[proc, handle] : handles) { From 26046eeca4a32399a79108abcc12cb98a03e2962 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 10:30:17 -0800 Subject: [PATCH 50/63] Switch to own DeviceSpecific implementation with raw pointers. --- ...pecific_managed_per_device_ff_handle.dtg.toml | 16 ---------------- ...evice_specific_managed_per_device_ff_handle.h | 14 +++++++++++++- .../realm-execution/distributed_device_handle.h | 2 +- .../tasks/impl/device_handle_init_return_task.h | 2 +- .../tasks/impl/device_handle_init_task.h | 2 +- .../tasks/impl/device_state_init_task.h | 2 +- .../include/realm-execution/tasks/impl/op_task.h | 2 +- ...vice_specific_managed_per_device_ff_handle.cc | 16 ++++++++++++---- 8 files changed, 30 insertions(+), 26 deletions(-) delete mode 100644 lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml deleted file mode 100644 index 1458adcba3..0000000000 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.dtg.toml +++ /dev/null @@ -1,16 +0,0 @@ -namespace = "FlexFlow" -name = "DeviceSpecificManagedPerDeviceFFHandle" -type = "struct" -features = [ - "eq", -] - -includes = [ - "", - "kernels/managed_per_device_ff_handle.h", - "task-spec/device_specific.h", -] - -[[fields]] -name = "handle" -type = "::FlexFlow::DeviceSpecific>" diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index eefa6c86ac..19a70491a2 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,10 +4,22 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" namespace FlexFlow { +struct DeviceSpecificManagedPerDeviceFFHandle { +public: + DeviceSpecificManagedPerDeviceFFHandle() = delete; + explicit DeviceSpecificManagedPerDeviceFFHandle( + device_id_t owner, std::optional handle); + + std::optional get(device_id_t device_idx) const; + +private: + device_id_t owner; + std::optional handle; +}; + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &, std::optional const &); diff --git a/lib/realm-execution/include/realm-execution/distributed_device_handle.h b/lib/realm-execution/include/realm-execution/distributed_device_handle.h index 40f3b98fb3..268be3583d 100644 --- a/lib/realm-execution/include/realm-execution/distributed_device_handle.h +++ b/lib/realm-execution/include/realm-execution/distributed_device_handle.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_DISTRIBUTED_DEVICE_HANDLE_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/hash/processor.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h index 9bae546403..a87652b5ce 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_return_task.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_RETURN_TASK_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h index 624eb6e682..312ed26add 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H #define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_DEVICE_HANDLE_INIT_TASK_H -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h index 933d4f9283..4ed8c1726d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task.h @@ -3,7 +3,7 @@ #include "kernels/profiling_settings.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h index 847154192a..9d4c2fd451 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task.h @@ -4,7 +4,7 @@ #include "kernels/profiling_settings.dtg.h" #include "op-attrs/ops/loss_functions/loss_attrs.dtg.h" #include "pcg/optimizer_attrs.dtg.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.dtg.h" +#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/realm.h" #include "realm-execution/realm_context.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index 440b9d18f7..99ff7a6dd6 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -3,19 +3,27 @@ namespace FlexFlow { +DeviceSpecificManagedPerDeviceFFHandle::DeviceSpecificManagedPerDeviceFFHandle( + device_id_t owner, std::optional handle) + : owner(owner), handle(handle) {} + +std::optional + DeviceSpecificManagedPerDeviceFFHandle::get(device_id_t device_idx) const { + ASSERT(this->owner == device_idx); + return this->handle; +} + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &device_id, std::optional const &managed_handle) { - return DeviceSpecificManagedPerDeviceFFHandle{ - DeviceSpecific>::create( - device_id, managed_handle)}; + return DeviceSpecificManagedPerDeviceFFHandle{device_id, managed_handle}; } device_handle_t device_handle_t_from_device_specific_managed_handle( DeviceSpecificManagedPerDeviceFFHandle const &device_specific, device_id_t device_idx) { return device_handle_t_from_managed_handle_ptr( - *device_specific.handle.get(device_idx)); + *device_specific.get(device_idx)); } } // namespace FlexFlow From 12c494092954f8f34e9ef5d4b0cc433d8575fed0 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 10:57:11 -0800 Subject: [PATCH 51/63] Separate device handle test. --- .../distributed_device_handle.cc | 38 +++++++++++++++++++ .../test/src/realm-execution/realm_manager.cc | 16 ++++---- .../test/src/realm-execution/test_e2e.cc | 5 +++ 3 files changed, 51 insertions(+), 8 deletions(-) create mode 100644 lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc diff --git a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc new file mode 100644 index 0000000000..5a5402a140 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc @@ -0,0 +1,38 @@ +#include "realm-execution/distributed_device_handle.h" +#include "realm-execution/realm_manager.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("DistributedDeviceHandle") { + // Construct some fake command line for our test + char fake_executable_name[] = "fake_executable_name"; + char arg0[] = "-ll:cpu"; + char arg1[] = "2"; + std::vector fake_args{fake_executable_name, arg0, arg1}; + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager(&fake_argc, &fake_argv); + + (void)manager.start_controller([](RealmContext &ctx) { + DistributedDeviceHandle handle = create_distributed_device_handle( + /*ctx=*/ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + // Make sure we have handles for the processors we're expecting + Realm::Machine::ProcessorQuery pq(Realm::Machine::get_machine()); + pq.only_kind(Realm::Processor::LOC_PROC); + for (Realm::Processor proc : pq) { + handle.at(proc); + } + }); + } +} + +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 41fa63f4f9..5fe659cdc2 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -2,7 +2,10 @@ #include "realm-execution/distributed_device_handle.h" #include +namespace test { + using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmManager") { @@ -17,18 +20,15 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - FlexFlow::Realm::Event event = + Realm::Event event = manager.start_controller([&](RealmContext &ctx) { // Data is captured and retains value ASSERT(some_data == 123); - - // Launch some basic task to ensure everything works - DistributedDeviceHandle handle = create_distributed_device_handle( - /*ctx=*/ctx, - /*workSpaceSize=*/1024 * 1024, - /*allowTensorOpMathConversion=*/true); }); - // Need to block on the completion of the event to ensure we don't race + // Need to block on the completion of the event to ensure we don't race, + // because the lambda captures the environment event.wait(); } } + +} // namespace test diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 37f1a9b42c..9592cb221c 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -2,7 +2,10 @@ #include "realm-execution/realm_manager.h" #include +namespace test { + using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training") { @@ -14,3 +17,5 @@ TEST_SUITE(FF_TEST_SUITE) { (void)manager.start_controller([](RealmContext &ctx) {}); } } + +} // namespace test From 1b424610c9464626f28ebe7c2f20fcf9a88010a1 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 12:26:57 -0800 Subject: [PATCH 52/63] More work on Realm tests. --- .../parallel_computation_graph.h | 4 + .../parallel_computation_graph.cc | 21 +++ .../pcg_instance/pcg_instance.h | 32 +++- .../pcg_instance/pcg_instance.cc | 10 +- .../test/src/internal/realm_test_utils.cc | 28 +++ .../test/src/internal/realm_test_utils.h | 15 ++ .../distributed_device_handle.cc | 8 +- .../test/src/realm-execution/realm_manager.cc | 15 +- .../test/src/realm-execution/test_e2e.cc | 173 +++++++++++++++++- 9 files changed, 283 insertions(+), 23 deletions(-) create mode 100644 lib/realm-execution/test/src/internal/realm_test_utils.cc create mode 100644 lib/realm-execution/test/src/internal/realm_test_utils.h diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 3d948ac107..21f33f6d3d 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -32,6 +32,10 @@ ParallelLayerAddedResult add_parallel_layer( ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, TensorShape const &tensor_shape); +ParallelLayerAddedResult + pcg_add_input_layer_with_grad(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape); + OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index 907dc05620..959747dbc7 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -142,6 +142,27 @@ ParallelLayerAddedResult pcg_add_input_layer(ParallelComputationGraph &pcg, }); } +ParallelLayerAddedResult + pcg_add_input_layer_with_grad(ParallelComputationGraph &pcg, + TensorShape const &tensor_shape) { + ParallelLayerAttrs layer_attrs = ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{InputAttrs{tensor_shape}}, + /*name=*/std::nullopt, + }; + + return add_parallel_layer(/*pcg=*/pcg, + /*layer_attrs=*/layer_attrs, + /*inputs=*/{}, + /*weights=*/{}, + /*output_flags=*/ + std::unordered_map{ + { + TensorSlotName::OUTPUT, + CreateGrad::YES, + }, + }); +} + OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer) { PCGOperatorAttrs op_attrs = pcg_get_op_attrs(pcg, layer); diff --git a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h index b917477df4..b0037f51b2 100644 --- a/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h +++ b/lib/realm-execution/include/realm-execution/pcg_instance/pcg_instance.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_PCG_INSTANCE_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_PCG_INSTANCE_PCG_INSTANCE_H #include "kernels/accessor.h" #include "kernels/allocation.h" @@ -57,6 +57,34 @@ PCGInstance create_pcg_instance( DistributedDeviceHandle const &device_handle, FFIterationConfig const &iteration_config); +std::unordered_map + perform_all_passes_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_forward_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_backward_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + +std::unordered_map + perform_update_pass_for_pcg_instance( + PCGInstance &instance, + ProfilingSettings const &profiling_settings, + DistributedDeviceHandle const &device_handle, + FFIterationConfig iteration_config); + } // namespace FlexFlow #endif diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index c79d8e8abd..de7cdcb687 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -47,7 +47,7 @@ std::optional return this->logit_grad_tensor; } -PCGInstance create_parallel_computation_graph_instance( +PCGInstance create_pcg_instance( RealmContext &ctx, MappedParallelComputationGraph const &mpcg, OptimizerAttrs const &optimizer_attrs, @@ -160,7 +160,7 @@ static std::unordered_map } std::unordered_map - perform_all_passes_for_parallel_computation_graph_instance( + perform_all_passes_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -180,7 +180,7 @@ std::unordered_map } std::unordered_map - perform_forward_pass_for_parallel_computation_graph_instance( + perform_forward_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -203,7 +203,7 @@ std::unordered_map } std::unordered_map - perform_backward_pass_for_parallel_computation_graph_instance( + perform_backward_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, @@ -226,7 +226,7 @@ std::unordered_map } std::unordered_map - perform_update_pass_for_parallel_computation_graph_instance( + perform_update_pass_for_pcg_instance( PCGInstance &instance, ProfilingSettings const &profiling_settings, DistributedDeviceHandle const &device_handle, diff --git a/lib/realm-execution/test/src/internal/realm_test_utils.cc b/lib/realm-execution/test/src/internal/realm_test_utils.cc new file mode 100644 index 0000000000..e381feb8de --- /dev/null +++ b/lib/realm-execution/test/src/internal/realm_test_utils.cc @@ -0,0 +1,28 @@ +#include "internal/realm_test_utils.h" +#include +#include + +namespace FlexFlow { + +static char *leak_string_contents(std::string const &str) { + // Realm command-line arguments require char* so intentionally leak the + // allocated string contents here + std::vector *content = new std::vector{str.begin(), str.end()}; + content->push_back(0); // NUL byte + return content->data(); +} + +std::vector make_fake_realm_args(positive_int num_cpus, + nonnegative_int num_gpus) { + std::vector result; + result.push_back(leak_string_contents("fake_executable_name")); + result.push_back(leak_string_contents("-ll:cpu")); + result.push_back(leak_string_contents(fmt::to_string(num_cpus))); + if (num_gpus > 0) { + result.push_back(leak_string_contents("-ll:gpu")); + result.push_back(leak_string_contents(fmt::to_string(num_gpus))); + } + return result; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/test/src/internal/realm_test_utils.h b/lib/realm-execution/test/src/internal/realm_test_utils.h new file mode 100644 index 0000000000..8e2775ad8b --- /dev/null +++ b/lib/realm-execution/test/src/internal/realm_test_utils.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_TEST_SRC_INTERNAL_REALM_TEST_UTILS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_TEST_SRC_INTERNAL_REALM_TEST_UTILS_H + +#include "utils/nonnegative_int/nonnegative_int.h" +#include "utils/positive_int/positive_int.h" +#include + +namespace FlexFlow { + +std::vector make_fake_realm_args(positive_int num_cpus, + nonnegative_int num_gpus); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc index 5a5402a140..fb7dff01e3 100644 --- a/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc +++ b/lib/realm-execution/test/src/realm-execution/distributed_device_handle.cc @@ -1,4 +1,5 @@ #include "realm-execution/distributed_device_handle.h" +#include "internal/realm_test_utils.h" #include "realm-execution/realm_manager.h" #include @@ -9,11 +10,8 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("DistributedDeviceHandle") { - // Construct some fake command line for our test - char fake_executable_name[] = "fake_executable_name"; - char arg0[] = "-ll:cpu"; - char arg1[] = "2"; - std::vector fake_args{fake_executable_name, arg0, arg1}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); diff --git a/lib/realm-execution/test/src/realm-execution/realm_manager.cc b/lib/realm-execution/test/src/realm-execution/realm_manager.cc index 5fe659cdc2..450d7fd3ec 100644 --- a/lib/realm-execution/test/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/test/src/realm-execution/realm_manager.cc @@ -1,4 +1,5 @@ #include "realm-execution/realm_manager.h" +#include "internal/realm_test_utils.h" #include "realm-execution/distributed_device_handle.h" #include @@ -9,9 +10,8 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmManager") { - // Construct some fake command line for our test - char fake_executable_name[] = "fake_executable_name"; - std::vector fake_args{fake_executable_name}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); @@ -20,11 +20,10 @@ TEST_SUITE(FF_TEST_SUITE) { // Launch a controller int some_data = 123; - Realm::Event event = - manager.start_controller([&](RealmContext &ctx) { - // Data is captured and retains value - ASSERT(some_data == 123); - }); + Realm::Event event = manager.start_controller([&](RealmContext &ctx) { + // Data is captured and retains value + ASSERT(some_data == 123); + }); // Need to block on the completion of the event to ensure we don't race, // because the lambda captures the environment event.wait(); diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 9592cb221c..33ad2bbbc1 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,5 +1,12 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_device_handle.h" #include "realm-execution/pcg_instance/pcg_instance.h" #include "realm-execution/realm_manager.h" +#include "utils/containers/require_only_key.h" #include namespace test { @@ -9,12 +16,172 @@ namespace Realm = ::FlexFlow::Realm; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("RealmBackend e2e Training") { - char fake_executable_name[] = "fake_executable_name"; - std::vector fake_args{fake_executable_name}; + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/0_n); int fake_argc = fake_args.size(); char **fake_argv = fake_args.data(); + RealmManager manager(&fake_argc, &fake_argv); - (void)manager.start_controller([](RealmContext &ctx) {}); + + (void)manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + GenericTensorAccessorW label_tensor_backing = + allocator.allocate_tensor(output_tensor_shape); + + // construct computation graph + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + GenericTensorAccessorW label_tensor = + allocator.allocate_tensor(label_tensor_shape); + + TensorShape weight_shape_1 = TensorShape{ + TensorDims{FFOrdered{hidden_dim, data_dim}}, DataType::FLOAT}; + TensorShape weight_shape_2 = TensorShape{ + TensorDims{FFOrdered{output_dim, hidden_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer_with_grad(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_1, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_1 = + require_only_key(weights_layer_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult weights_layer_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{ + PCGOperatorAttrs{WeightAttrs{ + weight_shape_2, InitializerAttrs{GlorotNormalAttrs{0}}}}, + std::nullopt}, + {}, + {}); + parallel_tensor_guid_t t_weights_2 = + require_only_key(weights_layer_2.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_1 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{hidden_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_input, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_1, + }, + }); + parallel_tensor_guid_t t_linear_1 = + require_only_key(linear_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult linear_operator_2 = add_parallel_layer( + pcg, + ParallelLayerAttrs{PCGOperatorAttrs{LinearAttrs{output_dim, + /*use_bias=*/false, + DataType::FLOAT, + Activation::RELU, + std::nullopt}}, + std::nullopt}, + { + { + TensorSlotName::INPUT, + t_linear_1, + }, + }, + { + { + TensorSlotName::WEIGHT, + t_weights_2, + }, + }); + parallel_tensor_guid_t t_linear_2 = + require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); + + MappedParallelComputationGraph mpcg{pcg, {}}; + + // instantiate computation graph + LossAttrs loss_attrs = LossAttrs{ + NonconfigurableLossAttrs{LossFunction::CATEGORICAL_CROSSENTROPY}}; + OptimizerAttrs optimizer_attrs = + OptimizerAttrs{SGDOptimizerAttrs{/*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001}}; + + std::unordered_map + input_tensors; + + DistributedDeviceHandle device_handle = create_distributed_device_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/loss_attrs, + /*label_tensor=*/label_tensor, + /*logit_tensor=*/t_linear_2, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 5; + std::vector loss_values; + + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + // loss_values.push_back(copy_tensor_accessor_r( + // pcg_instance.get_loss_tensor_accessor().value(), + // allocator)); + } + + // // Assert that each sample in the batch has a lower loss in last epoch + // // than the first epoch + // GenericTensorAccessorR first_epoch_loss = loss_values.at(0); + // GenericTensorAccessorR last_epoch_loss = loss_values.back(); + // CHECK_MESSAGE(did_loss_decrease(first_epoch_loss, last_epoch_loss), + // check_kv("first_epoch_loss", + // format_accessor_r_contents(first_epoch_loss)), + // check_kv("last_epoch_loss", + // format_accessor_r_contents(last_epoch_loss))); + }); } } From 657a9f9b25166f2f4cd8be40dde5fae658fd00bb Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 16:55:22 -0800 Subject: [PATCH 53/63] JSON serialization of a bunch of data types. --- lib/pcg/include/pcg/layer_guid_t.dtg.toml | 1 + .../mapped_operator_task_group.h | 12 ++++++ .../parallel_layer_guid_t.dtg.toml | 1 + .../mapped_operator_task_group.cc | 17 ++++++++ .../mapped_operator_task_group.cc | 42 ++++++++++++++++++ .../dynamic_layer_guid_t.dtg.toml | 1 + .../serializable_dynamic_node_attrs.dtg.toml | 43 +++++++++++++++++++ ...ializable_dynamic_node_invocation.dtg.toml | 33 ++++++++++++++ .../serializable_dynamic_value_attrs.dtg.toml | 34 +++++++++++++++ .../training_operation_attrs.dtg.toml | 1 + 10 files changed, 185 insertions(+) create mode 100644 lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml diff --git a/lib/pcg/include/pcg/layer_guid_t.dtg.toml b/lib/pcg/include/pcg/layer_guid_t.dtg.toml index d73cf547da..2f2f7694a0 100644 --- a/lib/pcg/include/pcg/layer_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/layer_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h index 5b1cad5e99..ebfdefa478 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -5,6 +5,7 @@ #include "pcg/machine_space_coordinate.dtg.h" #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "utils/bidict/bidict.h" +#include namespace FlexFlow { @@ -45,4 +46,15 @@ struct hash<::FlexFlow::MappedOperatorTaskGroup> { }; } // namespace std + +namespace nlohmann { + +template <> +struct adl_serializer<::FlexFlow::MappedOperatorTaskGroup> { + static ::FlexFlow::MappedOperatorTaskGroup from_json(json const &j); + static void to_json(json &j, ::FlexFlow::MappedOperatorTaskGroup const &t); +}; + +} // namespace nlohmann + #endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml index 618bcb0dc4..292b361fc8 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc index b96a447383..4436efd727 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -90,3 +90,20 @@ size_t hash<::FlexFlow::MappedOperatorTaskGroup>::operator()( } } // namespace std + +namespace nlohmann { + +::FlexFlow::MappedOperatorTaskGroup + adl_serializer<::FlexFlow::MappedOperatorTaskGroup>::from_json( + json const &j) { + return ::FlexFlow::MappedOperatorTaskGroup{j.template get< + ::FlexFlow::bidict<::FlexFlow::MachineSpaceCoordinate, + ::FlexFlow::OperatorAtomicTaskShardBinding>>()}; +} + +void adl_serializer<::FlexFlow::MappedOperatorTaskGroup>::to_json( + json &j, ::FlexFlow::MappedOperatorTaskGroup const &t) { + j = t.get_shard_bindings(); +} + +} // namespace nlohmann diff --git a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc new file mode 100644 index 0000000000..1c3667afc7 --- /dev/null +++ b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -0,0 +1,42 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "op-attrs/parallel_tensor_space_coordinate.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("adl_serializer") { + bidict + shard_bindings{ + {MachineSpaceCoordinate{0_n, 0_n, DeviceType::CPU}, + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, + ParallelTensorSpaceCoordinate{ + 0_n, 0_n, FFOrdered{1_n, 2_n, 3_n}}}, + }, + }}, + }; + MappedOperatorTaskGroup deserialized{shard_bindings}; + nlohmann::json serialized = shard_bindings; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + MappedOperatorTaskGroup result = serialized; + MappedOperatorTaskGroup correct = deserialized; + + CHECK(result == correct); + } + } +} diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml index c6e6673f33..bd64f52567 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml new file mode 100644 index 0000000000..3c43e1d637 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml @@ -0,0 +1,43 @@ +namespace = "FlexFlow" +name = "SerializableDynamicNodeAttrs" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_task_type.dtg.h", + "pcg/machine_space_coordinate.dtg.h", + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", + "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", + "task-spec/dynamic_graph/training_operation_attrs.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "task_type" +type = "std::optional<::FlexFlow::DynamicTaskType>" + +[[fields]] +name = "device_coord" +type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "mapping" +type = "std::optional<::FlexFlow::MappedOperatorTaskGroup>" + +[[fields]] +name = "op_attrs" +type = "std::optional<::FlexFlow::TrainingOperationAttrs>" + +[[fields]] +name = "layer_guid" +type = "::FlexFlow::dynamic_layer_guid_t" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml new file mode 100644 index 0000000000..01f4cc8876 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "SerializableDynamicNodeInvocation" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.h", +] + +src_includes = [ + "utils/hash/unordered_map.h", + "utils/fmt/unordered_map.h", +] + +[[fields]] +name = "inputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" + +[[fields]] +name = "node_attrs" +type = "::FlexFlow::SerializableDynamicNodeAttrs" + +[[fields]] +name = "outputs" +type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml new file mode 100644 index 0000000000..05864b4b47 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -0,0 +1,34 @@ +namespace = "FlexFlow" +name = "SerializableDynamicValueAttrs" +type = "struct" +features = [ + "eq", + "hash", + "fmt", + "json", +] + +includes = [ + "", + "task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.h", + "op-attrs/parallel_tensor_shape.dtg.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "parallel_tensor_shape" +type = "std::optional<::FlexFlow::ParallelTensorShape>" + +[[fields]] +name = "shard_coord" +type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" + +[[fields]] +name = "role" +type = "std::optional<::FlexFlow::DynamicTensorRole>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 66c475b3a9..1051d8ac13 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ From bed8e8a090c6ddd59d0593f61225d849a86dad35 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 17:20:49 -0800 Subject: [PATCH 54/63] Make more stuff serializable. --- .../parallel_tensor_guid_t.dtg.toml | 1 + lib/pcg/include/pcg/tensor_guid_t.dtg.toml | 1 + .../dynamic_tensor_guid_t.dtg.toml | 1 + .../serializable_dynamic_value_attrs.dtg.toml | 4 +++ .../serializable_dynamic_value_attrs.h | 16 +++++++++++ .../serializable_dynamic_value_attrs.cc | 27 +++++++++++++++++++ .../kwarg_dataflow_output.dtg.toml | 1 + 7 files changed, 51 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml index 4494a31ac2..2710a15664 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/pcg/include/pcg/tensor_guid_t.dtg.toml b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml index 151f7b1f0f..e8caf0021f 100644 --- a/lib/pcg/include/pcg/tensor_guid_t.dtg.toml +++ b/lib/pcg/include/pcg/tensor_guid_t.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml index 75e9099104..c9171b928b 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "json", ] includes = [ diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml index 05864b4b47..6209bfa247 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -21,6 +21,10 @@ src_includes = [ "utils/json/optional.h", ] +[[fields]] +name = "tensor_guid" +type = "::FlexFlow::dynamic_tensor_guid_t" + [[fields]] name = "parallel_tensor_shape" type = "std::optional<::FlexFlow::ParallelTensorShape>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h new file mode 100644 index 0000000000..6272265b7e --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_VALUE_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_VALUE_ATTRS_H + +#include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.h" + +namespace FlexFlow { + +SerializableDynamicValueAttrs + dynamic_value_attrs_to_serializable(DynamicValueAttrs const &); +DynamicValueAttrs dynamic_value_attrs_from_serializable( + SerializableDynamicValueAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc new file mode 100644 index 0000000000..2dc0b509ab --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_value_attrs.cc @@ -0,0 +1,27 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include + +namespace FlexFlow { + +SerializableDynamicValueAttrs + dynamic_value_attrs_to_serializable(DynamicValueAttrs const &attrs) { + return SerializableDynamicValueAttrs{ + /*tensor_guid=*/attrs.tensor_guid, + /*parallel_tensor_shape=*/attrs.parallel_tensor_shape, + /*shard_coord=*/attrs.shard_coord, + /*role=*/attrs.role, + }; +} + +DynamicValueAttrs dynamic_value_attrs_from_serializable( + SerializableDynamicValueAttrs const &attrs) { + return DynamicValueAttrs{ + /*tensor_guid=*/attrs.tensor_guid, + /*parallel_tensor_shape=*/attrs.parallel_tensor_shape, + /*shard_coord=*/attrs.shard_coord, + /*accessor=*/std::nullopt, + /*role=*/attrs.role, + }; +} + +} // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml index f286fb90a7..5b537eac88 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output.dtg.toml @@ -6,6 +6,7 @@ features = [ "ord", "hash", "fmt", + "json", ] template_params = [ From 374e4b6f5c6f104c18cc0301d4d21d61e5bd7635 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Fri, 13 Feb 2026 22:18:55 -0800 Subject: [PATCH 55/63] To-do notes. --- .../src/realm-execution/pcg_instance/pcg_instance.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index de7cdcb687..199f2dc090 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -111,6 +111,10 @@ PCGInstance create_pcg_instance( // * external instances // * task argument serializer // * copies + // * parallel operator implementation (partition, reduce, gather, etc.) + // * and fused parallel operators (reduce + broadcast = allreduce) + // * memory-optimizing compiler integration (tensor creation/destruction, + // tensor reuse) } static std::unordered_map From 31afd42affe23347cddd3e26429f1e98fe596697 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:17:15 -0800 Subject: [PATCH 56/63] More serialization routines. --- .../serializable_dynamic_node_attrs.h | 16 ++++++++++ .../serializable_dynamic_node_invocation.h | 16 ++++++++++ .../serializable_dynamic_node_attrs.cc | 29 +++++++++++++++++ .../serializable_dynamic_node_invocation.cc | 31 +++++++++++++++++++ 4 files changed, 92 insertions(+) create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h create mode 100644 lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc create mode 100644 lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h new file mode 100644 index 0000000000..7a274a1e7b --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_ATTRS_H + +#include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.h" + +namespace FlexFlow { + +SerializableDynamicNodeAttrs + dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &); +DynamicNodeAttrs + dynamic_node_attrs_from_serializable(SerializableDynamicNodeAttrs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h new file mode 100644 index 0000000000..2bcdb9a898 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_INVOCATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_SERIALIZABLE_DYNAMIC_NODE_INVOCATION_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +SerializableDynamicNodeInvocation + dynamic_node_invocation_to_serializable(DynamicNodeInvocation const &); +DynamicNodeInvocation dynamic_node_invocation_from_serializable( + SerializableDynamicNodeInvocation const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc new file mode 100644 index 0000000000..d613194d14 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc @@ -0,0 +1,29 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.h" +#include + +namespace FlexFlow { + +SerializableDynamicNodeAttrs + dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &attrs) { + return SerializableDynamicNodeAttrs{ + /*task_type=*/attrs.task_type, + /*device_coord=*/attrs.device_coord, + /*mapping=*/attrs.mapping, + /*op_attrs=*/attrs.op_attrs, + /*layer_guid=*/attrs.layer_guid, + }; +} + +DynamicNodeAttrs dynamic_node_attrs_from_serializable( + SerializableDynamicNodeAttrs const &attrs) { + return DynamicNodeAttrs{ + /*task_type=*/attrs.task_type, + /*device_coord=*/attrs.device_coord, + /*mapping=*/attrs.mapping, + /*op_attrs=*/attrs.op_attrs, + /*layer_guid=*/attrs.layer_guid, + /*per_device_op_state=*/std::nullopt, + }; +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc new file mode 100644 index 0000000000..334623ee67 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_invocation.cc @@ -0,0 +1,31 @@ +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_attrs.h" +#include "task-spec/dynamic_graph/serializable_dynamic_value_attrs.h" +#include "utils/containers/map_values.h" + +namespace FlexFlow { + +SerializableDynamicNodeInvocation dynamic_node_invocation_to_serializable( + DynamicNodeInvocation const &invocation) { + return SerializableDynamicNodeInvocation{ + /*inputs=*/map_values(invocation.inputs, + dynamic_value_attrs_to_serializable), + /*node_attrs=*/dynamic_node_attrs_to_serializable(invocation.node_attrs), + /*outputs=*/ + map_values(invocation.outputs, dynamic_value_attrs_to_serializable), + }; +} + +DynamicNodeInvocation dynamic_node_invocation_from_serializable( + SerializableDynamicNodeInvocation const &invocation) { + return DynamicNodeInvocation{ + /*inputs=*/map_values(invocation.inputs, + dynamic_value_attrs_from_serializable), + /*node_attrs=*/ + dynamic_node_attrs_from_serializable(invocation.node_attrs), + /*outputs=*/ + map_values(invocation.outputs, dynamic_value_attrs_from_serializable), + }; +} + +} // namespace FlexFlow From f47303321b8e30f3e2257ffcdd2c16bfd3768882 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:18:49 -0800 Subject: [PATCH 57/63] Most of serializer finished. --- .../serializable_realm_processor.dtg.toml | 17 ++++++ .../serializer/serializable_realm_processor.h | 16 +++++ .../tasks/serializer/task_arg_serializer.h | 26 ++++++++ .../tasks/impl/device_state_init_task.cc | 61 +++++++++++++------ .../serializable_realm_processor.cc | 15 +++++ 5 files changed, 115 insertions(+), 20 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml new file mode 100644 index 0000000000..3cb64d95c1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.dtg.toml @@ -0,0 +1,17 @@ +namespace = "FlexFlow" +name = "SerializableRealmProcessor" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", +] + +[[fields]] +name = "id" +type = "::FlexFlow::Realm::Processor::id_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h new file mode 100644 index 0000000000..6b29b6e223 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_realm_processor.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_PROCESSOR_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_SERIALIZABLE_REALM_PROCESSOR_H + +#include "realm-execution/realm.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" + +namespace FlexFlow { + +SerializableRealmProcessor + realm_processor_to_serializable(Realm::Processor const &); +Realm::Processor + realm_processor_from_serializable(SerializableRealmProcessor const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h new file mode 100644 index 0000000000..fc5abba587 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h @@ -0,0 +1,26 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_TASK_ARG_SERIALIZER_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_SERIALIZER_TASK_ARG_SERIALIZER_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::string serialize_task_args(T const &args) { + nlohmann::json j; + args.serialize(j); + return j.dump(); +} + +template +T deserialize_task_args(void const *args, size_t arglen) { + nlohmann::json j = nlohmann::json::parse( + std::string_view{reinterpret_cast(args), arglen}); + return T::deserialize(j); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 5a51b1c803..0e7730e485 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -3,11 +3,16 @@ #include "local-execution/device_state_initialization.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/device_specific_per_device_op_state.dtg.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" +#include "utils/exception.h" #include "utils/optional.h" +#include #include #include @@ -19,11 +24,11 @@ namespace FlexFlow { struct DeviceStateInitTaskArgs { DeviceStateInitTaskArgs() = delete; DeviceStateInitTaskArgs( - DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, + DynamicNodeInvocation const &invocation, + ProfilingSettings const &profiling_settings, DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const *iteration_config, - OptimizerAttrs const *optimizer_attrs, + FFIterationConfig const &iteration_config, + OptimizerAttrs const &optimizer_attrs, Realm::Processor origin_proc, DeviceSpecificPerDeviceOpState *origin_result_ptr) : invocation(invocation), profiling_settings(profiling_settings), @@ -31,12 +36,28 @@ struct DeviceStateInitTaskArgs { optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} + void serialize(nlohmann::json &j) const { + j = { + {"invocation", dynamic_node_invocation_to_serializable(invocation)}, + {"profiling_settings", profiling_settings}, + // {"device_handle", device_handle}, + {"iteration_config", iteration_config}, + {"optimizer_attrs", optimizer_attrs}, + {"origin_proc", realm_processor_to_serializable(origin_proc)}, + {"origin_result_ptr", reinterpret_cast(origin_result_ptr)}, + }; + } + + static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { + NOT_IMPLEMENTED(); + } + public: - DynamicNodeInvocation const *invocation; - ProfilingSettings const *profiling_settings; + DynamicNodeInvocation invocation; + ProfilingSettings profiling_settings; DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig const *iteration_config; - OptimizerAttrs const *optimizer_attrs; + FFIterationConfig iteration_config; + OptimizerAttrs optimizer_attrs; Realm::Processor origin_proc; DeviceSpecificPerDeviceOpState *origin_result_ptr; }; @@ -46,9 +67,8 @@ void device_state_init_task_body(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceStateInitTaskArgs)); DeviceStateInitTaskArgs task_args = - *reinterpret_cast(args); + deserialize_task_args(args, arglen); // FIXME: serialize instead of passing pointers around ASSERT(task_args.origin_proc.address_space() == proc.address_space()); @@ -58,12 +78,12 @@ void device_state_init_task_body(void const *args, device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); DynamicNodeInvocation result_invocation = - initialize_node(*task_args.invocation, + initialize_node(task_args.invocation, ctx.get_current_device_allocator(), - *task_args.profiling_settings, + task_args.profiling_settings, device_handle, - *task_args.iteration_config, - *task_args.optimizer_attrs, + task_args.iteration_config, + task_args.optimizer_attrs, ctx.get_current_device_idx()); DeviceSpecificPerDeviceOpState result_state = assert_unwrap(result_invocation.node_attrs.per_device_op_state); @@ -89,11 +109,11 @@ std::optional spawn_device_state_init_task( DeviceSpecificPerDeviceOpState *result_ptr, Realm::Event precondition) { DeviceStateInitTaskArgs task_args{ - &invocation, - &profiling_settings, + invocation, + profiling_settings, device_handle, - &iteration_config, - &optimizer_attrs, + iteration_config, + optimizer_attrs, ctx.get_current_processor(), result_ptr, }; @@ -105,10 +125,11 @@ std::optional spawn_device_state_init_task( }), get_init_task_id_for_op_attrs); if (task_id.has_value()) { + std::string args = serialize_task_args(task_args); return ctx.spawn_task(target_proc, assert_unwrap(task_id), - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc new file mode 100644 index 0000000000..b16e2891c4 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_realm_processor.cc @@ -0,0 +1,15 @@ +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" + +namespace FlexFlow { + +SerializableRealmProcessor + realm_processor_to_serializable(Realm::Processor const &proc) { + return SerializableRealmProcessor{proc.id}; +} + +Realm::Processor + realm_processor_from_serializable(SerializableRealmProcessor const &proc) { + return Realm::Processor{proc.id}; +} + +} // namespace FlexFlow From 877fd8a04b1b9db1647b0b738fbd5173a10e2371 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 12:41:51 -0800 Subject: [PATCH 58/63] Finish serialization of device init task. --- ...ce_specific_managed_per_device_ff_handle.h | 6 ++++ ...e_specific_managed_per_device_ff_handle.cc | 28 +++++++++++++++++++ .../tasks/impl/device_state_init_task.cc | 24 ++++++++++++++-- 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index 19a70491a2..45617ffcbf 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,6 +4,8 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" +#include +#include namespace FlexFlow { @@ -15,6 +17,10 @@ struct DeviceSpecificManagedPerDeviceFFHandle { std::optional get(device_id_t device_idx) const; + void serialize(nlohmann::json &j) const; + static DeviceSpecificManagedPerDeviceFFHandle + deserialize(nlohmann::json const &j); + private: device_id_t owner; std::optional handle; diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index 99ff7a6dd6..ea0782fd4b 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -1,5 +1,8 @@ #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "kernels/device_handle_t.h" +#include "utils/containers/transform.h" +#include "utils/json/optional.h" +#include namespace FlexFlow { @@ -13,6 +16,31 @@ std::optional return this->handle; } +void DeviceSpecificManagedPerDeviceFFHandle::serialize( + nlohmann::json &j) const { + j = { + {"owner", owner}, + {"handle", + transform(handle, + [](ManagedPerDeviceFFHandle *ptr) { + return reinterpret_cast(ptr); + })}, + }; +} + +DeviceSpecificManagedPerDeviceFFHandle + DeviceSpecificManagedPerDeviceFFHandle::deserialize( + nlohmann::json const &j) { + return DeviceSpecificManagedPerDeviceFFHandle{ + /*owner=*/j.at("owner").get(), + /*handle=*/ + transform(j.at("handle").get>(), + [](uintptr_t ptrval) { + return reinterpret_cast(ptrval); + }), + }; +} + DeviceSpecificManagedPerDeviceFFHandle make_device_specific_managed_handle( device_id_t const &device_id, std::optional const &managed_handle) { diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 0e7730e485..312c3f2401 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -3,6 +3,7 @@ #include "local-execution/device_state_initialization.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" #include "realm-execution/tasks/serializer/serializable_realm_processor.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" @@ -10,7 +11,6 @@ #include "task-spec/device_specific_per_device_op_state.dtg.h" #include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" -#include "utils/exception.h" #include "utils/optional.h" #include #include @@ -37,10 +37,12 @@ struct DeviceStateInitTaskArgs { origin_result_ptr(origin_result_ptr) {} void serialize(nlohmann::json &j) const { + nlohmann::json j_device_handle; + device_handle.serialize(j_device_handle); j = { {"invocation", dynamic_node_invocation_to_serializable(invocation)}, {"profiling_settings", profiling_settings}, - // {"device_handle", device_handle}, + {"device_handle", j_device_handle}, {"iteration_config", iteration_config}, {"optimizer_attrs", optimizer_attrs}, {"origin_proc", realm_processor_to_serializable(origin_proc)}, @@ -49,7 +51,23 @@ struct DeviceStateInitTaskArgs { } static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { - NOT_IMPLEMENTED(); + return DeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable( + j.at("invocation").get()), + /*profiling_settings=*/ + j.at("profiling_settings").get(), + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize( + j.at("device_handle")), + /*iteration_config=*/j.at("iteration_config").get(), + /*optimizer_attrs=*/j.at("optimizer_attrs").get(), + /*origin_proc=*/ + realm_processor_from_serializable( + j.at("origin_proc").get()), + /*origin_result_ptr=*/ + reinterpret_cast( + j.at("origin_result_ptr").get()), + }; } public: From 0aa0664aacbb7c83f7b8d87daf456eb0b063e97c Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 14:55:30 -0800 Subject: [PATCH 59/63] Switch over to explicit DTGs for task arguments and serialization. --- ...ce_specific_managed_per_device_ff_handle.h | 5 +- .../device_handle_init_task_args.dtg.toml | 26 ++++++ .../impl/device_state_init_task_args.dtg.toml | 42 ++++++++++ ...able_device_handle_init_task_args.dtg.toml | 30 +++++++ ...erializable_device_handle_init_task_args.h | 17 ++++ ...zable_device_state_init_task_args.dtg.toml | 48 +++++++++++ ...serializable_device_state_init_task_args.h | 16 ++++ .../serializable_device_specific_ptr.dtg.toml | 28 +++++++ .../tasks/serializer/task_arg_serializer.h | 5 +- ...e_specific_managed_per_device_ff_handle.cc | 24 +++--- .../pcg_instance/pcg_instance.cc | 1 + .../tasks/impl/device_handle_init_task.cc | 35 ++------ .../tasks/impl/device_state_init_task.cc | 82 ++----------------- ...rializable_device_handle_init_task_args.cc | 28 +++++++ ...erializable_device_state_init_task_args.cc | 36 ++++++++ 15 files changed, 304 insertions(+), 119 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h create mode 100644 lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc diff --git a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h index 45617ffcbf..d48a80f438 100644 --- a/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h +++ b/lib/realm-execution/include/realm-execution/device_specific_managed_per_device_ff_handle.h @@ -4,6 +4,7 @@ #include "kernels/device_handle_t.dtg.h" #include "kernels/managed_per_device_ff_handle.h" #include "pcg/device_id_t.dtg.h" +#include "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h" #include #include @@ -17,9 +18,9 @@ struct DeviceSpecificManagedPerDeviceFFHandle { std::optional get(device_id_t device_idx) const; - void serialize(nlohmann::json &j) const; + SerializableDeviceSpecificPtr serialize() const; static DeviceSpecificManagedPerDeviceFFHandle - deserialize(nlohmann::json const &j); + deserialize(SerializableDeviceSpecificPtr const &j); private: device_id_t owner; diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml new file mode 100644 index 0000000000..c0ba37bb5d --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_handle_init_task_args.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "DeviceHandleInitTaskArgs" +type = "struct" +features = [] + +includes = [ + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_realm_processor.h", +] + +[[fields]] +name = "workSpaceSize" +type = "size_t" + +[[fields]] +name = "allowTensorOpMathConversion" +type = "bool" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::Realm::Processor" + +[[fields]] +name = "origin_result_ptr" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle *" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml new file mode 100644 index 0000000000..a9aa77dde9 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/device_state_init_task_args.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "DeviceStateInitTaskArgs" +type = "struct" +features = [] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "realm-execution/realm.h", + "task-spec/device_specific_per_device_op_state.dtg.h", + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::DynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::Realm::Processor" + +[[fields]] +name = "origin_result_ptr" +type = "::FlexFlow::DeviceSpecificPerDeviceOpState *" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml new file mode 100644 index 0000000000..3a187924c8 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "SerializableDeviceHandleInitTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", +] + +[[fields]] +name = "workSpaceSize" +type = "size_t" + +[[fields]] +name = "allowTensorOpMathConversion" +type = "bool" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::SerializableRealmProcessor" + +[[fields]] +name = "origin_result_ptr" +type = "uintptr_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h new file mode 100644 index 0000000000..b239221c16 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H + +#include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.h" + +namespace FlexFlow { + +SerializableDeviceHandleInitTaskArgs + device_handle_init_task_args_to_serializable( + DeviceHandleInitTaskArgs const &); +DeviceHandleInitTaskArgs device_handle_init_task_args_from_serializable( + SerializableDeviceHandleInitTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml new file mode 100644 index 0000000000..68076b7d70 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -0,0 +1,48 @@ +namespace = "FlexFlow" +name = "SerializableDeviceStateInitTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/realm.h", + "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", + "task-spec/device_specific_per_device_op_state.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::SerializableDynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::SerializableDeviceSpecificPtr" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "::FlexFlow::OptimizerAttrs" + +[[fields]] +name = "origin_proc" +type = "::FlexFlow::SerializableRealmProcessor" + +[[fields]] +name = "origin_result_ptr" +type = "uintptr_t" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h new file mode 100644 index 0000000000..2467f2067c --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H + +#include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.h" + +namespace FlexFlow { + +SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( + DeviceStateInitTaskArgs const &); +DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( + SerializableDeviceStateInitTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml new file mode 100644 index 0000000000..07cf61f7e1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.toml @@ -0,0 +1,28 @@ +namespace = "FlexFlow" +name = "SerializableDeviceSpecificPtr" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "pcg/device_id_t.dtg.h", + "cstdint", + "optional", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "device_idx" +type = "::FlexFlow::device_id_t" + +[[fields]] +name = "ptr" +type = "std::optional" diff --git a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h index fc5abba587..3208368d2d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h +++ b/lib/realm-execution/include/realm-execution/tasks/serializer/task_arg_serializer.h @@ -9,8 +9,7 @@ namespace FlexFlow { template std::string serialize_task_args(T const &args) { - nlohmann::json j; - args.serialize(j); + nlohmann::json j = args; return j.dump(); } @@ -18,7 +17,7 @@ template T deserialize_task_args(void const *args, size_t arglen) { nlohmann::json j = nlohmann::json::parse( std::string_view{reinterpret_cast(args), arglen}); - return T::deserialize(j); + return j.get(); } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc index ea0782fd4b..6e0cef0bb2 100644 --- a/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc +++ b/lib/realm-execution/src/realm-execution/device_specific_managed_per_device_ff_handle.cc @@ -16,25 +16,25 @@ std::optional return this->handle; } -void DeviceSpecificManagedPerDeviceFFHandle::serialize( - nlohmann::json &j) const { - j = { - {"owner", owner}, - {"handle", - transform(handle, - [](ManagedPerDeviceFFHandle *ptr) { - return reinterpret_cast(ptr); - })}, +SerializableDeviceSpecificPtr + DeviceSpecificManagedPerDeviceFFHandle::serialize() const { + return SerializableDeviceSpecificPtr{ + /*device_idx=*/owner, + /*ptr=*/ + transform(handle, + [](ManagedPerDeviceFFHandle *ptr) { + return reinterpret_cast(ptr); + }), }; } DeviceSpecificManagedPerDeviceFFHandle DeviceSpecificManagedPerDeviceFFHandle::deserialize( - nlohmann::json const &j) { + SerializableDeviceSpecificPtr const &handle) { return DeviceSpecificManagedPerDeviceFFHandle{ - /*owner=*/j.at("owner").get(), + /*owner=*/handle.device_idx, /*handle=*/ - transform(j.at("handle").get>(), + transform(handle.ptr, [](uintptr_t ptrval) { return reinterpret_cast(ptrval); }), diff --git a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc index 199f2dc090..8e6ab022aa 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance/pcg_instance.cc @@ -110,6 +110,7 @@ PCGInstance create_pcg_instance( // TODO list: // * external instances // * task argument serializer + // * pass instances to task and convert to tensor accessor // * copies // * parallel operator implementation (partition, reduce, gather, etc.) // * and fused parallel operators (reduce + broadcast = allreduce) diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index cd5608ca7e..5cd53ea062 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -1,33 +1,14 @@ #include "realm-execution/tasks/impl/device_handle_init_task.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_handle_init_return_task.h" +#include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceHandleInitTaskArgs { - DeviceHandleInitTaskArgs() = delete; - DeviceHandleInitTaskArgs( - size_t workSpaceSize, - bool allowTensorOpMathConversion, - Realm::Processor origin_proc, - DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr) - : workSpaceSize(workSpaceSize), - allowTensorOpMathConversion(allowTensorOpMathConversion), - origin_proc(origin_proc), origin_result_ptr(origin_result_ptr) {} - -public: - size_t workSpaceSize; - bool allowTensorOpMathConversion; - Realm::Processor origin_proc; - DeviceSpecificManagedPerDeviceFFHandle *origin_result_ptr; -}; -static_assert(std::is_trivially_copy_constructible_v); - static std::optional make_device_handle_for_processor(Realm::Processor processor, size_t workSpaceSize, @@ -52,12 +33,10 @@ void device_handle_init_task_body(void const *args, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(DeviceHandleInitTaskArgs)); DeviceHandleInitTaskArgs task_args = - *reinterpret_cast(args); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + device_handle_init_task_args_from_serializable( + deserialize_task_args(args, + arglen)); RealmContext ctx{proc}; DeviceSpecificManagedPerDeviceFFHandle managed_handle = @@ -89,6 +68,8 @@ Realm::Event spawn_device_handle_init_task( result_ptr, }; + std::string args = serialize_task_args( + device_handle_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, task_id_t::DEVICE_HANDLE_INIT_TASK_ID, &task_args, diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc index 312c3f2401..99c72cf5e7 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_state_init_task.cc @@ -1,95 +1,26 @@ #include "realm-execution/tasks/impl/device_state_init_task.h" -#include "kernels/device_handle_t.dtg.h" #include "local-execution/device_state_initialization.h" -#include "realm-execution/device_specific_managed_per_device_ff_handle.h" #include "realm-execution/tasks/impl/device_state_init_return_task.h" -#include "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h" -#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" #include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" -#include "task-spec/device_specific_per_device_op_state.dtg.h" -#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" -#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "utils/optional.h" -#include #include #include namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct DeviceStateInitTaskArgs { - DeviceStateInitTaskArgs() = delete; - DeviceStateInitTaskArgs( - DynamicNodeInvocation const &invocation, - ProfilingSettings const &profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const &iteration_config, - OptimizerAttrs const &optimizer_attrs, - Realm::Processor origin_proc, - DeviceSpecificPerDeviceOpState *origin_result_ptr) - : invocation(invocation), profiling_settings(profiling_settings), - device_handle(device_handle), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs), origin_proc(origin_proc), - origin_result_ptr(origin_result_ptr) {} - - void serialize(nlohmann::json &j) const { - nlohmann::json j_device_handle; - device_handle.serialize(j_device_handle); - j = { - {"invocation", dynamic_node_invocation_to_serializable(invocation)}, - {"profiling_settings", profiling_settings}, - {"device_handle", j_device_handle}, - {"iteration_config", iteration_config}, - {"optimizer_attrs", optimizer_attrs}, - {"origin_proc", realm_processor_to_serializable(origin_proc)}, - {"origin_result_ptr", reinterpret_cast(origin_result_ptr)}, - }; - } - - static DeviceStateInitTaskArgs deserialize(nlohmann::json const &j) { - return DeviceStateInitTaskArgs{ - /*invocation=*/dynamic_node_invocation_from_serializable( - j.at("invocation").get()), - /*profiling_settings=*/ - j.at("profiling_settings").get(), - /*device_handle=*/ - DeviceSpecificManagedPerDeviceFFHandle::deserialize( - j.at("device_handle")), - /*iteration_config=*/j.at("iteration_config").get(), - /*optimizer_attrs=*/j.at("optimizer_attrs").get(), - /*origin_proc=*/ - realm_processor_from_serializable( - j.at("origin_proc").get()), - /*origin_result_ptr=*/ - reinterpret_cast( - j.at("origin_result_ptr").get()), - }; - } - -public: - DynamicNodeInvocation invocation; - ProfilingSettings profiling_settings; - DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig iteration_config; - OptimizerAttrs optimizer_attrs; - Realm::Processor origin_proc; - DeviceSpecificPerDeviceOpState *origin_result_ptr; -}; - void device_state_init_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { DeviceStateInitTaskArgs task_args = - deserialize_task_args(args, arglen); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + device_state_init_task_args_from_serializable( + deserialize_task_args(args, + arglen)); RealmContext ctx{proc}; device_handle_t device_handle = @@ -143,7 +74,8 @@ std::optional spawn_device_state_init_task( }), get_init_task_id_for_op_attrs); if (task_id.has_value()) { - std::string args = serialize_task_args(task_args); + std::string args = serialize_task_args( + device_state_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, assert_unwrap(task_id), args.data(), diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc new file mode 100644 index 0000000000..a44a5a5db1 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_handle_init_task_args.cc @@ -0,0 +1,28 @@ +#include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.h" + +namespace FlexFlow { + +SerializableDeviceHandleInitTaskArgs + device_handle_init_task_args_to_serializable( + DeviceHandleInitTaskArgs const &args) { + return SerializableDeviceHandleInitTaskArgs{ + /*workSpaceSize=*/args.workSpaceSize, + /*allowTensorOpMathConversion=*/args.allowTensorOpMathConversion, + /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), + /*origin_result_ptr=*/reinterpret_cast(args.origin_result_ptr), + }; +} + +DeviceHandleInitTaskArgs device_handle_init_task_args_from_serializable( + SerializableDeviceHandleInitTaskArgs const &args) { + return DeviceHandleInitTaskArgs{ + /*workSpaceSize=*/args.workSpaceSize, + /*allowTensorOpMathConversion=*/args.allowTensorOpMathConversion, + /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), + /*origin_result_ptr=*/ + reinterpret_cast( + args.origin_result_ptr), + }; +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc new file mode 100644 index 0000000000..528ff26867 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_device_state_init_task_args.cc @@ -0,0 +1,36 @@ +#include "realm-execution/tasks/impl/serializable_device_state_init_task_args.h" +#include "realm-execution/tasks/serializer/serializable_realm_processor.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" + +namespace FlexFlow { + +SerializableDeviceStateInitTaskArgs device_state_init_task_args_to_serializable( + DeviceStateInitTaskArgs const &args) { + return SerializableDeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/args.device_handle.serialize(), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + /*origin_proc=*/realm_processor_to_serializable(args.origin_proc), + /*origin_result_ptr=*/reinterpret_cast(args.origin_result_ptr), + }; +} + +DeviceStateInitTaskArgs device_state_init_task_args_from_serializable( + SerializableDeviceStateInitTaskArgs const &args) { + return DeviceStateInitTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + /*origin_proc=*/realm_processor_from_serializable(args.origin_proc), + /*origin_result_ptr=*/ + reinterpret_cast( + args.origin_result_ptr), + }; +} + +} // namespace FlexFlow From 5f4cce685b99742fea2a55a3eeeb93823f35be1b Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 15:13:38 -0800 Subject: [PATCH 60/63] Convert op task args. --- .../tasks/impl/op_task_args.dtg.toml | 32 ++++++++++ ...able_device_handle_init_task_args.dtg.toml | 1 - ...erializable_device_handle_init_task_args.h | 4 +- ...zable_device_state_init_task_args.dtg.toml | 1 - ...serializable_device_state_init_task_args.h | 4 +- .../impl/serializable_op_task_args.dtg.toml | 42 +++++++++++++ .../tasks/impl/serializable_op_task_args.h | 14 +++++ .../tasks/impl/device_handle_init_task.cc | 4 +- .../src/realm-execution/tasks/impl/op_task.cc | 60 ++++++------------- .../tasks/impl/serializable_op_task_args.cc | 27 +++++++++ 10 files changed, 139 insertions(+), 50 deletions(-) create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml create mode 100644 lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h create mode 100644 lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml new file mode 100644 index 0000000000..814f9f802b --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/op_task_args.dtg.toml @@ -0,0 +1,32 @@ +namespace = "FlexFlow" +name = "OpTaskArgs" +type = "struct" +features = [] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/device_specific_managed_per_device_ff_handle.h", + "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::DynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::DeviceSpecificManagedPerDeviceFFHandle" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml index 3a187924c8..34f52880f8 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.toml @@ -9,7 +9,6 @@ features = [ ] includes = [ - "realm-execution/realm.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", ] diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h index b239221c16..63d70fe10a 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_handle_init_task_args.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_HANDLE_INIT_TASK_ARGS_H #include "realm-execution/tasks/impl/device_handle_init_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_device_handle_init_task_args.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml index 68076b7d70..c99d2758c0 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.toml @@ -11,7 +11,6 @@ features = [ includes = [ "kernels/profiling_settings.dtg.h", "pcg/optimizer_attrs.dtg.h", - "realm-execution/realm.h", "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", "realm-execution/tasks/serializer/serializable_realm_processor.dtg.h", "task-spec/device_specific_per_device_op_state.dtg.h", diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h index 2467f2067c..f028820974 100644 --- a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_device_state_init_task_args.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H -#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_H +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_DEVICE_STATE_INIT_TASK_ARGS_H #include "realm-execution/tasks/impl/device_state_init_task_args.dtg.h" #include "realm-execution/tasks/impl/serializable_device_state_init_task_args.dtg.h" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml new file mode 100644 index 0000000000..a0f89e3ae2 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.dtg.toml @@ -0,0 +1,42 @@ +namespace = "FlexFlow" +name = "SerializableOpTaskArgs" +type = "struct" +features = [ + "eq", + "fmt", + "hash", + "json", +] + +includes = [ + "kernels/profiling_settings.dtg.h", + "pcg/optimizer_attrs.dtg.h", + "realm-execution/tasks/serializer/serializable_device_specific_ptr.dtg.h", + "task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.h", + "task-spec/ff_iteration_config.dtg.h", +] + +src_includes = [ + "utils/fmt/optional.h", + "utils/json/optional.h", +] + +[[fields]] +name = "invocation" +type = "::FlexFlow::SerializableDynamicNodeInvocation" + +[[fields]] +name = "profiling_settings" +type = "::FlexFlow::ProfilingSettings" + +[[fields]] +name = "device_handle" +type = "::FlexFlow::SerializableDeviceSpecificPtr" + +[[fields]] +name = "iteration_config" +type = "::FlexFlow::FFIterationConfig" + +[[fields]] +name = "optimizer_attrs" +type = "std::optional<::FlexFlow::OptimizerAttrs>" diff --git a/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h new file mode 100644 index 0000000000..3b2d05d0b6 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/tasks/impl/serializable_op_task_args.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_OP_TASK_ARGS_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_TASKS_IMPL_SERIALIZABLE_OP_TASK_ARGS_H + +#include "realm-execution/tasks/impl/op_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_op_task_args.dtg.h" + +namespace FlexFlow { + +SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &); +OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc index 5cd53ea062..b806aa1277 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/device_handle_init_task.cc @@ -72,8 +72,8 @@ Realm::Event spawn_device_handle_init_task( device_handle_init_task_args_to_serializable(task_args)); return ctx.spawn_task(target_proc, task_id_t::DEVICE_HANDLE_INIT_TASK_ID, - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc index e17973febb..d8b8873442 100644 --- a/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc +++ b/lib/realm-execution/src/realm-execution/tasks/impl/op_task.cc @@ -1,6 +1,9 @@ #include "realm-execution/tasks/impl/op_task.h" #include "local-execution/task_execution.h" #include "realm-execution/device_specific_managed_per_device_ff_handle.h" +#include "realm-execution/tasks/impl/op_task_args.dtg.h" +#include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "realm-execution/tasks/serializer/task_arg_serializer.h" #include "realm-execution/tasks/task_id_t.h" #include "task-spec/per_device_op_state.h" #include "utils/optional.h" @@ -8,59 +11,31 @@ namespace FlexFlow { -// TODO: at some point we're going to have to actually serialize these, but for -// now just pass the pointer and assume we're running inside a single address -// space -struct OpTaskArgs { -public: - OpTaskArgs() = delete; - OpTaskArgs(DynamicNodeInvocation const *invocation, - ProfilingSettings const *profiling_settings, - DeviceSpecificManagedPerDeviceFFHandle const &device_handle, - FFIterationConfig const *iteration_config, - std::optional const *optimizer_attrs, - Realm::Processor origin_proc) - : invocation(invocation), profiling_settings(profiling_settings), - device_handle(device_handle), iteration_config(iteration_config), - optimizer_attrs(optimizer_attrs) {} - -public: - DynamicNodeInvocation const *invocation; - ProfilingSettings const *profiling_settings; - DeviceSpecificManagedPerDeviceFFHandle device_handle; - FFIterationConfig const *iteration_config; - std::optional const *optimizer_attrs; - Realm::Processor origin_proc; -}; - void op_task_body(void const *args, size_t arglen, void const *userdata, size_t userlen, Realm::Processor proc) { - ASSERT(arglen == sizeof(OpTaskArgs)); - OpTaskArgs task_args = *reinterpret_cast(args); - - // FIXME: serialize instead of passing pointers around - ASSERT(task_args.origin_proc.address_space() == proc.address_space()); + OpTaskArgs task_args = op_task_args_from_serializable( + deserialize_task_args(args, arglen)); RealmContext ctx{proc}; device_handle_t device_handle = device_handle_t_from_device_specific_managed_handle( task_args.device_handle, ctx.get_current_device_idx()); execute_dynamic_node_invocation( - /*invocation=*/*task_args.invocation, + /*invocation=*/task_args.invocation, /*allocator=*/ctx.get_current_device_allocator(), - /*profiling_settings=*/*task_args.profiling_settings, + /*profiling_settings=*/task_args.profiling_settings, /*ff_handle=*/device_handle, /*per_device_op_state=*/ - transform(task_args.invocation->node_attrs.per_device_op_state, + transform(task_args.invocation.node_attrs.per_device_op_state, [&](DeviceSpecificPerDeviceOpState const &op_state) { return get_device_state_from_device_specific( op_state, ctx.get_current_device_idx()); }), - /*iteration_config=*/*task_args.iteration_config, - /*optimizer_attrs=*/*task_args.optimizer_attrs, + /*iteration_config=*/task_args.iteration_config, + /*optimizer_attrs=*/task_args.optimizer_attrs, /*device_idx=*/ctx.get_current_device_idx()); } @@ -73,17 +48,18 @@ Realm::Event FFIterationConfig const &iteration_config, std::optional const &optimizer_attrs, Realm::Event precondition) { - OpTaskArgs task_args{&invocation, - &profiling_settings, + OpTaskArgs task_args{invocation, + profiling_settings, device_handle, - &iteration_config, - &optimizer_attrs, - ctx.get_current_processor()}; + iteration_config, + optimizer_attrs}; + std::string args = + serialize_task_args(op_task_args_to_serializable(task_args)); return ctx.spawn_task( target_proc, assert_unwrap(get_task_id_for_op(invocation.node_attrs, optimizer_attrs)), - &task_args, - sizeof(task_args), + args.data(), + args.size(), Realm::ProfilingRequestSet{}, precondition); } diff --git a/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc new file mode 100644 index 0000000000..0513bc6df7 --- /dev/null +++ b/lib/realm-execution/src/realm-execution/tasks/impl/serializable_op_task_args.cc @@ -0,0 +1,27 @@ +#include "realm-execution/tasks/impl/serializable_op_task_args.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" + +namespace FlexFlow { + +SerializableOpTaskArgs op_task_args_to_serializable(OpTaskArgs const &args) { + return SerializableOpTaskArgs{ + /*invocation=*/dynamic_node_invocation_to_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/args.device_handle.serialize(), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + }; +} + +OpTaskArgs op_task_args_from_serializable(SerializableOpTaskArgs const &args) { + return OpTaskArgs{ + /*invocation=*/dynamic_node_invocation_from_serializable(args.invocation), + /*profiling_settings=*/args.profiling_settings, + /*device_handle=*/ + DeviceSpecificManagedPerDeviceFFHandle::deserialize(args.device_handle), + /*iteration_config=*/args.iteration_config, + /*optimizer_attrs=*/args.optimizer_attrs, + }; +} + +} // namespace FlexFlow From 81c0d8b13e265f1b53f12af890ad289f0b3c89e9 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:37:07 -0800 Subject: [PATCH 61/63] Map the PCG for test. --- .../test/src/realm-execution/test_e2e.cc | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/lib/realm-execution/test/src/realm-execution/test_e2e.cc b/lib/realm-execution/test/src/realm-execution/test_e2e.cc index 33ad2bbbc1..8e5edf72ad 100644 --- a/lib/realm-execution/test/src/realm-execution/test_e2e.cc +++ b/lib/realm-execution/test/src/realm-execution/test_e2e.cc @@ -1,7 +1,12 @@ #include "internal/realm_test_utils.h" #include "kernels/allocation.h" #include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" #include "realm-execution/distributed_device_handle.h" #include "realm-execution/pcg_instance/pcg_instance.h" @@ -126,7 +131,44 @@ TEST_SUITE(FF_TEST_SUITE) { parallel_tensor_guid_t t_linear_2 = require_only_key(linear_operator_2.outputs, TensorSlotName::OUTPUT); - MappedParallelComputationGraph mpcg{pcg, {}}; + MachineSpaceCoordinate cpu0{0_n, 0_n, DeviceType::CPU}; + ParallelTensorSpaceCoordinate tensor_coord0{0_n, 0_n, FFOrdered{0_n}}; + MappedParallelComputationGraph mpcg{ + pcg, + { + {inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {weights_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{ + {{TensorSlotName::OUTPUT, tensor_coord0}}}}}}}, + {linear_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + {linear_operator_2.parallel_layer, + MappedOperatorTaskGroup{ + {{cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::WEIGHT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}}}}}, + }, + }; // instantiate computation graph LossAttrs loss_attrs = LossAttrs{ From 6803fb7ad61ec1fbaffa3568eb0fc91f793b8328 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:43:29 -0800 Subject: [PATCH 62/63] Fix a bug in shard expansion. --- lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index 33b7fb8591..402e0ef055 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -15,7 +15,7 @@ bool value_is_shard_expanded(DynamicValueAttrs const &n) { bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { auto slot_is_shard_expanded = [](DynamicTensorSlot const &) -> bool { - return true; + return false; }; return no_part_of_dynamic_graph_satisfies(g, From 445bee0fda56e63f37ae71ae827e1ef9b0727f12 Mon Sep 17 00:00:00 2001 From: Elliott Slaughter Date: Sat, 14 Feb 2026 17:53:00 -0800 Subject: [PATCH 63/63] Finish body of instance allocation. --- .../src/realm-execution/instance_allocation.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index c033f0bac1..b740859e22 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -5,6 +5,7 @@ #include "realm-execution/realm_context.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" @@ -14,6 +15,7 @@ #include "utils/containers/make.h" #include "utils/containers/map_values.h" #include "utils/containers/unordered_set_of.h" +#include "utils/containers/values.h" #include "utils/exception.h" #include "utils/optional.h" @@ -59,6 +61,15 @@ TensorInstanceBacking perform_instance_allocation( } }; + for (DynamicNodeInvocation const &invocation : g.invocations) { + for (DynamicValueAttrs const &input : values(invocation.inputs)) { + allocate(invocation.node_attrs, input); + } + for (DynamicValueAttrs const &output : values(invocation.outputs)) { + allocate(invocation.node_attrs, output); + } + } + return result; }