diff --git a/src/utils/checkpointing_metrics/README.md b/src/utils/checkpointing_metrics/README.md index c3bc229..5a7cf75 100644 --- a/src/utils/checkpointing_metrics/README.md +++ b/src/utils/checkpointing_metrics/README.md @@ -1,21 +1,83 @@ # Checkpoint statistics calculator -This Python utility calculates checkpoint write time statistics from NVIDIA NeMo log files. +This Python utility calculates checkpoint write time statistics from NVIDIA log files. It supports **multiple log formats** through a plugin-based parser architecture. + +## Supported Log Formats + +| Format | Flag Value | Description | +|---|---|---| +| NeMo 2 | `nemo2` (default) | NeMo 2.x checkpoint logs with `Global Checkpoint Save` messages | +| NeMo 1 | `nemo1` | NeMo 1.x checkpoint logs with `Checkpoint save for step` messages | ## Usage ``` -python calculate_checkpoint_metrics.py --gcs_logs_path - +python calculate_checkpoint_metrics.py --gcs_logs_path [--log_format auto|nemo1|nemo2] ``` + ### Required arguments - `--gcs_logs_path`: The path to NeMo logs in a GCS bucket. E.g. `gs://logs_bucket/experiment_name/experiment_version` +### Optional arguments + +- `--log_format`: The log format to parse. Choices: `auto`, `nemo1`, `nemo2`. Default: `auto` (auto-detects from log content). + +### Examples + +**NeMo 2 logs** (default): +``` +python calculate_checkpoint_metrics.py \ + --gcs_logs_path gs://tess-benchmark-outputs/muzi-8b-dl-ckpt-20260217-175559 +``` + +**NeMo 1 logs**: +``` +python calculate_checkpoint_metrics.py \ + --gcs_logs_path gs://tess-benchmark-outputs/nemo1-experiment \ + --log_format nemo1 +``` + +### Sample output + +``` +> python calculate_checkpoint_metrics.py \ + --gcs_logs_path gs://tess-benchmark-outputs/muzi-8b-dl-ckpt-20260217-175559 +Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-1_localrank-1.txt, Global rank: 1, Local rank: 1 +Auto-detected log format: nemo2 +Analyzing file: muzi-8b-dl-ckpt-20260217-175559/run_0/nemo_log_globalrank-2_localrank-2.txt, Global rank: 2, Local rank: 2 +Auto-detected log format: nemo2 +Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-0_localrank-0.txt, Global rank: 0, Local rank: 0 +Auto-detected log format: nemo2 +Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-3_localrank-3.txt, Global rank: 3, Local rank: 3 +Auto-detected log format: nemo2 +min checkpoint write duration: 4.9700000286102295s +max checkpoint write duration: 34.86500000953674s +average checkpoint write duration: 17.880999982357025s +checkpoint write time standard deviation: 12.443055009810607 +``` + ### Dependencies The utility uses the `google-cloud-storage` Python package. You can install the package to your Python environment using the following command. ``` -pip install google-cloud-storage` +pip install google-cloud-storage +``` + +## Adding a New Log Format + +To add support for a new framework: + +1. Create `_parser.py` in the `checkpointing_metrics/` directory +2. Subclass `LogParser` from `log_parser.py` +3. Implement the abstract methods (regex patterns, time extraction, step normalization) +4. Decorate the class with `@register_parser` +5. Add an import in `calculate_checkpoint_metrics.py` to trigger registration +6. Add tests — no changes needed to core logic + +## Testing + +``` +python3 -m unittest calculate_checkpoint_metrics_test -v ``` \ No newline at end of file diff --git a/src/utils/checkpointing_metrics/calculate_checkpoint_metrics.py b/src/utils/checkpointing_metrics/calculate_checkpoint_metrics.py index 236cfba..f6a7010 100644 --- a/src/utils/checkpointing_metrics/calculate_checkpoint_metrics.py +++ b/src/utils/checkpointing_metrics/calculate_checkpoint_metrics.py @@ -23,18 +23,31 @@ from google.cloud import storage import log_patterns import utils +from log_parser import get_parser, available_parsers, detect_format_from_line, default_filename_validator +import nemo1_parser +import nemo2_parser def process_metrics_from_logs( gcs_logs_path: str, + log_format: str = "auto", ): """Process NeMo logs stored in a GCS bucket and calculate checkpointing metrics. Args: gcs_logs_path: The path to the NeMo logs in a GCS bucket. + log_format: The log format to parse ('nemo1', 'nemo2', 'auto', etc.). + When 'auto', the format is detected from the log content. """ + if log_format == "auto": + parser = None + filename_val = default_filename_validator + else: + parser = get_parser(log_format) + filename_val = parser.validate_filename + storage_client = storage.Client() logs_bucket_name = gcs_logs_path.split("/")[2] match_glob = f'{"/".join(gcs_logs_path.split("/")[3:])}/**' @@ -43,11 +56,10 @@ def process_metrics_from_logs( ckpt_write_times = utils.process_logs_files( logs_bucket=logs_bucket, match_glob=match_glob, - process_logs_file=process_ckpt_write_times, - filename_val=lambda file_path: re.search( - log_patterns.NEMO_LOG_FILE_NAME, file_path - ) - is not None, + process_logs_file=lambda bucket, path: process_ckpt_write_times( + bucket, path, parser + ), + filename_val=filename_val, ) compute_write_duration_per_step(ckpt_write_times) @@ -56,6 +68,7 @@ def process_metrics_from_logs( def process_ckpt_write_times( logs_bucket: storage.bucket.Bucket, file_path: str, + parser=None, ): """Process checkpoint write times from NeMo logs. @@ -63,6 +76,7 @@ def process_ckpt_write_times( logs_bucket: The bucket which contains the logs from the benchmark run. file_path: The path to the NeMo log file. + parser: A LogParser instance for framework-specific parsing. Returns: A list of dictionaries, representing ckpt write data per global_rank. @@ -70,6 +84,8 @@ def process_ckpt_write_times( """ global generate_warnings + auto_detect = parser is None + ckpt_write_results = [] ckpt_write_times = {} blob = logs_bucket.blob(file_path) @@ -78,7 +94,13 @@ def process_ckpt_write_times( content = blob.download_as_string() stream = io.TextIOWrapper(io.BytesIO(content), encoding="utf-8") - file_path_match = re.search(log_patterns.NEMO_LOG_FILE_NAME, file_path) + # For auto-detect mode, find file path match using any parser. + if auto_detect: + file_path_match = re.search( + log_patterns.NEMO_LOG_FILE_NAME, file_path + ) + else: + file_path_match = re.search(parser.log_file_pattern, file_path) if not file_path_match: raise ValueError( f"Invalid file path: {file_path}. Valid pattern:" @@ -92,10 +114,21 @@ def process_ckpt_write_times( ) for line in stream: - start_match = re.search(log_patterns.CHECKPOINT_WRITE_START, line) + # Auto-detect: try all parsers until one matches. + if auto_detect: + detected_parser, start_match = detect_format_from_line(line) + if detected_parser: + parser = detected_parser + auto_detect = False + print(f"Auto-detected log format: {parser.name}") + else: + start_match = None + else: + start_match = parser.checkpoint_start_pattern.search(line) + if start_match: - step = start_match.group(1) - start_time = utils.parse_nemo_timestamp(line) + step = parser.extract_step_from_start(start_match) + start_time = parser.extract_start_time(start_match, line) if ckpt_write_times.get(step, {}).get("start_time"): if generate_warnings: @@ -108,10 +141,14 @@ def process_ckpt_write_times( ckpt_write_times[step] = {"start_time": start_time} continue - end_match = re.search(log_patterns.CHECKPOINT_WRITE_END, line) + # Only check end pattern if a parser has been determined. + if parser is None: + continue + + end_match = parser.checkpoint_end_pattern.search(line) if end_match: - step = end_match.group(1) - end_time = utils.parse_nemo_timestamp(line) + step = parser.extract_step_from_end(end_match) + end_time = parser.extract_end_time(end_match, line) if ckpt_write_times.get(step, {}).get("start_time") is None: raise ValueError( @@ -201,19 +238,27 @@ def compute_write_duration_per_step(write_times: list[dict[str, any]]): if __name__ == "__main__": - parser = argparse.ArgumentParser( + arg_parser = argparse.ArgumentParser( description="Process checkpointing metrics from the logs." ) - parser.add_argument( + arg_parser.add_argument( "--gcs_logs_path", required=True, + help="The path to the NeMo logs in a GCS bucket.", + ) + arg_parser.add_argument( + "--log_format", + choices=["auto"] + available_parsers(), + default="auto", help=( - "The path to the NeMo logs in a GCS bucket" + "The log format to parse. Available formats: auto, " + + ", ".join(available_parsers()) + + ". (default: auto)" ), ) - args = parser.parse_args() + args = arg_parser.parse_args() generate_warnings = os.getenv("GENERATE_LOG_WARNINGS", "False").lower() == "true" - process_metrics_from_logs(args.gcs_logs_path) \ No newline at end of file + process_metrics_from_logs(args.gcs_logs_path, log_format=args.log_format) \ No newline at end of file diff --git a/src/utils/checkpointing_metrics/calculate_checkpoint_metrics_test.py b/src/utils/checkpointing_metrics/calculate_checkpoint_metrics_test.py new file mode 100644 index 0000000..c169985 --- /dev/null +++ b/src/utils/checkpointing_metrics/calculate_checkpoint_metrics_test.py @@ -0,0 +1,473 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Tests for checkpointing metrics processing (NeMo 1 and NeMo 2 formats).""" + +import io +import re +import sys +import os +import unittest +from unittest import mock + +# Add the module directory to sys.path so we can import the modules. +sys.path.insert( + 0, os.path.dirname(os.path.abspath(__file__)) +) + +# Mock google.cloud.storage before importing modules that depend on it. +sys.modules["google"] = mock.MagicMock() +sys.modules["google.cloud"] = mock.MagicMock() +sys.modules["google.cloud.storage"] = mock.MagicMock() + +import log_patterns +import utils +import calculate_checkpoint_metrics +from log_parser import get_parser, available_parsers, detect_format_from_line + + +# --- Sample NeMo 2 log lines --- +SAMPLE_NEMO2_LOG = """\ +[NeMo I 2026-02-17 17:58:40 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 24 : Start time: 1771351120.135s : Save duration: 25.537s +[NeMo I 2026-02-17 17:58:41 nemo_logging:393] Scheduled async checkpoint save for /ckpt/step=24.ckpt +[NeMo I 2026-02-17 17:58:41 nemo_logging:393] Async finalization time took 0.018 s +[NeMo I 2026-02-17 17:58:50 nemo_logging:393] Successfully saved checkpoint from iteration 24 to /ckpt/step=24.ckpt +[NeMo I 2026-02-17 17:58:50 nemo_logging:393] Async checkpoint save for step 25 (/ckpt/step=24.ckpt) finalized successfully. +[NeMo I 2026-02-17 17:59:05 nemo_logging:393] Global Checkpoint Save : Rank: 0 : Iteration: 49 : Start time: 1771351145.879s : Save duration: 5.330s +[NeMo I 2026-02-17 17:59:06 nemo_logging:393] Scheduled async checkpoint save for /ckpt/step=49.ckpt +[NeMo I 2026-02-17 17:59:15 nemo_logging:393] Successfully saved checkpoint from iteration 49 to /ckpt/step=49.ckpt +[NeMo I 2026-02-17 17:59:15 nemo_logging:393] Async checkpoint save for step 50 (/ckpt/step=49.ckpt) finalized successfully. +""" + +# --- Sample NeMo 1 log lines --- +SAMPLE_NEMO1_LOG = """\ +[NeMo I 2024-08-15 10:30:00 nemo_logging:393] Checkpoint save for step 100 started +[NeMo I 2024-08-15 10:30:20 nemo_logging:393] Scheduled async checkpoint save for /ckpt/step=100.ckpt +[NeMo I 2024-08-15 10:30:25 nemo_logging:393] Async checkpoint save for step 100 (/ckpt/step=100.ckpt) finalized successfully. +[NeMo I 2024-08-15 10:31:00 nemo_logging:393] Checkpoint save for step 200 started +[NeMo I 2024-08-15 10:31:18 nemo_logging:393] Scheduled async checkpoint save for /ckpt/step=200.ckpt +[NeMo I 2024-08-15 10:31:30 nemo_logging:393] Async checkpoint save for step 200 (/ckpt/step=200.ckpt) finalized successfully. +""" + + +class TestParserRegistry(unittest.TestCase): + """Tests for the parser registry.""" + + def test_nemo1_parser_registered(self): + self.assertIn("nemo1", available_parsers()) + + def test_nemo2_parser_registered(self): + self.assertIn("nemo2", available_parsers()) + + def test_get_parser_nemo1(self): + parser = get_parser("nemo1") + self.assertEqual(parser.name, "nemo1") + + def test_get_parser_nemo2(self): + parser = get_parser("nemo2") + self.assertEqual(parser.name, "nemo2") + + def test_get_parser_unknown_raises(self): + with self.assertRaises(ValueError): + get_parser("unknown_format") + + +class TestNemo2LogPatterns(unittest.TestCase): + """Tests for NeMo 2 log pattern regex matching via parser.""" + + def setUp(self): + self.parser = get_parser("nemo2") + + def test_checkpoint_write_start_matches_nemo2(self): + line = ( + "[NeMo I 2026-02-17 17:58:40 nemo_logging:393] Global Checkpoint" + " Save : Rank: 0 : Iteration: 24 : Start time: 1771351095.135s" + " : Save duration: 25.537s" + ) + match = self.parser.checkpoint_start_pattern.search(line) + self.assertIsNotNone(match) + self.assertEqual(self.parser.extract_step_from_start(match), "24") + self.assertEqual( + self.parser.extract_start_time(match, line), 1771351095.135 + ) + + def test_checkpoint_write_start_multi_digit_rank(self): + line = ( + "[NeMo I 2026-02-17 17:58:40 nemo_logging:393] Global Checkpoint" + " Save : Rank: 63 : Iteration: 99 : Start time: 1771351190.030s" + " : Save duration: 1.453s" + ) + match = self.parser.checkpoint_start_pattern.search(line) + self.assertIsNotNone(match) + self.assertEqual(self.parser.extract_step_from_start(match), "99") + self.assertEqual( + self.parser.extract_start_time(match, line), 1771351190.030 + ) + + def test_checkpoint_write_start_does_not_match_nemo1(self): + line = ( + "[NeMo I 2024-08-15 10:30:00 nemo_logging:393]" + " Checkpoint save for step 100 started" + ) + match = self.parser.checkpoint_start_pattern.search(line) + self.assertIsNone(match) + + def test_checkpoint_write_end_matches_nemo2(self): + line = ( + "[NeMo I 2026-02-17 17:58:50 nemo_logging:393] Async checkpoint" + " save for step 25 (/ckpt/step=24.ckpt) finalized successfully." + ) + match = self.parser.checkpoint_end_pattern.search(line) + self.assertIsNotNone(match) + self.assertEqual(self.parser.extract_step_from_end(match), "24") + + +class TestNemo1LogPatterns(unittest.TestCase): + """Tests for NeMo 1 log pattern regex matching via parser.""" + + def setUp(self): + self.parser = get_parser("nemo1") + + def test_checkpoint_write_start_matches_nemo1(self): + line = ( + "[NeMo I 2024-08-15 10:30:00 nemo_logging:393]" + " Checkpoint save for step 100 started" + ) + match = self.parser.checkpoint_start_pattern.search(line) + self.assertIsNotNone(match) + self.assertEqual(self.parser.extract_step_from_start(match), "100") + + def test_checkpoint_write_start_does_not_match_nemo2(self): + line = ( + "[NeMo I 2026-02-17 17:58:40 nemo_logging:393] Global Checkpoint" + " Save : Rank: 0 : Iteration: 24 : Start time: 1771351095.135s" + " : Save duration: 25.537s" + ) + match = self.parser.checkpoint_start_pattern.search(line) + self.assertIsNone(match) + + def test_checkpoint_write_end_matches_nemo1(self): + line = ( + "[NeMo I 2024-08-15 10:30:25 nemo_logging:393] Async checkpoint" + " save for step 100 (/ckpt/step=100.ckpt) finalized successfully." + ) + match = self.parser.checkpoint_end_pattern.search(line) + self.assertIsNotNone(match) + self.assertEqual(self.parser.extract_step_from_end(match), "100") + + def test_filename_validation(self): + self.assertTrue( + self.parser.validate_filename( + "logs/nemo_log_globalrank-0_localrank-0.txt" + ) + ) + self.assertFalse( + self.parser.validate_filename("logs/invalid_file_name.txt") + ) + + +class TestSharedPatterns(unittest.TestCase): + """Tests for shared patterns in log_patterns.py.""" + + def test_nemo_log_file_name_pattern(self): + path = "logs/nemo_log_globalrank-0_localrank-0.txt" + match = re.search(log_patterns.NEMO_LOG_FILE_NAME, path) + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "0") + self.assertEqual(match.group(2), "0") + + def test_nemo_timestamp_pattern(self): + line = "[NeMo I 2026-02-17 17:58:50 nemo_logging:393] Some message" + match = re.search(log_patterns.NEMO_LOG_TIMESTAMP, line) + self.assertIsNotNone(match) + self.assertEqual(match.group(1), "2026-02-17 17:58:50") + + +class TestParseNemoTimestamp(unittest.TestCase): + """Tests for timestamp parsing from NeMo log lines.""" + + def test_parse_valid_timestamp(self): + line = "[NeMo I 2026-02-17 17:58:50 nemo_logging:393] Some message" + timestamp = utils.parse_nemo_timestamp(line) + self.assertIsInstance(timestamp, float) + self.assertGreater(timestamp, 0) + + def test_parse_invalid_timestamp_raises(self): + line = "no timestamp here" + with self.assertRaises(ValueError): + utils.parse_nemo_timestamp(line) + + +class TestProcessCkptWriteTimesNemo2(unittest.TestCase): + """Tests for processing checkpoint write times from NeMo 2 logs.""" + + def setUp(self): + calculate_checkpoint_metrics.generate_warnings = False + + @mock.patch.object( + calculate_checkpoint_metrics, "generate_warnings", False + ) + def test_process_nemo2_log_extracts_checkpoints(self): + """Verify that NeMo 2 log lines are correctly parsed.""" + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_blob.download_as_string.return_value = ( + SAMPLE_NEMO2_LOG.encode("utf-8") + ) + + parser = get_parser("nemo2") + file_path = "logs/nemo_log_globalrank-0_localrank-0.txt" + results = calculate_checkpoint_metrics.process_ckpt_write_times( + mock_bucket, file_path, parser + ) + + self.assertIsNotNone(results) + self.assertEqual(len(results), 2) + + # First checkpoint: iteration 24. + self.assertEqual(results[0]["global_rank"], 0) + self.assertEqual(results[0]["local_rank"], 0) + self.assertEqual(results[0]["checkpoint_step"], "24") + self.assertEqual(results[0]["start_time"], 1771351120.135) + self.assertGreater(results[0]["end_time"], results[0]["start_time"]) + self.assertGreater(results[0]["checkpoint_write_duration"], 0) + + # Second checkpoint: iteration 49. + self.assertEqual(results[1]["checkpoint_step"], "49") + self.assertEqual(results[1]["start_time"], 1771351145.879) + + +class TestProcessCkptWriteTimesNemo1(unittest.TestCase): + """Tests for processing checkpoint write times from NeMo 1 logs.""" + + def setUp(self): + calculate_checkpoint_metrics.generate_warnings = False + + @mock.patch.object( + calculate_checkpoint_metrics, "generate_warnings", False + ) + def test_process_nemo1_log_extracts_checkpoints(self): + """Verify that NeMo 1 log lines are correctly parsed.""" + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_blob.download_as_string.return_value = ( + SAMPLE_NEMO1_LOG.encode("utf-8") + ) + + parser = get_parser("nemo1") + file_path = "logs/nemo_log_globalrank-0_localrank-0.txt" + results = calculate_checkpoint_metrics.process_ckpt_write_times( + mock_bucket, file_path, parser + ) + + self.assertIsNotNone(results) + self.assertEqual(len(results), 2) + + # First checkpoint: step 100. + self.assertEqual(results[0]["global_rank"], 0) + self.assertEqual(results[0]["local_rank"], 0) + self.assertEqual(results[0]["checkpoint_step"], "100") + self.assertGreater(results[0]["end_time"], results[0]["start_time"]) + self.assertGreater(results[0]["checkpoint_write_duration"], 0) + + # Second checkpoint: step 200. + self.assertEqual(results[1]["checkpoint_step"], "200") + self.assertGreater(results[1]["end_time"], results[1]["start_time"]) + + +class TestAutoDetection(unittest.TestCase): + """Tests for auto-detection of log format.""" + + def setUp(self): + calculate_checkpoint_metrics.generate_warnings = False + + def test_detect_nemo2_from_line(self): + line = ( + "[NeMo I 2026-02-17 17:58:40 nemo_logging:393] Global Checkpoint" + " Save : Rank: 0 : Iteration: 24 : Start time: 1771351120.135s" + " : Save duration: 25.537s" + ) + parser, match = detect_format_from_line(line) + self.assertIsNotNone(parser) + self.assertEqual(parser.name, "nemo2") + self.assertIsNotNone(match) + + def test_detect_nemo1_from_line(self): + line = ( + "[NeMo I 2024-08-15 10:30:00 nemo_logging:393]" + " Checkpoint save for step 100 started" + ) + parser, match = detect_format_from_line(line) + self.assertIsNotNone(parser) + self.assertEqual(parser.name, "nemo1") + self.assertIsNotNone(match) + + def test_detect_no_match(self): + line = "[NeMo I 2026-02-17 17:58:41 nemo_logging:393] Some other message" + parser, match = detect_format_from_line(line) + self.assertIsNone(parser) + self.assertIsNone(match) + + @mock.patch.object( + calculate_checkpoint_metrics, "generate_warnings", False + ) + def test_auto_detect_nemo2_log(self): + """Auto-detect should correctly parse NeMo 2 logs without explicit parser.""" + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_blob.download_as_string.return_value = ( + SAMPLE_NEMO2_LOG.encode("utf-8") + ) + + file_path = "logs/nemo_log_globalrank-0_localrank-0.txt" + results = calculate_checkpoint_metrics.process_ckpt_write_times( + mock_bucket, file_path, parser=None + ) + + self.assertIsNotNone(results) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["checkpoint_step"], "24") + self.assertEqual(results[1]["checkpoint_step"], "49") + + @mock.patch.object( + calculate_checkpoint_metrics, "generate_warnings", False + ) + def test_auto_detect_nemo1_log(self): + """Auto-detect should correctly parse NeMo 1 logs without explicit parser.""" + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_blob.download_as_string.return_value = ( + SAMPLE_NEMO1_LOG.encode("utf-8") + ) + + file_path = "logs/nemo_log_globalrank-0_localrank-0.txt" + results = calculate_checkpoint_metrics.process_ckpt_write_times( + mock_bucket, file_path, parser=None + ) + + self.assertIsNotNone(results) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["checkpoint_step"], "100") + self.assertEqual(results[1]["checkpoint_step"], "200") + + +class TestProcessCkptWriteTimesInvalidFile(unittest.TestCase): + """Tests for error handling with invalid file paths.""" + + def setUp(self): + calculate_checkpoint_metrics.generate_warnings = False + + def test_process_invalid_file_path(self): + mock_bucket = mock.MagicMock() + mock_blob = mock.MagicMock() + mock_bucket.blob.return_value = mock_blob + mock_blob.download_as_string.return_value = b"" + + parser = get_parser("nemo2") + file_path = "logs/invalid_file_name.txt" + result = calculate_checkpoint_metrics.process_ckpt_write_times( + mock_bucket, file_path, parser + ) + # Should print an error and return None. + self.assertIsNone(result) + + +class TestComputeWriteDurationPerStep(unittest.TestCase): + """Tests for computing write duration per step.""" + + def test_single_rank_multiple_steps(self): + write_times = [ + { + "global_rank": 0, + "local_rank": 0, + "checkpoint_step": "24", + "checkpoint_write_duration": 10.0, + "start_time": 100.0, + "end_time": 110.0, + }, + { + "global_rank": 0, + "local_rank": 0, + "checkpoint_step": "49", + "checkpoint_write_duration": 12.0, + "start_time": 200.0, + "end_time": 212.0, + }, + ] + # Should print stats without error. + calculate_checkpoint_metrics.compute_write_duration_per_step( + write_times + ) + + def test_multi_rank_multiple_steps(self): + write_times = [ + { + "global_rank": 0, + "local_rank": 0, + "checkpoint_step": "24", + "checkpoint_write_duration": 10.0, + "start_time": 100.0, + "end_time": 110.0, + }, + { + "global_rank": 1, + "local_rank": 1, + "checkpoint_step": "24", + "checkpoint_write_duration": 12.0, + "start_time": 99.0, + "end_time": 111.0, + }, + { + "global_rank": 0, + "local_rank": 0, + "checkpoint_step": "49", + "checkpoint_write_duration": 8.0, + "start_time": 200.0, + "end_time": 208.0, + }, + { + "global_rank": 1, + "local_rank": 1, + "checkpoint_step": "49", + "checkpoint_write_duration": 9.0, + "start_time": 199.0, + "end_time": 208.0, + }, + ] + # Duration per step: + # step 24: max(110,111) - min(100,99) = 12 + # step 49: max(208,208) - min(200,199) = 9 + calculate_checkpoint_metrics.compute_write_duration_per_step( + write_times + ) + + def test_empty_write_times(self): + # Should print a warning and not crash. + calculate_checkpoint_metrics.compute_write_duration_per_step([]) + + def test_missing_fields(self): + write_times = [{"global_rank": 0}] + # Should print a warning about missing fields. + calculate_checkpoint_metrics.compute_write_duration_per_step( + write_times + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/src/utils/checkpointing_metrics/log_parser.py b/src/utils/checkpointing_metrics/log_parser.py new file mode 100644 index 0000000..31259b1 --- /dev/null +++ b/src/utils/checkpointing_metrics/log_parser.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Abstract base class and registry for framework-specific log parsers.""" + +import abc +import re + + +class LogParser(abc.ABC): + """Strategy for framework-specific log parsing differences. + + Each subclass encapsulates the regex patterns and extraction logic that + differ between NeMo versions (or other frameworks). The shared + file-processing loop in calculate_checkpoint_metrics.py calls these + methods instead of hardcoding behavior. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """Unique identifier for this parser (e.g. 'nemo1', 'nemo2').""" + + @property + @abc.abstractmethod + def log_file_pattern(self) -> str: + """Regex pattern to match valid log file paths.""" + + @property + @abc.abstractmethod + def checkpoint_start_pattern(self) -> re.Pattern: + """Compiled regex for checkpoint write start log lines.""" + + @property + @abc.abstractmethod + def checkpoint_end_pattern(self) -> re.Pattern: + """Compiled regex for checkpoint write end log lines.""" + + @abc.abstractmethod + def extract_step_from_start(self, match: re.Match) -> str: + """Extract the step number from a checkpoint start match.""" + + @abc.abstractmethod + def extract_start_time(self, match: re.Match, line: str) -> float: + """Extract the start time (epoch seconds) from a checkpoint start match.""" + + @abc.abstractmethod + def extract_step_from_end(self, match: re.Match) -> str: + """Extract the step number from a checkpoint end match. + + Some frameworks adjust the step (e.g. NeMo 2 reports step = iteration + 1 + in the end message, so we subtract 1). + """ + + @abc.abstractmethod + def extract_end_time(self, match: re.Match, line: str) -> float: + """Extract the end time (epoch seconds) from a checkpoint end match.""" + + def validate_filename(self, file_path: str) -> bool: + """Check whether a file path matches this parser's log file pattern.""" + return re.search(self.log_file_pattern, file_path) is not None + + +# ---- Parser Registry ---- + +_PARSERS: dict[str, LogParser] = {} + + +def register_parser(cls): + """Class decorator that registers a LogParser subclass.""" + instance = cls() + _PARSERS[instance.name] = instance + return cls + + +def get_parser(name: str) -> LogParser: + """Retrieve a registered parser by name.""" + if name not in _PARSERS: + raise ValueError( + f"Unknown log format: '{name}'. " + f"Available formats: {available_parsers()}" + ) + return _PARSERS[name] + + +def available_parsers() -> list[str]: + """Return a list of registered parser names.""" + return list(_PARSERS.keys()) + + +def detect_format_from_line(line: str): + """Try all registered parsers' start patterns against a line. + + Args: + line: A log line to test. + + Returns: + A (parser, match) tuple if a parser's start pattern matches, + or (None, None) if no parser matches. + """ + for parser in _PARSERS.values(): + match = parser.checkpoint_start_pattern.search(line) + if match: + return parser, match + return None, None + + +def default_filename_validator(file_path: str) -> bool: + """Validate filenames using any registered parser's log_file_pattern. + + Used in auto-detection mode before a parser is selected. + """ + return any(p.validate_filename(file_path) for p in _PARSERS.values()) diff --git a/src/utils/checkpointing_metrics/log_patterns.py b/src/utils/checkpointing_metrics/log_patterns.py index 960ae50..18db198 100644 --- a/src/utils/checkpointing_metrics/log_patterns.py +++ b/src/utils/checkpointing_metrics/log_patterns.py @@ -20,9 +20,5 @@ # The timestamp pattern in NeMo logs. NEMO_LOG_TIMESTAMP = r"(\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2})" -# The pattern of the checkpoint saving start log. -CHECKPOINT_WRITE_START = r"Checkpoint save for step (\d+) started" -# The pattern of the checkpoint saving end log. -CHECKPOINT_WRITE_END = ( - r"Async checkpoint save for step (\d+) .* finalized successfully" -) \ No newline at end of file +# Note: Framework-specific checkpoint start/end patterns are defined +# in the parser classes under parsers/. \ No newline at end of file diff --git a/src/utils/checkpointing_metrics/nemo1_parser.py b/src/utils/checkpointing_metrics/nemo1_parser.py new file mode 100644 index 0000000..3a399ef --- /dev/null +++ b/src/utils/checkpointing_metrics/nemo1_parser.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""NeMo 1 log parser for checkpointing metrics. + +Patterns restored from the pre-migration code (git commit 04a01d6). +""" + +import re + +import log_patterns +import utils +from log_parser import LogParser, register_parser + + +# NeMo 1 checkpoint start pattern. +_CHECKPOINT_WRITE_START = re.compile( + r"Checkpoint save for step (\d+) started" +) + +# NeMo 1 checkpoint end pattern (same as NeMo 2). +_CHECKPOINT_WRITE_END = re.compile( + r"Async checkpoint save for step (\d+) .* finalized successfully" +) + + +@register_parser +class Nemo1Parser(LogParser): + """Parser for NeMo 1 checkpoint log format.""" + + @property + def name(self) -> str: + return "nemo1" + + @property + def log_file_pattern(self) -> str: + return log_patterns.NEMO_LOG_FILE_NAME + + @property + def checkpoint_start_pattern(self) -> re.Pattern: + return _CHECKPOINT_WRITE_START + + @property + def checkpoint_end_pattern(self) -> re.Pattern: + return _CHECKPOINT_WRITE_END + + def extract_step_from_start(self, match: re.Match) -> str: + return match.group(1) + + def extract_start_time(self, match: re.Match, line: str) -> float: + # NeMo 1 gets the start time from the log line timestamp. + return utils.parse_nemo_timestamp(line) + + def extract_step_from_end(self, match: re.Match) -> str: + # NeMo 1 uses the step directly (no adjustment). + return match.group(1) + + def extract_end_time(self, match: re.Match, line: str) -> float: + return utils.parse_nemo_timestamp(line) diff --git a/src/utils/checkpointing_metrics/nemo2_parser.py b/src/utils/checkpointing_metrics/nemo2_parser.py new file mode 100644 index 0000000..735b059 --- /dev/null +++ b/src/utils/checkpointing_metrics/nemo2_parser.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""NeMo 2 log parser for checkpointing metrics.""" + +import re + +import log_patterns +import utils +from log_parser import LogParser, register_parser + + +# NeMo 2 checkpoint start pattern. +_CHECKPOINT_WRITE_START = re.compile( + r"Global Checkpoint Save : Rank: \d+ : Iteration: (\d+)" + r" : Start time: ([\d.]+)s : Save duration: ([\d.]+)s" +) + +# NeMo 2 checkpoint end pattern. +_CHECKPOINT_WRITE_END = re.compile( + r"Async checkpoint save for step (\d+) .* finalized successfully" +) + + +@register_parser +class Nemo2Parser(LogParser): + """Parser for NeMo 2 checkpoint log format.""" + + @property + def name(self) -> str: + return "nemo2" + + @property + def log_file_pattern(self) -> str: + return log_patterns.NEMO_LOG_FILE_NAME + + @property + def checkpoint_start_pattern(self) -> re.Pattern: + return _CHECKPOINT_WRITE_START + + @property + def checkpoint_end_pattern(self) -> re.Pattern: + return _CHECKPOINT_WRITE_END + + def extract_step_from_start(self, match: re.Match) -> str: + return match.group(1) + + def extract_start_time(self, match: re.Match, line: str) -> float: + # NeMo 2 includes the epoch start time in the log message itself. + return float(match.group(2)) + + def extract_step_from_end(self, match: re.Match) -> str: + # NeMo 2 reports step = iteration + 1 in the end message. + return str(int(match.group(1)) - 1) + + def extract_end_time(self, match: re.Match, line: str) -> float: + return utils.parse_nemo_timestamp(line)