-
Notifications
You must be signed in to change notification settings - Fork 18
feat: add precision checker with hook system and command-line control #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
This PR introduces a comprehensive precision checking system for debugging numerical accuracy issues in distributed training: **Core Features:** - Two-level precision checking (module-level and function-level) - Command-line flags: --precision_check, --precision_check_all_ranks - Extensible hook system for Functions, Modules, and Tensors - Automatic FP32 reference computation for validation **Hook System:** - Forward/backward pre/post hooks for Functions and Modules - Tensor gradient hooks for inspection - Unified hook type definitions to reduce code duplication **Implementation:** - PrecisionChecker utility with configurable check levels - Integration with autograd Function and nn::Module - Support for distributed training (per-rank checking) - Detailed logging to precision_check_rank_[N].log files **Documentation:** - docs/hook_mechanism.md - Hook system architecture - docs/precision_checker_guide.md - Usage guide **Testing:** - test/hook/test_hook.cc - Hook functionality tests - test/hook/test_precision_check.cc - Precision checker tests Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
…omprehensive docs - Add PrecisionCheckConfig and PrecisionCheckContext for better state management - Refactor precision checker to use context-based architecture - Add comprehensive documentation (hook_mechanism.md, precision_checker_guide.md) - Add test cases for hook system and precision checking - Update CMakeLists.txt to include new test targets - Improve command-line flag handling in examples Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
- Unify Function and Module hook infrastructure into common/hook.h - Remove duplicated HookHandle and HookHandleImpl classes - Update precision_checker_guide.md and hook_mechanism.md
| int pp_rank = 0; | ||
|
|
||
| // Set thread-local global rank | ||
| nn::parallel::global::thread_global_rank = rank.GlobalRank(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个全局变量后续可以看看有没有什么更优雅的替代方法
This commit fixes the issue where only rank 0 generated precision check log files when running with tensor parallelism. The root cause was that GetLogStream() used process-global static variables, causing all threads in a single process to share the same log file handle. Changes: - Add thread_global_rank thread-local variable to track per-thread rank - Convert GetLogStream() and TableHeaderPrinted() to use thread_local storage - Set thread_global_rank in Train() function for each thread - Move baseline output (key|md5 format) into table format branch to avoid duplicate output in simple format - Add directory creation and error handling for log file opening With these changes, each thread now creates its own log file based on its global rank (process_rank * nthread_per_process + thread_rank). Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
d35e92a to
a7806d9
Compare
Add tools/compare_loss.py to automate end-to-end loss comparison between two log directories, eliminating manual verification overhead as test cases scale up. Co-Authored-By: Claude Sonnet 4.5 <[email protected]>
| @@ -1,6 +1,8 @@ | |||
| #pragma once | |||
|
|
|||
| #include <functional> | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件没看到实质修改,新加的前置声明和头文件有用到吗?
| void Init(int threads_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, | ||
| int pipeline_parallel_size, int virtual_pipeline_parallel_size); | ||
| int pipeline_parallel_size, int virtual_pipeline_parallel_size, | ||
| const utils::PrecisionCheckConfig &precision_config = utils::PrecisionCheckConfig()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数签名尽可能不要填默认值,否则后续新增参数都要有默认值。
|
|
||
| Layout layout_; | ||
| PrecisionCheckLevel precision_check_level_ = PrecisionCheckLevel::NONE; | ||
| utils::PrecisionCheckConfig precision_check_config_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以补上 const
| inline void InitAllEnv(int nthread_per_process, int tensor_parallel_size, bool sequence_parallel_enabled, | ||
| int pipeline_parallel_size, int virtual_pipeline_parallel) { | ||
| int pipeline_parallel_size, int virtual_pipeline_parallel, | ||
| const utils::PrecisionCheckConfig &precision_config = utils::PrecisionCheckConfig()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要填默认值
| std::string baseline_path = ""; // baseline file path for comparison | ||
|
|
||
| // Parse from "key=value,key=value" string | ||
| static PrecisionCheckConfig Parse(const std::string &config_str) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实现放 .cc 里吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
实现放 .cc 里
| namespace { | ||
|
|
||
| // Simple MD5 implementation | ||
| class MD5 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接比 tensor 的 md5,而不是看 abs/rel diff 且留一个阈值范围,是不是很难完全一致,而且无法看出差距多大?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precision_checker 相关的配置不应该放在 parallel 下,建议挪到其他地方,例如 utils 下。(可以等下次 pr 再改,跟全局 module hook 一起改)
| utils::PrecisionChecker::RegisterForModule(this); | ||
| precision_check_registered_ = true; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们应当有一个能够注册全局 module hook 的机制,目前 precision_checker 本质上是注册了一个全局的 module hook,应当是 precision_checker 直接调用注册全局 module hook 的接口(例如在 InitAllEnv 里根据传入的 precision 参数决定是否注册全局 precision_check hook)(可以等下次 pr 再改)
| } | ||
|
|
||
| // Register backward hooks on output tensors' grad_fn | ||
| if (!backward_pre_hooks_.empty() || !backward_post_hooks_.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然已经这么写了,就给这个条件包一个 UNLIKELY 吧
| sys.exit(1 if total_mismatches > 0 else 0) | ||
|
|
||
| if __name__ == '__main__': | ||
| main() No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件末尾加一个空行
This PR introduces a comprehensive precision checking system for debugging numerical accuracy issues in distributed training:
Core Features:
Hook System:
Implementation:
Documentation:
Testing: