diff --git a/framework/py/flwr/serverapp/strategy/strategy_utils.py b/framework/py/flwr/serverapp/strategy/strategy_utils.py index 1053f8b96a41..d2b87a6867b9 100644 --- a/framework/py/flwr/serverapp/strategy/strategy_utils.py +++ b/framework/py/flwr/serverapp/strategy/strategy_utils.py @@ -111,26 +111,19 @@ def aggregate_metricrecords( ) -> MetricRecord: """Perform weighted aggregation all MetricRecords using a specific key.""" # Retrieve weighting factor from MetricRecord - weights: list[float] = [] - for record in records: - # Get the first (and only) MetricRecord in the record - metricrecord = next(iter(record.metric_records.values())) - # Because replies have been checked for consistency, - # we can safely cast the weighting factor to float - w = cast(float, metricrecord[weighting_metric_name]) - weights.append(w) + weights: list[float] = [ + cast(float, next(iter(record.metric_records.values()))[weighting_metric_name]) + for record in records + ] - # Average total_weight = sum(weights) weight_factors = [w / total_weight for w in weights] aggregated_metrics = MetricRecord() for record, weight in zip(records, weight_factors): for record_item in record.metric_records.values(): - # aggregate in-place for key, value in record_item.items(): if key == weighting_metric_name: - # We exclude the weighting key from the aggregated MetricRecord continue if key not in aggregated_metrics: if isinstance(value, list): @@ -139,14 +132,14 @@ def aggregate_metricrecords( aggregated_metrics[key] = value * weight else: if isinstance(value, list): - current_list = cast(list[float], aggregated_metrics[key]) - aggregated_metrics[key] = [ - curr + val * weight - for curr, val in zip(current_list, value) - ] + curr_list = cast(list[float], aggregated_metrics[key]) + # avoid zip if possible: pre-allocated, update in place + for i, val in enumerate(value): + curr_list[i] += val * weight else: - current_value = cast(float, aggregated_metrics[key]) - aggregated_metrics[key] = current_value + value * weight + aggregated_metrics[key] = ( + cast(float, aggregated_metrics[key]) + value * weight + ) return aggregated_metrics