Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions src/utils/checkpointing_metrics/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,83 @@
# Checkpoint statistics calculator

This Python utility calculates checkpoint write time statistics from NVIDIA NeMo log files.
This Python utility calculates checkpoint write time statistics from NVIDIA log files. It supports **multiple log formats** through a plugin-based parser architecture.

## Supported Log Formats

| Format | Flag Value | Description |
|---|---|---|
| NeMo 2 | `nemo2` (default) | NeMo 2.x checkpoint logs with `Global Checkpoint Save` messages |
| NeMo 1 | `nemo1` | NeMo 1.x checkpoint logs with `Checkpoint save for step` messages |

## Usage

```
python calculate_checkpoint_metrics.py --gcs_logs_path <path_to_logs>

python calculate_checkpoint_metrics.py --gcs_logs_path <path_to_logs> [--log_format auto|nemo1|nemo2]
```

### Required arguments

- `--gcs_logs_path`: The path to NeMo logs in a GCS bucket. E.g. `gs://logs_bucket/experiment_name/experiment_version`

### Optional arguments

- `--log_format`: The log format to parse. Choices: `auto`, `nemo1`, `nemo2`. Default: `auto` (auto-detects from log content).

### Examples

**NeMo 2 logs** (default):
```
python calculate_checkpoint_metrics.py \
--gcs_logs_path gs://tess-benchmark-outputs/muzi-8b-dl-ckpt-20260217-175559
```

**NeMo 1 logs**:
```
python calculate_checkpoint_metrics.py \
--gcs_logs_path gs://tess-benchmark-outputs/nemo1-experiment \
--log_format nemo1
```

### Sample output

```
> python calculate_checkpoint_metrics.py \
--gcs_logs_path gs://tess-benchmark-outputs/muzi-8b-dl-ckpt-20260217-175559
Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-1_localrank-1.txt, Global rank: 1, Local rank: 1
Auto-detected log format: nemo2
Analyzing file: muzi-8b-dl-ckpt-20260217-175559/run_0/nemo_log_globalrank-2_localrank-2.txt, Global rank: 2, Local rank: 2
Auto-detected log format: nemo2
Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-0_localrank-0.txt, Global rank: 0, Local rank: 0
Auto-detected log format: nemo2
Analyzing file: muzi-8b-dl-ckpt-20260217-175559/nemo_log_globalrank-3_localrank-3.txt, Global rank: 3, Local rank: 3
Auto-detected log format: nemo2
min checkpoint write duration: 4.9700000286102295s
max checkpoint write duration: 34.86500000953674s
average checkpoint write duration: 17.880999982357025s
checkpoint write time standard deviation: 12.443055009810607
```

### Dependencies

The utility uses the `google-cloud-storage` Python package. You can install the package to your Python environment using the following command.

