diff --git a/.github/workflows/pull_request.yaml b/.github/workflows/pull_request.yaml index e7be9e6f..33baadef 100644 --- a/.github/workflows/pull_request.yaml +++ b/.github/workflows/pull_request.yaml @@ -53,6 +53,9 @@ jobs: run: make data env: TESTING: "1" + MODAL_CALIBRATE: "1" + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} - name: Save calibration log (constituencies) uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/push.yaml b/.github/workflows/push.yaml index d4575eb6..9c019934 100644 --- a/.github/workflows/push.yaml +++ b/.github/workflows/push.yaml @@ -58,6 +58,10 @@ jobs: HUGGING_FACE_TOKEN: ${{ secrets.HUGGING_FACE_TOKEN }} - name: Build datasets run: make data + env: + MODAL_CALIBRATE: "1" + MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }} + MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }} - name: Save calibration log (constituencies) uses: actions/upload-artifact@v4 with: diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..6d4d9bd5 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: minor + changes: + added: + - Modal GPU calibration support to speed up CI runs diff --git a/policyengine_uk_data/datasets/create_datasets.py b/policyengine_uk_data/datasets/create_datasets.py index 07f6b362..41c2c149 100644 --- a/policyengine_uk_data/datasets/create_datasets.py +++ b/policyengine_uk_data/datasets/create_datasets.py @@ -1,7 +1,11 @@ from policyengine_uk_data.datasets.frs import create_frs from policyengine_uk_data.storage import STORAGE_FOLDER +import gc import logging import os +import io +import numpy as np +import h5py from policyengine_uk_data.utils.uprating import uprate_dataset from policyengine_uk_data.utils.progress import ( ProcessingProgress, @@ -11,17 +15,159 @@ logging.basicConfig(level=logging.INFO) +USE_MODAL = os.environ.get("MODAL_CALIBRATE", "0") == "1" + + +def _dump(arr) -> bytes: + buf = io.BytesIO() + np.save(buf, arr) + return buf.getvalue() + + +def _build_weights_init(dataset, area_count, r): + areas_per_household = np.maximum(r.sum(axis=0), 1) + original_weights = np.log( + dataset.household.household_weight.values / areas_per_household + + np.random.random(len(dataset.household.household_weight.values)) + * 0.01 + ) + return np.ones((area_count, len(original_weights))) * original_weights + + +def _build_log(checkpoints, get_performance, m_c, y_c, m_n, y_n, log_csv): + import pandas as pd + + performance = pd.DataFrame() + for epoch, w_bytes in checkpoints: + w = np.load(io.BytesIO(w_bytes)) + perf = get_performance(w, m_c, y_c, m_n, y_n, []) + perf["epoch"] = epoch + perf["loss"] = perf.rel_abs_error**2 + perf["target_name"] = [ + f"{a}/{m}" for a, m in zip(perf.name, perf.metric) + ] + performance = pd.concat([performance, perf], ignore_index=True) + performance.to_csv(log_csv, index=False) + final_epoch, final_bytes = checkpoints[-1] + return np.load(io.BytesIO(final_bytes)) + + +def _run_modal_calibrations( + frs, + epochs, + create_constituency_target_matrix, + create_local_authority_target_matrix, + create_national_target_matrix, + get_constituency_performance, + get_la_performance, +): + """ + Dispatch both calibrations concurrently to Modal GPU containers. + Returns (constituency_weights, la_weights) as numpy arrays and + writes constituency_calibration_log.csv / la_calibration_log.csv. + """ + import modal + import pandas as pd + from policyengine_uk_data.utils.modal_calibrate import ( + app, + run_calibration, + ) + + def _arr(x): + return x.values if hasattr(x, "values") else x + + # Build matrices one at a time; serialise immediately and free the + # DataFrames (keeping only column/index metadata for log reconstruction). + + m_nat, y_nat = create_national_target_matrix(frs.copy()) + m_nat_np = _arr(m_nat) + y_nat_np = _arr(y_nat) + m_nat_cols = list(m_nat.columns) + y_nat_index = list(y_nat.index) + b_m_nat = _dump(m_nat_np) + b_y_nat = _dump(y_nat_np) + del m_nat, y_nat + gc.collect() + + frs_copy = frs.copy() + matrix_c, y_c, r_c = create_constituency_target_matrix(frs_copy) + matrix_c_np = _arr(matrix_c) + y_c_np = _arr(y_c) + matrix_c_cols = list(matrix_c.columns) + y_c_cols = list(y_c.columns) + wi_c = _build_weights_init(frs_copy, 650, r_c) + b_matrix_c = _dump(matrix_c_np) + b_y_c = _dump(y_c_np) + b_wi_c = _dump(wi_c) + b_r_c = _dump(r_c) + del matrix_c, y_c, wi_c, r_c, frs_copy + gc.collect() + + frs_copy = frs.copy() + matrix_la, y_la, r_la = create_local_authority_target_matrix(frs_copy) + matrix_la_np = _arr(matrix_la) + y_la_np = _arr(y_la) + matrix_la_cols = list(matrix_la.columns) + y_la_cols = list(y_la.columns) + wi_la = _build_weights_init(frs_copy, 360, r_la) + b_matrix_la = _dump(matrix_la_np) + b_y_la = _dump(y_la_np) + b_wi_la = _dump(wi_la) + b_r_la = _dump(r_la) + del matrix_la, y_la, wi_la, r_la, frs_copy + gc.collect() + + with modal.enable_output(), app.run(): + fut_c = run_calibration.spawn( + b_matrix_c, b_y_c, b_r_c, b_m_nat, b_y_nat, b_wi_c, epochs + ) + fut_la = run_calibration.spawn( + b_matrix_la, b_y_la, b_r_la, b_m_nat, b_y_nat, b_wi_la, epochs + ) + del b_r_c, b_wi_c, b_r_la, b_wi_la + gc.collect() + + checkpoints_c = fut_c.get() + checkpoints_la = fut_la.get() + + # Reconstruct DataFrames with correct labels for get_performance + matrix_c_df = pd.DataFrame(matrix_c_np, columns=matrix_c_cols) + y_c_df = pd.DataFrame(y_c_np, columns=y_c_cols) + m_nat_df = pd.DataFrame(m_nat_np, columns=m_nat_cols) + y_nat_df = pd.Series(y_nat_np, index=y_nat_index) + matrix_la_df = pd.DataFrame(matrix_la_np, columns=matrix_la_cols) + y_la_df = pd.DataFrame(y_la_np, columns=y_la_cols) + + weights_c = _build_log( + checkpoints_c, + get_constituency_performance, + matrix_c_df, + y_c_df, + m_nat_df, + y_nat_df, + "constituency_calibration_log.csv", + ) + weights_la = _build_log( + checkpoints_la, + get_la_performance, + matrix_la_df, + y_la_df, + m_nat_df, + y_nat_df, + "la_calibration_log.csv", + ) + + return weights_c, weights_la + def main(): """Create enhanced FRS dataset with rich progress tracking.""" try: - # Use reduced epochs and fidelity for testing is_testing = os.environ.get("TESTING", "0") == "1" epochs = 32 if is_testing else 512 progress_tracker = ProcessingProgress() - # Define dataset creation steps steps = [ "Create base FRS dataset", "Impute consumption", @@ -43,7 +189,6 @@ def main(): update_dataset, nested_progress, ): - # Create base FRS dataset update_dataset("Create base FRS dataset", "processing") frs = create_frs( raw_frs_folder=STORAGE_FOLDER / "frs_2023_24", @@ -52,7 +197,6 @@ def main(): frs.save(STORAGE_FOLDER / "frs_2023_24.h5") update_dataset("Create base FRS dataset", "completed") - # Import imputation functions from policyengine_uk_data.datasets.imputations import ( impute_consumption, impute_wealth, @@ -64,9 +208,6 @@ def main(): impute_student_loan_plan, ) - # Apply imputations with progress tracking - # Wealth must be imputed before consumption because consumption - # uses num_vehicles as a predictor for fuel spending update_dataset("Impute wealth", "processing") frs = impute_wealth(frs) update_dataset("Impute wealth", "completed") @@ -99,19 +240,10 @@ def main(): frs = impute_student_loan_plan(frs, year=2025) update_dataset("Impute student loan plan", "completed") - # Uprate dataset update_dataset("Uprate to 2025", "processing") frs = uprate_dataset(frs, 2025) update_dataset("Uprate to 2025", "completed") - # Calibrate constituency weights with nested progress - - update_dataset("Calibrate constituency weights", "processing") - - # Use a separate progress tracker for calibration with nested display - from policyengine_uk_data.utils.calibrate import ( - calibrate_local_areas, - ) from policyengine_uk_data.datasets.local_areas.constituencies.loss import ( create_constituency_target_matrix, ) @@ -121,23 +253,6 @@ def main(): from policyengine_uk_data.datasets.local_areas.constituencies.calibrate import ( get_performance, ) - - # Run calibration with verbose progress - frs_calibrated_constituencies = calibrate_local_areas( - dataset=frs, - epochs=epochs, - matrix_fn=create_constituency_target_matrix, - national_matrix_fn=create_national_target_matrix, - area_count=650, - weight_file="parliamentary_constituency_weights.h5", - excluded_training_targets=[], - log_csv="constituency_calibration_log.csv", - verbose=True, # Enable nested progress display - area_name="Constituency", - get_performance=get_performance, - nested_progress=nested_progress, # Pass the nested progress manager - ) - from policyengine_uk_data.datasets.local_areas.local_authorities.calibrate import ( get_performance as get_la_performance, ) @@ -145,25 +260,85 @@ def main(): create_local_authority_target_matrix, ) - # Run calibration with verbose progress - calibrate_local_areas( - dataset=frs, - epochs=epochs, - matrix_fn=create_local_authority_target_matrix, - national_matrix_fn=create_national_target_matrix, - area_count=360, - weight_file="local_authority_weights.h5", - excluded_training_targets=[], - log_csv="la_calibration_log.csv", - verbose=True, # Enable nested progress display - area_name="Local Authority", - get_performance=get_la_performance, - nested_progress=nested_progress, # Pass the nested progress manager - ) + if USE_MODAL: + update_dataset("Calibrate constituency weights", "processing") + update_dataset( + "Calibrate local authority weights", "processing" + ) + + weights_c, weights_la = _run_modal_calibrations( + frs, + epochs, + create_constituency_target_matrix, + create_local_authority_target_matrix, + create_national_target_matrix, + get_performance, + get_la_performance, + ) + + with h5py.File( + STORAGE_FOLDER / "parliamentary_constituency_weights.h5", + "w", + ) as f: + f.create_dataset("2025", data=weights_c) + + with h5py.File( + STORAGE_FOLDER / "local_authority_weights.h5", "w" + ) as f: + f.create_dataset("2025", data=weights_la) + + frs_calibrated_constituencies = frs.copy() + frs_calibrated_constituencies.household.household_weight = ( + weights_c.sum(axis=0) + ) + + update_dataset("Calibrate constituency weights", "completed") + update_dataset( + "Calibrate local authority weights", "completed" + ) + else: + from policyengine_uk_data.utils.calibrate import ( + calibrate_local_areas, + ) + + update_dataset("Calibrate constituency weights", "processing") + frs_calibrated_constituencies = calibrate_local_areas( + dataset=frs, + epochs=epochs, + matrix_fn=create_constituency_target_matrix, + national_matrix_fn=create_national_target_matrix, + area_count=650, + weight_file="parliamentary_constituency_weights.h5", + excluded_training_targets=[], + log_csv="constituency_calibration_log.csv", + verbose=True, + area_name="Constituency", + get_performance=get_performance, + nested_progress=nested_progress, + ) + update_dataset("Calibrate constituency weights", "completed") - update_dataset("Calibrate dataset", "completed") + update_dataset( + "Calibrate local authority weights", "processing" + ) + calibrate_local_areas( + dataset=frs, + epochs=epochs, + matrix_fn=create_local_authority_target_matrix, + national_matrix_fn=create_national_target_matrix, + area_count=360, + weight_file="local_authority_weights.h5", + excluded_training_targets=[], + log_csv="la_calibration_log.csv", + verbose=True, + area_name="Local Authority", + get_performance=get_la_performance, + nested_progress=nested_progress, + ) + update_dataset( + "Calibrate local authority weights", "completed" + ) - # Downrate and save update_dataset("Downrate to 2023", "processing") frs_calibrated = uprate_dataset( frs_calibrated_constituencies, 2023 @@ -174,14 +349,14 @@ def main(): frs_calibrated.save(STORAGE_FOLDER / "enhanced_frs_2023_24.h5") update_dataset("Save final dataset", "completed") - # Display success message display_success_panel( "Dataset creation completed successfully", details={ "base_dataset": "frs_2023_24.h5", "enhanced_dataset": "enhanced_frs_2023_24.h5", "imputations_applied": "consumption, wealth, VAT, services, income, capital_gains, salary_sacrifice, student_loan_plan", - "calibration": "national, LA and constituency targets", + "calibration": "national, LA and constituency targets", + "calibration_backend": "Modal GPU" if USE_MODAL else "CPU", }, ) diff --git a/policyengine_uk_data/utils/calibrate.py b/policyengine_uk_data/utils/calibrate.py index 6e31402c..ae6f29f5 100644 --- a/policyengine_uk_data/utils/calibrate.py +++ b/policyengine_uk_data/utils/calibrate.py @@ -8,119 +8,84 @@ from policyengine_uk_data.utils.progress import ProcessingProgress -def calibrate_local_areas( - dataset: UKSingleYearDataset, - matrix_fn, - national_matrix_fn, - area_count: int, - weight_file: str, - dataset_key: str = "2025", - epochs: int = 512, - excluded_training_targets=[], - log_csv=None, +def _run_optimisation( + matrix_np: np.ndarray, + y_np: np.ndarray, + r_np: np.ndarray, + matrix_national_np: np.ndarray, + y_national_np: np.ndarray, + weights_init_np: np.ndarray, + epochs: int, + device: torch.device, + excluded_training_targets_local: np.ndarray | None = None, + excluded_training_targets_national: np.ndarray | None = None, verbose: bool = False, area_name: str = "area", - get_performance=None, + progress_tracker=None, nested_progress=None, -): + log_csv: str | None = None, + get_performance=None, + m_c_orig=None, + y_c_orig=None, + m_n_orig=None, + y_n_orig=None, + weight_file: str | None = None, + dataset_key: str = "2025", + dataset=None, +) -> np.ndarray: """ - Generic calibration function for local areas (constituencies, local authorities, etc.) + Pure optimisation loop (Adam, PyTorch). Device-agnostic — pass + ``device=torch.device("cuda")`` for GPU or ``"cpu"`` for CPU. - Args: - dataset: The dataset to calibrate - matrix_fn: Function that returns (matrix, targets, mask) for the local areas - national_matrix_fn: Function that returns (matrix, targets) for national totals - area_count: Number of areas (e.g., 650 for constituencies, 360 for local authorities) - weight_file: Name of the h5 file to save weights to - dataset_key: Key to use in the h5 file - epochs: Number of training epochs - excluded_training_targets: List of targets to exclude from training (for validation) - log_csv: CSV file to log performance to - verbose: Whether to print progress - area_name: Name of the area type for logging + Returns the final weights array (area_count × n_households). """ - dataset = dataset.copy() - matrix, y, r = matrix_fn(dataset) - m_c, y_c = matrix.copy(), y.copy() - m_national, y_national = national_matrix_fn(dataset) - m_n, y_n = m_national.copy(), y_national.copy() - - # Weights - area_count x num_households - # Use country-aware initialization: divide each household's weight by the - # number of areas in its country, not the total area count. This ensures - # households start at approximately correct weight for their country's targets. - # The country_mask r[i,j]=1 iff household j is in same country as area i. - areas_per_household = r.sum( - axis=0 - ) # number of areas each household can contribute to - areas_per_household = np.maximum( - areas_per_household, 1 - ) # avoid division by zero - original_weights = np.log( - dataset.household.household_weight.values / areas_per_household - + np.random.random(len(dataset.household.household_weight.values)) - * 0.01 + metrics = torch.tensor(matrix_np, dtype=torch.float32, device=device) + y = torch.tensor(y_np, dtype=torch.float32, device=device) + matrix_national = torch.tensor( + matrix_national_np, dtype=torch.float32, device=device ) + y_national = torch.tensor( + y_national_np, dtype=torch.float32, device=device + ) + r = torch.tensor(r_np, dtype=torch.float32, device=device) + weights = torch.tensor( - np.ones((area_count, len(original_weights))) * original_weights, + weights_init_np, dtype=torch.float32, + device=device, requires_grad=True, ) - # Set up validation targets if specified - validation_targets_local = ( - matrix.columns.isin(excluded_training_targets) - if hasattr(matrix, "columns") - else None + dropout_targets = ( + excluded_training_targets_local is not None + and excluded_training_targets_local.any() ) - validation_targets_national = ( - m_national.columns.isin(excluded_training_targets) - if hasattr(m_national, "columns") - else None - ) - dropout_targets = len(excluded_training_targets) > 0 - - # Convert to tensors - metrics = torch.tensor( - matrix.values if hasattr(matrix, "values") else matrix, - dtype=torch.float32, - ) - y = torch.tensor( - y.values if hasattr(y, "values") else y, dtype=torch.float32 - ) - matrix_national = torch.tensor( - m_national.values if hasattr(m_national, "values") else m_national, - dtype=torch.float32, - ) - y_national = torch.tensor( - y_national.values if hasattr(y_national, "values") else y_national, - dtype=torch.float32, - ) - r = torch.tensor(r, dtype=torch.float32) - def sre(x, y): - one_way = ((1 + x) / (1 + y) - 1) ** 2 - other_way = ((1 + y) / (1 + x) - 1) ** 2 + def sre(x, y_t): + one_way = ((1 + x) / (1 + y_t) - 1) ** 2 + other_way = ((1 + y_t) / (1 + x) - 1) ** 2 return torch.min(one_way, other_way) - def loss(w, validation: bool = False): + def loss_fn(w, validation: bool = False): pred_local = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1) - if dropout_targets and validation_targets_local is not None: - if validation: - mask = validation_targets_local - else: - mask = ~validation_targets_local + if dropout_targets and excluded_training_targets_local is not None: + mask = ( + excluded_training_targets_local + if validation + else ~excluded_training_targets_local + ) pred_local = pred_local[:, mask] mse_local = torch.mean(sre(pred_local, y[:, mask])) else: mse_local = torch.mean(sre(pred_local, y)) pred_national = (w.sum(axis=0) * matrix_national.T).sum(axis=1) - if dropout_targets and validation_targets_national is not None: - if validation: - mask = validation_targets_national - else: - mask = ~validation_targets_national + if dropout_targets and excluded_training_targets_national is not None: + mask = ( + excluded_training_targets_national + if validation + else ~excluded_training_targets_national + ) pred_national = pred_national[mask] mse_national = torch.mean(sre(pred_national, y_national[mask])) else: @@ -129,164 +94,179 @@ def loss(w, validation: bool = False): return mse_local + mse_national def pct_close(w, t=0.1, local=True, national=True): - """Return the percentage of metrics that are within t% of the target""" numerator = 0 denominator = 0 - if local: pred_local = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1) - e_local = torch.sum( + numerator += torch.sum( torch.abs((pred_local / (1 + y) - 1)) < t ).item() - c_local = pred_local.shape[0] * pred_local.shape[1] - numerator += e_local - denominator += c_local - + denominator += pred_local.shape[0] * pred_local.shape[1] if national: pred_national = (w.sum(axis=0) * matrix_national.T).sum(axis=1) - e_national = torch.sum( + numerator += torch.sum( torch.abs((pred_national / (1 + y_national) - 1)) < t ).item() - c_national = pred_national.shape[0] - numerator += e_national - denominator += c_national - + denominator += pred_national.shape[0] return numerator / denominator - def dropout_weights(weights, p): + def dropout_weights(w, p): if p == 0: - return weights - # Replace p% of the weights with the mean value of the rest of them - mask = torch.rand_like(weights) < p - mean = weights[~mask].mean() - masked_weights = weights.clone() - masked_weights[mask] = mean - return masked_weights + return w + mask = torch.rand_like(w) < p + mean = w[~mask].mean() + w2 = w.clone() + w2[mask] = mean + return w2 optimizer = torch.optim.Adam([weights], lr=1e-1) - final_weights = (torch.exp(weights) * r).detach().numpy() + final_weights = (torch.exp(weights) * r).detach().cpu().numpy() performance = pd.DataFrame() - progress_tracker = ProcessingProgress() if verbose else None + def _epoch_step(epoch): + nonlocal final_weights, performance + optimizer.zero_grad() + weights_ = torch.exp(dropout_weights(weights, 0.05)) * r + l = loss_fn(weights_) + l.backward() + optimizer.step() - if verbose and progress_tracker: + local_close = pct_close(weights_, local=True, national=False) + national_close = pct_close(weights_, local=False, national=True) + + if epoch % 10 == 0: + final_weights = (torch.exp(weights) * r).detach().cpu().numpy() + + if log_csv and get_performance and m_c_orig is not None: + perf = get_performance( + final_weights, + m_c_orig, + y_c_orig, + m_n_orig, + y_n_orig, + [], + ) + perf["epoch"] = epoch + perf["loss"] = perf.rel_abs_error**2 + perf["target_name"] = [ + f"{a}/{m}" for a, m in zip(perf.name, perf.metric) + ] + performance = pd.concat([performance, perf], ignore_index=True) + performance.to_csv(log_csv, index=False) + + if weight_file: + with h5py.File(STORAGE_FOLDER / weight_file, "w") as f: + f.create_dataset(dataset_key, data=final_weights) + if dataset is not None: + dataset.household.household_weight = final_weights.sum(axis=0) + + return l, local_close, national_close + + if verbose and progress_tracker is not None: with progress_tracker.track_calibration( epochs, nested_progress ) as update_calibration: for epoch in range(epochs): update_calibration(epoch + 1, calculating_loss=True) - - optimizer.zero_grad() - weights_ = torch.exp(dropout_weights(weights, 0.05)) * r - l = loss(weights_) - l.backward() - optimizer.step() - - local_close = pct_close(weights_, local=True, national=False) - national_close = pct_close( - weights_, local=False, national=True - ) - + l, _, _ = _epoch_step(epoch) if dropout_targets: - validation_loss = loss(weights_, validation=True) + weights_ = torch.exp(dropout_weights(weights, 0.05)) * r + val_loss = loss_fn(weights_, validation=True) update_calibration( epoch + 1, - loss_value=validation_loss.item(), + loss_value=val_loss.item(), calculating_loss=False, ) else: update_calibration( - epoch + 1, loss_value=l.item(), calculating_loss=False - ) - - if epoch % 10 == 0: - final_weights = (torch.exp(weights) * r).detach().numpy() - - # Log performance if requested and get_performance function is available - if log_csv and get_performance: - performance_step = get_performance( - final_weights, - m_c, - y_c, - m_n, - y_n, - excluded_training_targets, - ) - performance_step["epoch"] = epoch - performance_step["loss"] = ( - performance_step.rel_abs_error**2 - ) - performance_step["target_name"] = [ - f"{area}/{metric}" - for area, metric in zip( - performance_step.name, performance_step.metric - ) - ] - performance = pd.concat( - [performance, performance_step], ignore_index=True - ) - performance.to_csv(log_csv, index=False) - - # Save weights - with h5py.File(STORAGE_FOLDER / weight_file, "w") as f: - f.create_dataset(dataset_key, data=final_weights) - - dataset.household.household_weight = final_weights.sum( - axis=0 + epoch + 1, + loss_value=l.item(), + calculating_loss=False, ) else: for epoch in range(epochs): - optimizer.zero_grad() - weights_ = torch.exp(dropout_weights(weights, 0.05)) * r - l = loss(weights_) - l.backward() - optimizer.step() + _epoch_step(epoch) - local_close = pct_close(weights_, local=True, national=False) - national_close = pct_close(weights_, local=False, national=True) + return final_weights - if verbose and (epoch % 1 == 0): - if dropout_targets: - validation_loss = loss(weights_, validation=True) - print( - f"Training loss: {l.item():,.3f}, Validation loss: {validation_loss.item():,.3f}, Epoch: {epoch}, " - f"{area_name}<10%: {local_close:.1%}, National<10%: {national_close:.1%}" - ) - else: - print( - f"Loss: {l.item()}, Epoch: {epoch}, {area_name}<10%: {local_close:.1%}, National<10%: {national_close:.1%}" - ) - if epoch % 10 == 0: - final_weights = (torch.exp(weights) * r).detach().numpy() +def calibrate_local_areas( + dataset: UKSingleYearDataset, + matrix_fn, + national_matrix_fn, + area_count: int, + weight_file: str, + dataset_key: str = "2025", + epochs: int = 512, + excluded_training_targets=[], + log_csv=None, + verbose: bool = False, + area_name: str = "area", + get_performance=None, + nested_progress=None, +) -> UKSingleYearDataset: + """ + Calibrate local-area weights on CPU using the extracted optimisation loop. + """ + dataset = dataset.copy() + matrix, y, r = matrix_fn(dataset) + m_c, y_c = matrix.copy(), y.copy() + m_national, y_national = national_matrix_fn(dataset) + m_n, y_n = m_national.copy(), y_national.copy() - # Log performance if requested and get_performance function is available - if log_csv: - performance_step = get_performance( - final_weights, - m_c, - y_c, - m_n, - y_n, - excluded_training_targets, - ) - performance_step["epoch"] = epoch - performance_step["loss"] = performance_step.rel_abs_error**2 - performance_step["target_name"] = [ - f"{area}/{metric}" - for area, metric in zip( - performance_step.name, performance_step.metric - ) - ] - performance = pd.concat( - [performance, performance_step], ignore_index=True - ) - performance.to_csv(log_csv, index=False) + areas_per_household = r.sum(axis=0) + areas_per_household = np.maximum(areas_per_household, 1) + original_weights = np.log( + dataset.household.household_weight.values / areas_per_household + + np.random.random(len(dataset.household.household_weight.values)) + * 0.01 + ) + weights_init = ( + np.ones((area_count, len(original_weights))) * original_weights + ) - # Save weights - with h5py.File(STORAGE_FOLDER / weight_file, "w") as f: - f.create_dataset(dataset_key, data=final_weights) + validation_targets_local = ( + matrix.columns.isin(excluded_training_targets) + if hasattr(matrix, "columns") + else None + ) + validation_targets_national = ( + m_national.columns.isin(excluded_training_targets) + if hasattr(m_national, "columns") + else None + ) - dataset.household.household_weight = final_weights.sum(axis=0) + progress_tracker = ProcessingProgress() if verbose else None + + final_weights = _run_optimisation( + matrix_np=matrix.values if hasattr(matrix, "values") else matrix, + y_np=y.values if hasattr(y, "values") else y, + r_np=r, + matrix_national_np=( + m_national.values if hasattr(m_national, "values") else m_national + ), + y_national_np=( + y_national.values if hasattr(y_national, "values") else y_national + ), + weights_init_np=weights_init, + epochs=epochs, + device=torch.device("cpu"), + excluded_training_targets_local=validation_targets_local, + excluded_training_targets_national=validation_targets_national, + verbose=verbose, + area_name=area_name, + progress_tracker=progress_tracker, + nested_progress=nested_progress, + log_csv=log_csv, + get_performance=get_performance, + m_c_orig=m_c, + y_c_orig=y_c, + m_n_orig=m_n, + y_n_orig=y_n, + weight_file=weight_file, + dataset_key=dataset_key, + dataset=dataset, + ) + dataset.household.household_weight = final_weights.sum(axis=0) return dataset diff --git a/policyengine_uk_data/utils/modal_calibrate.py b/policyengine_uk_data/utils/modal_calibrate.py new file mode 100644 index 00000000..60bb1d88 --- /dev/null +++ b/policyengine_uk_data/utils/modal_calibrate.py @@ -0,0 +1,96 @@ +import io +import modal +import numpy as np + +app = modal.App("policyengine-uk-calibration") + +image_gpu = modal.Image.debian_slim().pip_install( + "torch", "numpy", "h5py", "pandas" +) + + +@app.function(gpu="A10G", image=image_gpu, timeout=3600, serialized=True) +def run_calibration( + matrix: bytes, + y: bytes, + r: bytes, + matrix_national: bytes, + y_national: bytes, + weights_init: bytes, + epochs: int, +) -> bytes: + """ + Run the Adam calibration loop on a GPU container. All arrays are + serialised with ``np.save`` / deserialised with ``np.load``. + + Returns checkpoints as [(epoch, weights_bytes), ...] every 10 epochs. + """ + import io + import numpy as np + import torch + + def load(b): + return np.load(io.BytesIO(b)) + + matrix_np = load(matrix) + y_np = load(y) + r_np = load(r) + matrix_national_np = load(matrix_national) + y_national_np = load(y_national) + weights_init_np = load(weights_init) + + device = torch.device("cuda") + + metrics = torch.tensor(matrix_np, dtype=torch.float32, device=device) + y_t = torch.tensor(y_np, dtype=torch.float32, device=device) + m_national = torch.tensor( + matrix_national_np, dtype=torch.float32, device=device + ) + y_nat = torch.tensor(y_national_np, dtype=torch.float32, device=device) + r_t = torch.tensor(r_np, dtype=torch.float32, device=device) + + weights = torch.tensor( + weights_init_np, + dtype=torch.float32, + device=device, + requires_grad=True, + ) + + def sre(x, y_ref): + one_way = ((1 + x) / (1 + y_ref) - 1) ** 2 + other_way = ((1 + y_ref) / (1 + x) - 1) ** 2 + return torch.min(one_way, other_way) + + def loss_fn(w): + pred_local = (w.unsqueeze(-1) * metrics.unsqueeze(0)).sum(dim=1) + mse_local = torch.mean(sre(pred_local, y_t)) + pred_national = (w.sum(axis=0) * m_national.T).sum(axis=1) + mse_national = torch.mean(sre(pred_national, y_nat)) + return mse_local + mse_national + + def dropout_weights(w, p): + if p == 0: + return w + mask = torch.rand_like(w) < p + mean = w[~mask].mean() + w2 = w.clone() + w2[mask] = mean + return w2 + + optimizer = torch.optim.Adam([weights], lr=1e-1) + checkpoints = [] + + for epoch in range(epochs): + optimizer.zero_grad() + weights_ = torch.exp(dropout_weights(weights, 0.05)) * r_t + l = loss_fn(weights_) + l.backward() + optimizer.step() + + if epoch % 10 == 0: + w = (torch.exp(weights) * r_t).detach().cpu().numpy() + buf = io.BytesIO() + np.save(buf, w) + checkpoints.append((epoch, buf.getvalue())) + + return checkpoints diff --git a/pyproject.toml b/pyproject.toml index 77c33c9a..6ba27a85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dev = [ "itables", "quantile-forest", "build", + "modal", ] [tool.setuptools]