Skip to content

Conversation

@chen2021673
Copy link

This PR introduces a comprehensive precision checking system for debugging numerical accuracy issues in distributed training:

Core Features:

  • Two-level precision checking (module-level and function-level)
  • Command-line flags: --precision_check, --precision_check_all_ranks
  • Extensible hook system for Functions, Modules, and Tensors
  • Automatic FP32 reference computation for validation

Hook System:

  • Forward/backward pre/post hooks for Functions and Modules
  • Tensor gradient hooks for inspection
  • Unified hook type definitions to reduce code duplication

Implementation:

  • PrecisionChecker utility with configurable check levels
  • Integration with autograd Function and nn::Module
  • Support for distributed training (per-rank checking)
  • Detailed logging to precision_check_rank_[N].log files

Documentation:

  • docs/hook_mechanism.md - Hook system architecture
  • docs/precision_checker_guide.md - Usage guide

Testing:

  • test/hook/test_hook.cc - Hook functionality tests
  • test/hook/test_precision_check.cc - Precision checker tests

chen2021673 and others added 4 commits January 13, 2026 10:10
This PR introduces a comprehensive precision checking system for debugging
numerical accuracy issues in distributed training:

**Core Features:**
- Two-level precision checking (module-level and function-level)
- Command-line flags: --precision_check, --precision_check_all_ranks
- Extensible hook system for Functions, Modules, and Tensors
- Automatic FP32 reference computation for validation

**Hook System:**
- Forward/backward pre/post hooks for Functions and Modules
- Tensor gradient hooks for inspection
- Unified hook type definitions to reduce code duplication

**Implementation:**
- PrecisionChecker utility with configurable check levels
- Integration with autograd Function and nn::Module
- Support for distributed training (per-rank checking)
- Detailed logging to precision_check_rank_[N].log files

**Documentation:**
- docs/hook_mechanism.md - Hook system architecture
- docs/precision_checker_guide.md - Usage guide

**Testing:**
- test/hook/test_hook.cc - Hook functionality tests
- test/hook/test_precision_check.cc - Precision checker tests

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
…omprehensive docs

- Add PrecisionCheckConfig and PrecisionCheckContext for better state management
- Refactor precision checker to use context-based architecture
- Add comprehensive documentation (hook_mechanism.md, precision_checker_guide.md)
- Add test cases for hook system and precision checking
- Update CMakeLists.txt to include new test targets
- Improve command-line flag handling in examples

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
- Unify Function and Module hook infrastructure into common/hook.h
- Remove duplicated HookHandle and HookHandleImpl classes
- Update precision_checker_guide.md and hook_mechanism.md
int pp_rank = 0;

// Set thread-local global rank
nn::parallel::global::thread_global_rank = rank.GlobalRank();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个全局变量后续可以看看有没有什么更优雅的替代方法

This commit fixes the issue where only rank 0 generated precision check
log files when running with tensor parallelism. The root cause was that
GetLogStream() used process-global static variables, causing all threads
in a single process to share the same log file handle.

Changes:
- Add thread_global_rank thread-local variable to track per-thread rank
- Convert GetLogStream() and TableHeaderPrinted() to use thread_local storage
- Set thread_global_rank in Train() function for each thread
- Move baseline output (key|md5 format) into table format branch to avoid
  duplicate output in simple format
- Add directory creation and error handling for log file opening

With these changes, each thread now creates its own log file based on
its global rank (process_rank * nthread_per_process + thread_rank).

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
Add tools/compare_loss.py to automate end-to-end loss comparison between
two log directories, eliminating manual verification overhead as test cases
scale up.

Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
@@ -1,6 +1,8 @@
#pragma once

#include <functional>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件没看到实质修改,新加的前置声明和头文件有用到吗?

void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size, int virtual_pipeline_parallel_size);
int pipeline_parallel_size, int virtual_pipeline_parallel_size,
const utils::PrecisionCheckConfig &precision_config = utils::PrecisionCheckConfig());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数签名尽可能不要填默认值,否则后续新增参数都要有默认值。


Layout layout_;
PrecisionCheckLevel precision_check_level_ = PrecisionCheckLevel::NONE;
utils::PrecisionCheckConfig precision_check_config_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以补上 const

inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled,
int pipeline_parallel_size, int virtual_pipeline_parallel) {
int pipeline_parallel_size, int virtual_pipeline_parallel,
const utils::PrecisionCheckConfig &precision_config = utils::PrecisionCheckConfig()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要填默认值

std::string baseline_path = ""; // baseline file path for comparison

// Parse from "key=value,key=value" string
static PrecisionCheckConfig Parse(const std::string &config_str) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现放 .cc 里吧

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现放 .cc 里

namespace {

// Simple MD5 implementation
class MD5 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接比 tensor 的 md5,而不是看 abs/rel diff 且留一个阈值范围,是不是很难完全一致,而且无法看出差距多大?

Copy link
Collaborator

@kilinchange kilinchange Jan 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

precision_checker 相关的配置不应该放在 parallel 下,建议挪到其他地方,例如 utils 下。(可以等下次 pr 再改,跟全局 module hook 一起改)

utils::PrecisionChecker::RegisterForModule(this);
precision_check_registered_ = true;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html#torch-nn-modules-module-register-module-forward-hook

我们应当有一个能够注册全局 module hook 的机制,目前 precision_checker 本质上是注册了一个全局的 module hook,应当是 precision_checker 直接调用注册全局 module hook 的接口(例如在 InitAllEnv 里根据传入的 precision 参数决定是否注册全局 precision_check hook)(可以等下次 pr 再改)

}

// Register backward hooks on output tensors' grad_fn
if (!backward_pre_hooks_.empty() || !backward_post_hooks_.empty()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然已经这么写了,就给这个条件包一个 UNLIKELY 吧

sys.exit(1 if total_mismatches > 0 else 0)

if __name__ == '__main__':
main() No newline at end of file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文件末尾加一个空行

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants