|
| 1 | +"""Test for issue #10851: Dataset Index not included in to_dataframe when name differs from dimension.""" |
| 2 | +import numpy as np |
| 3 | +import pandas as pd |
| 4 | + |
| 5 | +import xarray as xr |
| 6 | + |
| 7 | + |
| 8 | +class TestToDataFrameIndexColumn: |
| 9 | + """Tests for to_dataframe including index coordinates with different names.""" |
| 10 | + |
| 11 | + def test_to_dataframe_includes_index_with_different_name(self): |
| 12 | + """Index coordinates with name different from dimension should be in columns.""" |
| 13 | + ds_temp = xr.Dataset( |
| 14 | + data_vars=dict(temp=(["time", "pos"], np.array([[5, 10, 15, 20, 25]]))), |
| 15 | + coords=dict( |
| 16 | + pf=("pos", [1.0, 2.0, 4.2, 8.0, 10.0]), |
| 17 | + time=("time", [pd.to_datetime("2025-01-01")]), |
| 18 | + ), |
| 19 | + ).set_xindex("pf") |
| 20 | + |
| 21 | + df = ds_temp.to_dataframe() |
| 22 | + |
| 23 | + assert "pf" in df.columns |
| 24 | + assert "temp" in df.columns |
| 25 | + np.testing.assert_array_equal(df["pf"].values, [1.0, 2.0, 4.2, 8.0, 10.0]) |
| 26 | + |
| 27 | + def test_to_dataframe_still_excludes_matching_dim_index(self): |
| 28 | + """Index coordinates where name matches dimension should not be in columns.""" |
| 29 | + ds = xr.Dataset( |
| 30 | + data_vars=dict(temp=(["x"], [1, 2, 3])), |
| 31 | + coords=dict(x=("x", [10, 20, 30])), |
| 32 | + ) |
| 33 | + |
| 34 | + df = ds.to_dataframe() |
| 35 | + |
| 36 | + assert "temp" in df.columns |
| 37 | + assert "x" not in df.columns |
| 38 | + |
| 39 | + def test_to_dataframe_roundtrip_with_set_xindex(self): |
| 40 | + """Dataset with set_xindex should roundtrip to DataFrame correctly.""" |
| 41 | + ds = xr.Dataset( |
| 42 | + data_vars=dict(val=(["dim"], [100, 200, 300])), |
| 43 | + coords=dict(coord_idx=("dim", ["a", "b", "c"])), |
| 44 | + ).set_xindex("coord_idx") |
| 45 | + |
| 46 | + df = ds.to_dataframe() |
| 47 | + |
| 48 | + assert "coord_idx" in df.columns |
| 49 | + assert "val" in df.columns |
| 50 | + assert list(df["coord_idx"]) == ["a", "b", "c"] |
| 51 | + assert list(df["val"]) == [100, 200, 300] |
0 commit comments