```
pip install google-cloud-storage`
pip install google-cloud-storage
```

## Adding a New Log Format

To add support for a new framework:

1. Create `<framework>_parser.py` in the `checkpointing_metrics/` directory
2. Subclass `LogParser` from `log_parser.py`
3. Implement the abstract methods (regex patterns, time extraction, step normalization)
4. Decorate the class with `@register_parser`
5. Add an import in `calculate_checkpoint_metrics.py` to trigger registration
6. Add tests — no changes needed to core logic

## Testing

```
python3 -m unittest calculate_checkpoint_metrics_test -v
```
79 changes: 62 additions & 17 deletions src/utils/checkpointing_metrics/calculate_checkpoint_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,31 @@
from google.cloud import storage
import log_patterns
import utils
from log_parser import get_parser, available_parsers, detect_format_from_line, default_filename_validator
import nemo1_parser
import nemo2_parser


def process_metrics_from_logs(
gcs_logs_path: str,
log_format: str = "auto",
):
"""Process NeMo logs stored in a GCS bucket and calculate checkpointing
metrics.

Args:
gcs_logs_path: The path to the NeMo logs in a GCS bucket.
log_format: The log format to parse ('nemo1', 'nemo2', 'auto', etc.).
When 'auto', the format is detected from the log content.
"""

if log_format == "auto":
parser = None
filename_val = default_filename_validator
else:
parser = get_parser(log_format)
filename_val = parser.validate_filename

storage_client = storage.Client()
logs_bucket_name = gcs_logs_path.split("/")[2]
match_glob = f'{"/".join(gcs_logs_path.split("/")[3:])}/**'
Expand All @@ -43,11 +56,10 @@ def process_metrics_from_logs(
ckpt_write_times = utils.process_logs_files(
logs_bucket=logs_bucket,
match_glob=match_glob,
process_logs_file=process_ckpt_write_times,
filename_val=lambda file_path: re.search(
log_patterns.NEMO_LOG_FILE_NAME, file_path
)
is not None,
process_logs_file=lambda bucket, path: process_ckpt_write_times(
bucket, path, parser
),
filename_val=filename_val,
)

compute_write_duration_per_step(ckpt_write_times)
Expand All @@ -56,20 +68,24 @@ def process_metrics_from_logs(
def process_ckpt_write_times(
logs_bucket: storage.bucket.Bucket,
file_path: str,
parser=None,
):
"""Process checkpoint write times from NeMo logs.

Args:
logs_bucket: The bucket which contains the logs from
the benchmark run.
file_path: The path to the NeMo log file.
parser: A LogParser instance for framework-specific parsing.

Returns:
A list of dictionaries, representing ckpt write data per global_rank.

"""
global generate_warnings

auto_detect = parser is None

ckpt_write_results = []
ckpt_write_times = {}
blob = logs_bucket.blob(file_path)
Expand All @@ -78,7 +94,13 @@ def process_ckpt_write_times(
content = blob.download_as_string()
stream = io.TextIOWrapper(io.BytesIO(content), encoding="utf-8")

file_path_match = re.search(log_patterns.NEMO_LOG_FILE_NAME, file_path)
# For auto-detect mode, find file path match using any parser.
if auto_detect:
file_path_match = re.search(
log_patterns.NEMO_LOG_FILE_NAME, file_path
)
else:
file_path_match = re.search(parser.log_file_pattern, file_path)
if not file_path_match:
raise ValueError(
f"Invalid file path: {file_path}. Valid pattern:"
Expand All @@ -92,10 +114,21 @@ def process_ckpt_write_times(
)

for line in stream:
start_match = re.search(log_patterns.CHECKPOINT_WRITE_START, line)
# Auto-detect: try all parsers until one matches.
if auto_detect:
detected_parser, start_match = detect_format_from_line(line)
if detected_parser:
parser = detected_parser
auto_detect = False
print(f"Auto-detected log format: {parser.name}")
else:
start_match = None
else:
start_match = parser.checkpoint_start_pattern.search(line)

if start_match:
step = start_match.group(1)
start_time = utils.parse_nemo_timestamp(line)
step = parser.extract_step_from_start(start_match)
start_time = parser.extract_start_time(start_match, line)

if ckpt_write_times.get(step, {}).get("start_time"):
if generate_warnings:
Expand All @@ -108,10 +141,14 @@ def process_ckpt_write_times(
ckpt_write_times[step] = {"start_time": start_time}
continue

end_match = re.search(log_patterns.CHECKPOINT_WRITE_END, line)
# Only check end pattern if a parser has been determined.
if parser is None:
continue

end_match = parser.checkpoint_end_pattern.search(line)
if end_match:
step = end_match.group(1)
end_time = utils.parse_nemo_timestamp(line)
step = parser.extract_step_from_end(end_match)
end_time = parser.extract_end_time(end_match, line)

if ckpt_write_times.get(step, {}).get("start_time") is None:
raise ValueError(
Expand Down Expand Up @@ -201,19 +238,27 @@ def compute_write_duration_per_step(write_times: list[dict[str, any]]):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
arg_parser = argparse.ArgumentParser(
description="Process checkpointing metrics from the logs."
)
parser.add_argument(
arg_parser.add_argument(
"--gcs_logs_path",
required=True,
help="The path to the NeMo logs in a GCS bucket.",
)
arg_parser.add_argument(
"--log_format",
choices=["auto"] + available_parsers(),
default="auto",
help=(
"The path to the NeMo logs in a GCS bucket"
"The log format to parse. Available formats: auto, "
+ ", ".join(available_parsers())
+ ". (default: auto)"
),
)

args = parser.parse_args()
args = arg_parser.parse_args()

generate_warnings = os.getenv("GENERATE_LOG_WARNINGS", "False").lower() == "true"

process_metrics_from_logs(args.gcs_logs_path)
process_metrics_from_logs(args.gcs_logs_path, log_format=args.log_format)
Loading