Skip to content

API Reference

tsam_xarray

tsam_xarray: Lightweight xarray wrapper for tsam time series aggregation.

ClusteringInfo module-attribute

ClusteringInfo = ClusteringResult

Backwards-compatible alias for :class:ClusteringResult.

ClusteringResult dataclass

Reusable clustering result with xarray dimension metadata.

Wraps one or more tsam ClusteringResult objects alongside the dimension names needed to apply the clustering to new data. Exposes clustering metadata as cached xarray DataArrays.

Attributes:

Name Type Description
time_dim str

Name of the time dimension.

cluster_dim list[str]

Dimension(s) clustered together.

slice_dims list[str]

Dimension(s) aggregated independently.

clusterings dict[tuple[Hashable, ...], ClusteringResult]

Per-slice tsam clustering. Single entry {(): result} when no slicing.

n_clusters int

Number of clusters.

n_original_periods int

Number of original periods.

n_timesteps_per_period int

Timesteps per period.

n_segments int | None

Segments per period, or None.

cluster_assignments DataArray

Cluster ID per period. Dims: (period, *slice_dims).

cluster_occurrences DataArray

Periods per cluster. Dims: (cluster, *slice_dims).

cluster_centers DataArray

Representative period per cluster. Dims: (cluster, *slice_dims).

segment_durations DataArray | None

Duration per segment, or None. Dims: (cluster, timestep, *slice_dims).

segment_assignments DataArray | None

Segment ID per timestep, or None. Dims: (cluster, timestep, *slice_dims).

segment_centers DataArray | None

Representative timestep per segment, or None. Dims: (cluster, segment, *slice_dims).

Source code in src/tsam_xarray/_clustering.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
@dataclass(frozen=True, repr=False)
class ClusteringResult:
    """Reusable clustering result with xarray dimension metadata.

    Wraps one or more tsam ``ClusteringResult`` objects alongside
    the dimension names needed to apply the clustering to new data.
    Exposes clustering metadata as cached xarray DataArrays.

    Attributes:
        time_dim: Name of the time dimension.
        cluster_dim: Dimension(s) clustered together.
        slice_dims: Dimension(s) aggregated independently.
        clusterings: Per-slice tsam clustering.
            Single entry ``{(): result}`` when no slicing.
        n_clusters: Number of clusters.
        n_original_periods: Number of original periods.
        n_timesteps_per_period: Timesteps per period.
        n_segments: Segments per period, or ``None``.
        cluster_assignments: Cluster ID per period.
            Dims: ``(period, *slice_dims)``.
        cluster_occurrences: Periods per cluster.
            Dims: ``(cluster, *slice_dims)``.
        cluster_centers: Representative period per cluster.
            Dims: ``(cluster, *slice_dims)``.
        segment_durations: Duration per segment, or ``None``.
            Dims: ``(cluster, timestep, *slice_dims)``.
        segment_assignments: Segment ID per timestep, or
            ``None``. Dims: ``(cluster, timestep,
            *slice_dims)``.
        segment_centers: Representative timestep per segment,
            or ``None``.
            Dims: ``(cluster, segment, *slice_dims)``.
    """

    time_dim: str
    cluster_dim: list[str]
    slice_dims: list[str]
    clusterings: dict[tuple[Hashable, ...], tsam.ClusteringResult]
    _cache: dict[str, Any] = field(
        default_factory=dict, repr=False, init=False, compare=False
    )

    def __repr__(self) -> str:
        seg = f", n_segments={self.n_segments}" if self.n_segments else ""
        slices = f", slice_dims={self.slice_dims}" if self.slice_dims else ""
        return (
            f"ClusteringResult("
            f"n_clusters={self.n_clusters}, "
            f"n_periods={self.n_original_periods}, "
            f"timesteps_per_period={self.n_timesteps_per_period}, "
            f"time_dim={self.time_dim!r}, "
            f"cluster_dim={self.cluster_dim}"
            f"{slices}{seg})"
        )

    # -- scalar accessors (uniform across slices) --

    @property
    def n_clusters(self) -> int:
        """Number of clusters."""
        return next(iter(self.clusterings.values())).n_clusters

    @property
    def n_original_periods(self) -> int:
        """Number of original periods (e.g., days)."""
        return next(iter(self.clusterings.values())).n_original_periods

    @property
    def n_timesteps_per_period(self) -> int:
        """Number of timesteps per period (e.g., 24 for hourly with daily periods)."""
        return next(iter(self.clusterings.values())).n_timesteps_per_period

    @property
    def n_segments(self) -> int | None:
        """Number of segments per period, or None if no segmentation."""
        return next(iter(self.clusterings.values())).n_segments

    # -- DataArray properties (cached, concatenated across slices) --

    @property
    def _slice_coords(self) -> dict[str, Any]:
        """Reconstruct slice coordinates from clusterings keys."""
        if not self.slice_dims:
            return {}
        keys = list(self.clusterings.keys())
        return {
            dim: list(dict.fromkeys(k[i] for k in keys))
            for i, dim in enumerate(self.slice_dims)
        }

    @property
    def cluster_assignments(self) -> xr.DataArray:
        """Cluster assignment for each period, as DataArray.

        Dims: ``(period, *slice_dims)``.
        """
        if "cluster_assignments" not in self._cache:
            self._cache["cluster_assignments"] = self._build_assignments()
        result: xr.DataArray = self._cache["cluster_assignments"]
        return result

    def _build_assignments(self) -> xr.DataArray:
        if not self.slice_dims:
            cr = self.clusterings[()]
            return xr.DataArray(list(cr.cluster_assignments), dims=["period"])

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        arrays = [
            xr.DataArray(list(self.clusterings[k].cluster_assignments), dims=["period"])
            for k in keys
        ]
        return _concat_along_dims(arrays, self.slice_dims, sc)

    @property
    def cluster_occurrences(self) -> xr.DataArray:
        """Number of periods assigned to each cluster.

        Dims: ``(cluster, *slice_dims)``.
        """
        if "cluster_occurrences" not in self._cache:
            self._cache["cluster_occurrences"] = self._build_occurrences()
        result: xr.DataArray = self._cache["cluster_occurrences"]
        return result

    def _build_occurrences(self) -> xr.DataArray:
        def _single(cr: tsam.ClusteringResult) -> xr.DataArray:
            counts = np.bincount(cr.cluster_assignments, minlength=cr.n_clusters)
            return xr.DataArray(
                counts,
                dims=["cluster"],
                coords={"cluster": np.arange(cr.n_clusters)},
            )

        if not self.slice_dims:
            return _single(self.clusterings[()])

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        arrays = [_single(self.clusterings[k]) for k in keys]
        return _concat_along_dims(arrays, self.slice_dims, sc)

    @property
    def segment_durations(self) -> xr.DataArray | None:
        """Duration of each segment per cluster, or None if no segmentation.

        Dims: ``(cluster, timestep, *slice_dims)``.
        """
        if "segment_durations" not in self._cache:
            self._cache["segment_durations"] = self._build_segment_durations()
        result: xr.DataArray | None = self._cache["segment_durations"]
        return result

    def _build_segment_durations(self) -> xr.DataArray | None:
        if not self.slice_dims:
            return _segment_durations_to_da(self.clusterings[()].segment_durations)

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        first = _segment_durations_to_da(self.clusterings[keys[0]].segment_durations)
        if first is None:
            return None
        das: list[xr.DataArray] = [first]
        for k in keys[1:]:
            da = _segment_durations_to_da(self.clusterings[k].segment_durations)
            if da is None:
                msg = (
                    f"Slice {k} has no segment durations but the first "
                    f"slice does. Segmentation must be uniform across slices."
                )
                raise ValueError(msg)
            das.append(da)
        return _concat_along_dims(das, self.slice_dims, sc)

    @property
    def cluster_centers(self) -> xr.DataArray:
        """Representative period index for each cluster.

        Dims: ``(cluster, *slice_dims)``.
        """
        if "cluster_centers" not in self._cache:
            self._cache["cluster_centers"] = self._build_cluster_centers()
        result: xr.DataArray = self._cache["cluster_centers"]
        return result

    def _build_cluster_centers(self) -> xr.DataArray:
        def _single(cr: tsam.ClusteringResult) -> xr.DataArray:
            centers = cr.cluster_centers
            if centers is None:
                msg = "No cluster centers available."
                raise ValueError(msg)
            return xr.DataArray(
                list(centers),
                dims=["cluster"],
                coords={"cluster": np.arange(cr.n_clusters)},
            )

        if not self.slice_dims:
            return _single(self.clusterings[()])

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        arrays = [_single(self.clusterings[k]) for k in keys]
        return _concat_along_dims(arrays, self.slice_dims, sc)

    @property
    def segment_assignments(self) -> xr.DataArray | None:
        """Segment assignment for each timestep per cluster, or None.

        Dims: ``(cluster, timestep, *slice_dims)``.
        """
        if "segment_assignments" not in self._cache:
            self._cache["segment_assignments"] = self._build_segment_assignments()
        result: xr.DataArray | None = self._cache["segment_assignments"]
        return result

    def _build_segment_assignments(self) -> xr.DataArray | None:
        def _single(cr: tsam.ClusteringResult) -> xr.DataArray | None:
            if cr.segment_assignments is None:
                return None
            return xr.DataArray(
                np.array(cr.segment_assignments),
                dims=["cluster", "timestep"],
                coords={
                    "cluster": np.arange(cr.n_clusters),
                    "timestep": np.arange(cr.n_timesteps_per_period),
                },
            )

        if not self.slice_dims:
            return _single(self.clusterings[()])

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        first = _single(self.clusterings[keys[0]])
        if first is None:
            return None
        das: list[xr.DataArray] = [first]
        for k in keys[1:]:
            da = _single(self.clusterings[k])
            if da is None:
                msg = (
                    f"Slice {k} has no segment assignments but the first "
                    f"slice does. Segmentation must be uniform across slices."
                )
                raise ValueError(msg)
            das.append(da)
        return _concat_along_dims(das, self.slice_dims, sc)

    @property
    def segment_centers(self) -> xr.DataArray | None:
        """Representative timestep index for each segment per cluster, or None.

        Dims: ``(cluster, segment, *slice_dims)``.
        """
        if "segment_centers" not in self._cache:
            self._cache["segment_centers"] = self._build_segment_centers()
        result: xr.DataArray | None = self._cache["segment_centers"]
        return result

    def _build_segment_centers(self) -> xr.DataArray | None:
        def _single(cr: tsam.ClusteringResult) -> xr.DataArray | None:
            if cr.segment_centers is None:
                return None
            n_segments = cr.n_segments or len(cr.segment_centers[0])
            return xr.DataArray(
                np.array(cr.segment_centers),
                dims=["cluster", "segment"],
                coords={
                    "cluster": np.arange(cr.n_clusters),
                    "segment": np.arange(n_segments),
                },
            )

        if not self.slice_dims:
            return _single(self.clusterings[()])

        import itertools

        sc = self._slice_coords
        keys = list(itertools.product(*(sc[d] for d in self.slice_dims)))
        first = _single(self.clusterings[keys[0]])
        if first is None:
            return None
        das: list[xr.DataArray] = [first]
        for k in keys[1:]:
            da = _single(self.clusterings[k])
            if da is None:
                msg = (
                    f"Slice {k} has no segment centers but the first "
                    f"slice does. Segmentation must be uniform across slices."
                )
                raise ValueError(msg)
            das.append(da)
        return _concat_along_dims(das, self.slice_dims, sc)

    def apply(
        self,
        da: xr.DataArray,
        *,
        time_dim: str | None = None,
        cluster_dim: Sequence[str] | str | None = None,
        **tsam_kwargs: Any,
    ) -> Any:
        """Apply this clustering to new data.

        Args:
            da: New data with compatible time dimension
                length.
            time_dim: Time dimension name. Defaults to the
                stored value.
            cluster_dim: Cluster dimension(s). Defaults to the
                stored value. Can differ from the original if
                the new data has different dimension names.
            **tsam_kwargs: Additional keyword arguments passed
                to ``ClusteringResult.apply()``.

        Returns:
            Aggregation result using the stored clustering.
        """
        from tsam_xarray._result import AggregationResult

        td = time_dim if time_dim is not None else self.time_dim
        cd = (
            _resolve_cluster_dim(cluster_dim)
            if cluster_dim is not None
            else self.cluster_dim
        )

        _validate_apply(da, td, cd, self.slice_dims, self.clusterings)

        # Use stored slice_dims for canonical ordering
        slice_dims = self.slice_dims

        if not slice_dims:
            cr = self.clusterings[()]
            return _apply_single(da, cr, td, cd, tsam_kwargs)

        import itertools

        slice_coords: dict[str, Any] = {d: da.coords[d].values for d in slice_dims}
        slice_keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))

        results: list[AggregationResult] = []

        for key in slice_keys:
            sel = dict(zip(slice_dims, key, strict=True))
            da_slice = da.sel(sel)
            cr = _lookup_clustering(self.clusterings, key)
            r = _apply_single(da_slice, cr, td, cd, tsam_kwargs)
            results.append(r)

        return _concat_results(results, slice_dims, slice_coords, slice_keys)

    def disaggregate(self, data: xr.DataArray) -> xr.DataArray:
        """Map data on ``(cluster, timestep)`` back to original time.

        This is the inverse of ``aggregate()``. Use it to expand
        data computed on the compact cluster-representative grid
        (e.g., optimization results) back to the full time axis.

        Unlike ``AggregationResult.disaggregate()``, this method
        works on a ``ClusteringInfo`` loaded from JSON — no
        original data needed.

        Args:
            data: Data with ``cluster`` and ``timestep`` dims,
                matching the shape of the original cluster
                representatives. Additional dims (including
                auto-sliced dims like scenario) are supported.

        Returns:
            Data with ``cluster`` and ``timestep`` replaced by
            the original ``time`` dimension.
        """
        slice_dims = self.slice_dims
        if not slice_dims:
            return _disaggregate_single(self.clusterings[()], data)

        import itertools

        slice_coords = {d: data.coords[d].values for d in slice_dims}
        keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
        results = []
        for key in keys:
            sel = dict(zip(slice_dims, key, strict=True))
            data_slice = data.sel(sel)
            cr = _lookup_clustering(self.clusterings, key)
            results.append(_disaggregate_single(cr, data_slice))

        return _concat_along_dims(results, slice_dims, slice_coords)

    def to_dict(self) -> dict[str, Any]:
        """Serialize clustering to a dictionary.

        Returns:
            Plain dict suitable for ``json.dump()`` or
            storage in databases, APIs, etc.
        """
        entries = []
        for key, cr in self.clusterings.items():
            entries.append(
                {
                    "key": list(_native_key(key)),
                    "clustering": cr.to_dict(),
                }
            )
        return {
            "time_dim": self.time_dim,
            "cluster_dim": self.cluster_dim,
            "slice_dims": self.slice_dims,
            "clusterings": entries,
        }

    def to_json(self, path: str | Path, **json_kwargs: Any) -> None:
        """Save clustering to JSON file.

        Args:
            path: Output file path.
            **json_kwargs: Additional keyword arguments passed
                to ``json.dump()``. Default: ``indent=2``.
        """
        with Path(path).open("w") as f:
            json.dump(self.to_dict(), f, **json_kwargs)

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> ClusteringResult:
        """Load clustering from a dictionary.

        Args:
            data: Dict as returned by :meth:`to_dict`.

        Returns:
            The loaded ``ClusteringResult``.
        """
        # Backcompat: pre-0.6 wrappers stored the time index as an outer
        # ``time_coords`` key while the inner tsam blob (written by tsam<3.4)
        # had no ``time_index``. Forward it so disaggregate keeps datetimes.
        if "time_coords" in data:
            import warnings

            warnings.warn(
                "Loading a legacy tsam_xarray JSON with an outer 'time_coords' "
                "field; re-save with to_json() to silence this warning.",
                DeprecationWarning,
                stacklevel=2,
            )
            for entry in data["clusterings"]:
                entry["clustering"].setdefault("time_index", data["time_coords"])

        clusterings: dict[tuple[Hashable, ...], tsam.ClusteringResult] = {}
        for entry in data["clusterings"]:
            key = tuple(entry["key"])
            clusterings[key] = tsam.ClusteringResult.from_dict(entry["clustering"])

        return cls(
            time_dim=data["time_dim"],
            cluster_dim=data["cluster_dim"],
            slice_dims=data.get("slice_dims", []),
            clusterings=clusterings,
        )

    @classmethod
    def from_json(cls, path: str | Path) -> ClusteringResult:
        """Load clustering from JSON file.

        Args:
            path: Input file path.

        Returns:
            The loaded ``ClusteringResult``.
        """
        with Path(path).open() as f:
            return cls.from_dict(json.load(f))

n_clusters property

n_clusters: int

Number of clusters.

n_original_periods property

n_original_periods: int

Number of original periods (e.g., days).

n_timesteps_per_period property

n_timesteps_per_period: int

Number of timesteps per period (e.g., 24 for hourly with daily periods).

n_segments property

n_segments: int | None

Number of segments per period, or None if no segmentation.

cluster_assignments property

cluster_assignments: DataArray

Cluster assignment for each period, as DataArray.

Dims: (period, *slice_dims).

cluster_occurrences property

cluster_occurrences: DataArray

Number of periods assigned to each cluster.

Dims: (cluster, *slice_dims).

segment_durations property

segment_durations: DataArray | None

Duration of each segment per cluster, or None if no segmentation.

Dims: (cluster, timestep, *slice_dims).

cluster_centers property

cluster_centers: DataArray

Representative period index for each cluster.

Dims: (cluster, *slice_dims).

segment_assignments property

segment_assignments: DataArray | None

Segment assignment for each timestep per cluster, or None.

Dims: (cluster, timestep, *slice_dims).

segment_centers property

segment_centers: DataArray | None

Representative timestep index for each segment per cluster, or None.

Dims: (cluster, segment, *slice_dims).

apply

apply(
    da: DataArray,
    *,
    time_dim: str | None = None,
    cluster_dim: Sequence[str] | str | None = None,
    **tsam_kwargs: Any,
) -> Any

Apply this clustering to new data.

Parameters:

Name Type Description Default
da DataArray

New data with compatible time dimension length.

required
time_dim str | None

Time dimension name. Defaults to the stored value.

None
cluster_dim Sequence[str] | str | None

Cluster dimension(s). Defaults to the stored value. Can differ from the original if the new data has different dimension names.

None
**tsam_kwargs Any

Additional keyword arguments passed to ClusteringResult.apply().

{}

Returns:

Type Description
Any

Aggregation result using the stored clustering.

Source code in src/tsam_xarray/_clustering.py
def apply(
    self,
    da: xr.DataArray,
    *,
    time_dim: str | None = None,
    cluster_dim: Sequence[str] | str | None = None,
    **tsam_kwargs: Any,
) -> Any:
    """Apply this clustering to new data.

    Args:
        da: New data with compatible time dimension
            length.
        time_dim: Time dimension name. Defaults to the
            stored value.
        cluster_dim: Cluster dimension(s). Defaults to the
            stored value. Can differ from the original if
            the new data has different dimension names.
        **tsam_kwargs: Additional keyword arguments passed
            to ``ClusteringResult.apply()``.

    Returns:
        Aggregation result using the stored clustering.
    """
    from tsam_xarray._result import AggregationResult

    td = time_dim if time_dim is not None else self.time_dim
    cd = (
        _resolve_cluster_dim(cluster_dim)
        if cluster_dim is not None
        else self.cluster_dim
    )

    _validate_apply(da, td, cd, self.slice_dims, self.clusterings)

    # Use stored slice_dims for canonical ordering
    slice_dims = self.slice_dims

    if not slice_dims:
        cr = self.clusterings[()]
        return _apply_single(da, cr, td, cd, tsam_kwargs)

    import itertools

    slice_coords: dict[str, Any] = {d: da.coords[d].values for d in slice_dims}
    slice_keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))

    results: list[AggregationResult] = []

    for key in slice_keys:
        sel = dict(zip(slice_dims, key, strict=True))
        da_slice = da.sel(sel)
        cr = _lookup_clustering(self.clusterings, key)
        r = _apply_single(da_slice, cr, td, cd, tsam_kwargs)
        results.append(r)

    return _concat_results(results, slice_dims, slice_coords, slice_keys)

disaggregate

disaggregate(data: DataArray) -> xr.DataArray

Map data on (cluster, timestep) back to original time.

This is the inverse of aggregate(). Use it to expand data computed on the compact cluster-representative grid (e.g., optimization results) back to the full time axis.

Unlike AggregationResult.disaggregate(), this method works on a ClusteringInfo loaded from JSON — no original data needed.

Parameters:

Name Type Description Default
data DataArray

Data with cluster and timestep dims, matching the shape of the original cluster representatives. Additional dims (including auto-sliced dims like scenario) are supported.

required

Returns:

Type Description
DataArray

Data with cluster and timestep replaced by

DataArray

the original time dimension.

Source code in src/tsam_xarray/_clustering.py
def disaggregate(self, data: xr.DataArray) -> xr.DataArray:
    """Map data on ``(cluster, timestep)`` back to original time.

    This is the inverse of ``aggregate()``. Use it to expand
    data computed on the compact cluster-representative grid
    (e.g., optimization results) back to the full time axis.

    Unlike ``AggregationResult.disaggregate()``, this method
    works on a ``ClusteringInfo`` loaded from JSON — no
    original data needed.

    Args:
        data: Data with ``cluster`` and ``timestep`` dims,
            matching the shape of the original cluster
            representatives. Additional dims (including
            auto-sliced dims like scenario) are supported.

    Returns:
        Data with ``cluster`` and ``timestep`` replaced by
        the original ``time`` dimension.
    """
    slice_dims = self.slice_dims
    if not slice_dims:
        return _disaggregate_single(self.clusterings[()], data)

    import itertools

    slice_coords = {d: data.coords[d].values for d in slice_dims}
    keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
    results = []
    for key in keys:
        sel = dict(zip(slice_dims, key, strict=True))
        data_slice = data.sel(sel)
        cr = _lookup_clustering(self.clusterings, key)
        results.append(_disaggregate_single(cr, data_slice))

    return _concat_along_dims(results, slice_dims, slice_coords)

to_dict

to_dict() -> dict[str, Any]

Serialize clustering to a dictionary.

Returns:

Type Description
dict[str, Any]

Plain dict suitable for json.dump() or

dict[str, Any]

storage in databases, APIs, etc.

Source code in src/tsam_xarray/_clustering.py
def to_dict(self) -> dict[str, Any]:
    """Serialize clustering to a dictionary.

    Returns:
        Plain dict suitable for ``json.dump()`` or
        storage in databases, APIs, etc.
    """
    entries = []
    for key, cr in self.clusterings.items():
        entries.append(
            {
                "key": list(_native_key(key)),
                "clustering": cr.to_dict(),
            }
        )
    return {
        "time_dim": self.time_dim,
        "cluster_dim": self.cluster_dim,
        "slice_dims": self.slice_dims,
        "clusterings": entries,
    }

to_json

to_json(path: str | Path, **json_kwargs: Any) -> None

Save clustering to JSON file.

Parameters:

Name Type Description Default
path str | Path

Output file path.

required
**json_kwargs Any

Additional keyword arguments passed to json.dump(). Default: indent=2.

{}
Source code in src/tsam_xarray/_clustering.py
def to_json(self, path: str | Path, **json_kwargs: Any) -> None:
    """Save clustering to JSON file.

    Args:
        path: Output file path.
        **json_kwargs: Additional keyword arguments passed
            to ``json.dump()``. Default: ``indent=2``.
    """
    with Path(path).open("w") as f:
        json.dump(self.to_dict(), f, **json_kwargs)

from_dict classmethod

from_dict(data: dict[str, Any]) -> ClusteringResult

Load clustering from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dict as returned by :meth:to_dict.

required

Returns:

Type Description
ClusteringResult

The loaded ClusteringResult.

Source code in src/tsam_xarray/_clustering.py
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ClusteringResult:
    """Load clustering from a dictionary.

    Args:
        data: Dict as returned by :meth:`to_dict`.

    Returns:
        The loaded ``ClusteringResult``.
    """
    # Backcompat: pre-0.6 wrappers stored the time index as an outer
    # ``time_coords`` key while the inner tsam blob (written by tsam<3.4)
    # had no ``time_index``. Forward it so disaggregate keeps datetimes.
    if "time_coords" in data:
        import warnings

        warnings.warn(
            "Loading a legacy tsam_xarray JSON with an outer 'time_coords' "
            "field; re-save with to_json() to silence this warning.",
            DeprecationWarning,
            stacklevel=2,
        )
        for entry in data["clusterings"]:
            entry["clustering"].setdefault("time_index", data["time_coords"])

    clusterings: dict[tuple[Hashable, ...], tsam.ClusteringResult] = {}
    for entry in data["clusterings"]:
        key = tuple(entry["key"])
        clusterings[key] = tsam.ClusteringResult.from_dict(entry["clustering"])

    return cls(
        time_dim=data["time_dim"],
        cluster_dim=data["cluster_dim"],
        slice_dims=data.get("slice_dims", []),
        clusterings=clusterings,
    )

from_json classmethod

from_json(path: str | Path) -> ClusteringResult

Load clustering from JSON file.

Parameters:

Name Type Description Default
path str | Path

Input file path.

required

Returns:

Type Description
ClusteringResult

The loaded ClusteringResult.

Source code in src/tsam_xarray/_clustering.py
@classmethod
def from_json(cls, path: str | Path) -> ClusteringResult:
    """Load clustering from JSON file.

    Args:
        path: Input file path.

    Returns:
        The loaded ``ClusteringResult``.
    """
    with Path(path).open() as f:
        return cls.from_dict(json.load(f))

AccuracyMetrics dataclass

Accuracy metrics from time series aggregation.

Attributes:

Name Type Description
rmse DataArray

Per-column RMSE. Dims: (*cluster_dims, *slice_dims).

mae DataArray

Per-column MAE. Dims: (*cluster_dims, *slice_dims).

rmse_duration DataArray

Per-column duration-curve RMSE. Dims: (*cluster_dims, *slice_dims).

weighted_rmse DataArray

RMSE weighted across columns. Dims: (*slice_dims) or scalar.

weighted_mae DataArray

MAE weighted across columns. Dims: (*slice_dims) or scalar.

weighted_rmse_duration DataArray

Duration-curve RMSE weighted across columns. Dims: (*slice_dims) or scalar.

Source code in src/tsam_xarray/_result.py
@dataclass(frozen=True, repr=False)
class AccuracyMetrics:
    """Accuracy metrics from time series aggregation.

    Attributes:
        rmse: Per-column RMSE.
            Dims: ``(*cluster_dims, *slice_dims)``.
        mae: Per-column MAE.
            Dims: ``(*cluster_dims, *slice_dims)``.
        rmse_duration: Per-column duration-curve RMSE.
            Dims: ``(*cluster_dims, *slice_dims)``.
        weighted_rmse: RMSE weighted across columns.
            Dims: ``(*slice_dims)`` or scalar.
        weighted_mae: MAE weighted across columns.
            Dims: ``(*slice_dims)`` or scalar.
        weighted_rmse_duration: Duration-curve RMSE weighted
            across columns.
            Dims: ``(*slice_dims)`` or scalar.
    """

    rmse: xr.DataArray
    mae: xr.DataArray
    rmse_duration: xr.DataArray
    weighted_rmse: xr.DataArray
    weighted_mae: xr.DataArray
    weighted_rmse_duration: xr.DataArray

    def __repr__(self) -> str:
        def _fmt(da: xr.DataArray) -> str:
            mean = float(da.mean())
            if da.size <= 1:
                return f"{mean:.4f}"
            return f"{mean:.4f} [{float(da.min()):.4f}-{float(da.max()):.4f}]"

        return (
            f"AccuracyMetrics("
            f"weighted_rmse={_fmt(self.weighted_rmse)}, "
            f"weighted_mae={_fmt(self.weighted_mae)}, "
            f"weighted_rmse_duration="
            f"{_fmt(self.weighted_rmse_duration)})"
        )

AggregationResult dataclass

Result of tsam_xarray.aggregate().

Attributes:

Name Type Description
cluster_representatives DataArray

Typical periods. Dims: (cluster, timestep, *cluster_dims, *slice_dims).

cluster_assignments DataArray

Which cluster each period belongs to. Dims: (period, *slice_dims).

cluster_weights DataArray

Periods per cluster. Dims: (cluster, *slice_dims).

segment_durations DataArray | None

Duration of each segment, or None. Dims: (cluster, timestep, *slice_dims).

accuracy AccuracyMetrics

Per-column and weighted accuracy metrics.

reconstructed DataArray

Reconstructed time series (same shape as input).

original DataArray

The input data.

clustering ClusteringResult

Reusable clustering metadata. See ClusteringResult.

is_transferred bool

Whether this result came from apply() vs aggregate().

Source code in src/tsam_xarray/_result.py
@dataclass(frozen=True, repr=False)
class AggregationResult:
    """Result of ``tsam_xarray.aggregate()``.

    Attributes:
        cluster_representatives: Typical periods.
            Dims: ``(cluster, timestep, *cluster_dims,
            *slice_dims)``.
        cluster_assignments: Which cluster each period
            belongs to. Dims: ``(period, *slice_dims)``.
        cluster_weights: Periods per cluster.
            Dims: ``(cluster, *slice_dims)``.
        segment_durations: Duration of each segment, or
            ``None``. Dims: ``(cluster, timestep,
            *slice_dims)``.
        accuracy: Per-column and weighted accuracy metrics.
        reconstructed: Reconstructed time series
            (same shape as input).
        original: The input data.
        clustering: Reusable clustering metadata.
            See `ClusteringResult`.
        is_transferred: Whether this result came from
            ``apply()`` vs ``aggregate()``.
    """

    cluster_representatives: xr.DataArray
    cluster_assignments: xr.DataArray
    cluster_weights: xr.DataArray
    segment_durations: xr.DataArray | None
    accuracy: AccuracyMetrics
    reconstructed: xr.DataArray
    original: xr.DataArray
    clustering: ClusteringResult
    is_transferred: bool = False

    def __repr__(self) -> str:
        c = self.clustering
        slices = f", slice_dims={c.slice_dims}" if c.slice_dims else ""
        seg = f", n_segments={self.n_segments}" if self.n_segments else ""
        return (
            f"AggregationResult("
            f"n_clusters={self.n_clusters}, "
            f"n_periods={c.n_original_periods}, "
            f"cluster_dim={c.cluster_dim}"
            f"{slices}{seg}, "
            f"weighted_rmse={float(self.accuracy.weighted_rmse.mean()):.4f})"
        )

    @property
    def n_clusters(self) -> int:
        """Number of cluster representative clusters."""
        return int(self.cluster_weights.sizes["cluster"])

    @property
    def n_timesteps_per_period(self) -> int:
        """Number of timesteps per cluster representative."""
        return int(self.cluster_representatives.sizes["timestep"])

    @property
    def n_segments(self) -> int | None:
        """Number of segments per period, if segmentation was used."""
        first_cr = next(iter(self.clustering.clusterings.values()))
        result: int | None = first_cr.n_segments
        return result

    @property
    def residuals(self) -> xr.DataArray:
        """Difference between original and reconstructed data."""
        return self.original - self.reconstructed

    def disaggregate(self, data: xr.DataArray) -> xr.DataArray:
        """Map data on ``(cluster, timestep)`` back to original time.

        This is the inverse of ``aggregate()``. Use it to expand
        external data computed on the compact cluster-representative
        grid (e.g., optimization results) back to the full time
        axis.

        Without segmentation, values are repeated for each timestep
        in the period. With segmentation, values are placed at
        segment boundaries and remaining timesteps are NaN — use
        ``.ffill(dim="time")``,
        ``.interpolate_na(dim="time")``, etc.

        Args:
            data: Data with ``cluster`` and ``timestep`` dims,
                matching the shape of
                ``result.cluster_representatives``. Additional
                dims (including auto-sliced dims like scenario)
                are supported.

        Returns:
            Data with ``cluster`` and ``timestep`` replaced by
            the original ``time`` dimension.
        """
        # Use stored slice_dims for canonical ordering
        slice_dims = self.clustering.slice_dims
        if not slice_dims:
            return self._disaggregate_single(data)

        import itertools

        from tsam_xarray._core import _concat_along_dims

        slice_coords = {d: data.coords[d].values for d in slice_dims}
        keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
        results = []
        for key in keys:
            sel = dict(zip(slice_dims, key, strict=True))
            data_slice = data.sel(sel)
            result_slice = self._make_slice_view(sel)
            results.append(result_slice._disaggregate_single(data_slice))

        return _concat_along_dims(results, slice_dims, slice_coords)

    def _make_slice_view(self, sel: dict[str, object]) -> AggregationResult:
        """Create a view of this result for a single slice."""
        from tsam_xarray._clustering import (
            ClusteringResult as CR,
        )
        from tsam_xarray._clustering import (
            _lookup_clustering,
        )

        # Build key in stored slice_dims order
        key = tuple(sel[d] for d in self.clustering.slice_dims)
        cr = _lookup_clustering(self.clustering.clusterings, key)

        return AggregationResult(
            cluster_representatives=self.cluster_representatives.sel(sel),
            cluster_assignments=self.cluster_assignments.sel(sel),
            cluster_weights=self.cluster_weights.sel(sel),
            segment_durations=(
                self.segment_durations.sel(sel)
                if self.segment_durations is not None
                else None
            ),
            accuracy=AccuracyMetrics(
                rmse=self.accuracy.rmse.sel(sel),
                mae=self.accuracy.mae.sel(sel),
                rmse_duration=self.accuracy.rmse_duration.sel(sel),
                weighted_rmse=self.accuracy.weighted_rmse.sel(sel),
                weighted_mae=self.accuracy.weighted_mae.sel(sel),
                weighted_rmse_duration=self.accuracy.weighted_rmse_duration.sel(sel),
            ),
            reconstructed=self.reconstructed.sel(sel),
            original=self.original.sel(sel),
            clustering=CR(
                time_dim=self.clustering.time_dim,
                cluster_dim=self.clustering.cluster_dim,
                slice_dims=[],
                clusterings={(): cr},
            ),
        )

    def _disaggregate_single(self, data: xr.DataArray) -> xr.DataArray:
        """Disaggregate without slice dims."""
        from tsam_xarray._clustering import _disaggregate_single

        cr = self.clustering.clusterings[()]
        return _disaggregate_single(cr, data)

n_clusters property

n_clusters: int

Number of cluster representative clusters.

n_timesteps_per_period property

n_timesteps_per_period: int

Number of timesteps per cluster representative.

n_segments property

n_segments: int | None

Number of segments per period, if segmentation was used.

residuals property

residuals: DataArray

Difference between original and reconstructed data.

disaggregate

disaggregate(data: DataArray) -> xr.DataArray

Map data on (cluster, timestep) back to original time.

This is the inverse of aggregate(). Use it to expand external data computed on the compact cluster-representative grid (e.g., optimization results) back to the full time axis.

Without segmentation, values are repeated for each timestep in the period. With segmentation, values are placed at segment boundaries and remaining timesteps are NaN — use .ffill(dim="time"), .interpolate_na(dim="time"), etc.

Parameters:

Name Type Description Default
data DataArray

Data with cluster and timestep dims, matching the shape of result.cluster_representatives. Additional dims (including auto-sliced dims like scenario) are supported.

required

Returns:

Type Description
DataArray

Data with cluster and timestep replaced by

DataArray

the original time dimension.

Source code in src/tsam_xarray/_result.py
def disaggregate(self, data: xr.DataArray) -> xr.DataArray:
    """Map data on ``(cluster, timestep)`` back to original time.

    This is the inverse of ``aggregate()``. Use it to expand
    external data computed on the compact cluster-representative
    grid (e.g., optimization results) back to the full time
    axis.

    Without segmentation, values are repeated for each timestep
    in the period. With segmentation, values are placed at
    segment boundaries and remaining timesteps are NaN — use
    ``.ffill(dim="time")``,
    ``.interpolate_na(dim="time")``, etc.

    Args:
        data: Data with ``cluster`` and ``timestep`` dims,
            matching the shape of
            ``result.cluster_representatives``. Additional
            dims (including auto-sliced dims like scenario)
            are supported.

    Returns:
        Data with ``cluster`` and ``timestep`` replaced by
        the original ``time`` dimension.
    """
    # Use stored slice_dims for canonical ordering
    slice_dims = self.clustering.slice_dims
    if not slice_dims:
        return self._disaggregate_single(data)

    import itertools

    from tsam_xarray._core import _concat_along_dims

    slice_coords = {d: data.coords[d].values for d in slice_dims}
    keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))
    results = []
    for key in keys:
        sel = dict(zip(slice_dims, key, strict=True))
        data_slice = data.sel(sel)
        result_slice = self._make_slice_view(sel)
        results.append(result_slice._disaggregate_single(data_slice))

    return _concat_along_dims(results, slice_dims, slice_coords)

TuningResult dataclass

Result of hyperparameter tuning.

Attributes:

Name Type Description
n_clusters int

Optimal number of typical periods.

n_segments int

Optimal number of segments per period.

rmse float

RMSE of the optimal configuration.

best_result AggregationResult

The AggregationResult for the optimal configuration.

history list[dict[str, Any]]

History of all tested configurations with their RMSE values.

all_results list[AggregationResult]

All AggregationResults from tuning (when save_all_results=True).

Source code in src/tsam_xarray/_tuning.py
@dataclass
class TuningResult:
    """Result of hyperparameter tuning.

    Attributes:
        n_clusters: Optimal number of typical periods.
        n_segments: Optimal number of segments per period.
        rmse: RMSE of the optimal configuration.
        best_result: The AggregationResult for the optimal
            configuration.
        history: History of all tested configurations with
            their RMSE values.
        all_results: All AggregationResults from tuning
            (when ``save_all_results=True``).
    """

    n_clusters: int
    n_segments: int
    rmse: float
    best_result: AggregationResult
    history: list[dict[str, Any]] = field(repr=False)
    all_results: list[AggregationResult] = field(default_factory=list, repr=False)
    _cache: dict[str, Any] = field(
        default_factory=dict, repr=False, init=False, compare=False
    )

    @property
    def summary(self) -> pd.DataFrame:
        """Summary table of all tested configurations, sorted by RMSE."""
        import pandas as pd

        return pd.DataFrame(self.history).sort_values("rmse")

    @property
    def summary_matrix(self) -> xr.Dataset:
        """Metrics as Dataset with ``(n_clusters, n_segments)`` dims.

        Contains ``rmse`` and ``timesteps`` as variables.
        NaN where a combination was not tested.
        """
        import pandas as pd

        df = pd.DataFrame(self.history)
        return df.set_index(["n_clusters", "n_segments"]).to_xarray()

    def _require_all_results(self) -> None:
        if not self.all_results:
            msg = (
                "No results available. Use save_all_results=True "
                "in the tuning function."
            )
            raise ValueError(msg)
        if len(self.all_results) != len(self.history):
            msg = (
                f"Results/history mismatch: "
                f"{len(self.all_results)} results "
                f"vs {len(self.history)} history entries."
            )
            raise ValueError(msg)

    @property
    def reconstructed(self) -> xr.DataArray:
        """Reconstructed time series for each tested config.

        Lazy and cached.  Returns an xarray DataArray with the
        original dimensions plus ``(n_clusters, n_segments)``.
        NaN where a combination was not tested.

        Requires ``save_all_results=True``.
        """
        if "reconstructed" not in self._cache:
            self._require_all_results()
            import xarray as xr

            arrays = []
            for h, res in zip(self.history, self.all_results, strict=True):
                arr = res.reconstructed.expand_dims(
                    n_clusters=[h["n_clusters"]],
                    n_segments=[h["n_segments"]],
                )
                arrays.append(arr)
            self._cache["reconstructed"] = xr.combine_by_coords(arrays, join="outer")
        return self._cache["reconstructed"]  # type: ignore[no-any-return]

    @property
    def accuracy(self) -> xr.Dataset:
        """Per-column accuracy metrics for each tested config.

        Lazy and cached.  Returns an xarray Dataset with variables
        ``rmse``, ``mae``, and ``rmse_duration``, each with the
        cluster dimensions plus ``(n_clusters, n_segments)``.
        NaN where a combination was not tested.

        Requires ``save_all_results=True``.
        """
        if "accuracy" not in self._cache:
            self._require_all_results()
            import xarray as xr

            datasets = []
            for h, res in zip(self.history, self.all_results, strict=True):
                dims = {
                    "n_clusters": [h["n_clusters"]],
                    "n_segments": [h["n_segments"]],
                }
                ds = xr.Dataset(
                    {
                        "rmse": res.accuracy.rmse.expand_dims(dims),
                        "mae": res.accuracy.mae.expand_dims(dims),
                        "rmse_duration": res.accuracy.rmse_duration.expand_dims(dims),
                    }
                )
                datasets.append(ds)
            self._cache["accuracy"] = xr.combine_by_coords(datasets, join="outer")
        return self._cache["accuracy"]  # type: ignore[no-any-return]

    def find_by_timesteps(self, target: int) -> AggregationResult:
        """Find the result closest to a target timestep count.

        Requires ``save_all_results=True``.
        """
        self._require_all_results()
        best_idx = 0
        best_diff = float("inf")
        for i, h in enumerate(self.history):
            diff = abs(h["timesteps"] - target)
            if diff < best_diff:
                best_diff = diff
                best_idx = i
        return self.all_results[best_idx]

    def find_by_rmse(self, threshold: float) -> AggregationResult:
        """Find the smallest configuration that achieves a target RMSE.

        Returns the configuration with the fewest timesteps whose RMSE
        is at or below ``threshold``.

        Requires ``save_all_results=True``.
        """
        self._require_all_results()
        candidates: list[tuple[int, int]] = []  # (timesteps, index)
        for i, h in enumerate(self.history):
            if h["rmse"] <= threshold:
                candidates.append((h["timesteps"], i))

        if not candidates:
            best_available = min(h["rmse"] for h in self.history)
            msg = (
                f"No configuration achieves RMSE <= {threshold}. "
                f"Best available: {best_available:.4f}"
            )
            raise ValueError(msg)

        candidates.sort(key=lambda x: x[0])
        return self.all_results[candidates[0][1]]

    def plot(self, show_labels: bool = True, **kwargs: Any) -> go.Figure:
        """Plot RMSE vs timesteps.

        Requires ``plotly`` (``pip install plotly``).
        """
        try:
            import plotly.graph_objects as go
        except ImportError as exc:
            msg = "plotly is required for plot(): pip install plotly"
            raise ImportError(msg) from exc

        summary = self.summary
        hover_text = [
            f"{row['n_clusters']}x{row['n_segments']}<br>"
            f"Timesteps: {row['timesteps']}<br>"
            f"RMSE: {row['rmse']:.4f}"
            for _, row in summary.iterrows()
        ]

        fig = go.Figure()
        fig.add_trace(
            go.Scatter(
                x=summary["timesteps"],
                y=summary["rmse"],
                mode="lines+markers" if len(summary) > 1 else "markers",
                marker={"size": 10},
                hovertext=hover_text if show_labels else None,
                hoverinfo="text" if show_labels else "x+y",
                **kwargs,
            )
        )
        fig.update_layout(
            title="Tuning Results: Complexity vs Accuracy",
            xaxis_title="Timesteps (n_clusters x n_segments)",
            yaxis_title="RMSE",
            hovermode="closest",
        )
        return fig

    def __len__(self) -> int:
        return len(self.all_results)

    def __getitem__(self, index: int) -> AggregationResult:
        self._require_all_results()
        return self.all_results[index]

    def __iter__(self) -> Any:
        self._require_all_results()
        return iter(self.all_results)

summary property

summary: DataFrame

Summary table of all tested configurations, sorted by RMSE.

summary_matrix property

summary_matrix: Dataset

Metrics as Dataset with (n_clusters, n_segments) dims.

Contains rmse and timesteps as variables. NaN where a combination was not tested.

reconstructed property

reconstructed: DataArray

Reconstructed time series for each tested config.

Lazy and cached. Returns an xarray DataArray with the original dimensions plus (n_clusters, n_segments). NaN where a combination was not tested.

Requires save_all_results=True.

accuracy property

accuracy: Dataset

Per-column accuracy metrics for each tested config.

Lazy and cached. Returns an xarray Dataset with variables rmse, mae, and rmse_duration, each with the cluster dimensions plus (n_clusters, n_segments). NaN where a combination was not tested.

Requires save_all_results=True.

find_by_timesteps

find_by_timesteps(target: int) -> AggregationResult

Find the result closest to a target timestep count.

Requires save_all_results=True.

Source code in src/tsam_xarray/_tuning.py
def find_by_timesteps(self, target: int) -> AggregationResult:
    """Find the result closest to a target timestep count.

    Requires ``save_all_results=True``.
    """
    self._require_all_results()
    best_idx = 0
    best_diff = float("inf")
    for i, h in enumerate(self.history):
        diff = abs(h["timesteps"] - target)
        if diff < best_diff:
            best_diff = diff
            best_idx = i
    return self.all_results[best_idx]

find_by_rmse

find_by_rmse(threshold: float) -> AggregationResult

Find the smallest configuration that achieves a target RMSE.

Returns the configuration with the fewest timesteps whose RMSE is at or below threshold.

Requires save_all_results=True.

Source code in src/tsam_xarray/_tuning.py
def find_by_rmse(self, threshold: float) -> AggregationResult:
    """Find the smallest configuration that achieves a target RMSE.

    Returns the configuration with the fewest timesteps whose RMSE
    is at or below ``threshold``.

    Requires ``save_all_results=True``.
    """
    self._require_all_results()
    candidates: list[tuple[int, int]] = []  # (timesteps, index)
    for i, h in enumerate(self.history):
        if h["rmse"] <= threshold:
            candidates.append((h["timesteps"], i))

    if not candidates:
        best_available = min(h["rmse"] for h in self.history)
        msg = (
            f"No configuration achieves RMSE <= {threshold}. "
            f"Best available: {best_available:.4f}"
        )
        raise ValueError(msg)

    candidates.sort(key=lambda x: x[0])
    return self.all_results[candidates[0][1]]

plot

plot(show_labels: bool = True, **kwargs: Any) -> go.Figure

Plot RMSE vs timesteps.

Requires plotly (pip install plotly).

Source code in src/tsam_xarray/_tuning.py
def plot(self, show_labels: bool = True, **kwargs: Any) -> go.Figure:
    """Plot RMSE vs timesteps.

    Requires ``plotly`` (``pip install plotly``).
    """
    try:
        import plotly.graph_objects as go
    except ImportError as exc:
        msg = "plotly is required for plot(): pip install plotly"
        raise ImportError(msg) from exc

    summary = self.summary
    hover_text = [
        f"{row['n_clusters']}x{row['n_segments']}<br>"
        f"Timesteps: {row['timesteps']}<br>"
        f"RMSE: {row['rmse']:.4f}"
        for _, row in summary.iterrows()
    ]

    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=summary["timesteps"],
            y=summary["rmse"],
            mode="lines+markers" if len(summary) > 1 else "markers",
            marker={"size": 10},
            hovertext=hover_text if show_labels else None,
            hoverinfo="text" if show_labels else "x+y",
            **kwargs,
        )
    )
    fig.update_layout(
        title="Tuning Results: Complexity vs Accuracy",
        xaxis_title="Timesteps (n_clusters x n_segments)",
        yaxis_title="RMSE",
        hovermode="closest",
    )
    return fig

aggregate

aggregate(
    da: DataArray,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    n_clusters: int,
    weights: Weights = None,
    **tsam_kwargs: Any,
) -> AggregationResult

Aggregate an xarray DataArray using tsam.

Parameters:

Name Type Description Default
da DataArray

Input data with a time dimension and optional extra dimensions.

required
time_dim str

Name of the time dimension.

required
cluster_dim Sequence[str] | str

Dimension(s) to cluster together. Multiple dims are stacked internally into a MultiIndex and unstacked in results. All remaining dims are sliced independently. Empty () for 1D time series with no column dimension.

required
n_clusters int

Number of cluster representatives.

required
weights Weights

Per-coordinate weights for clustering. Missing entries default to 1.0. Two formats:

  • Simple dict (single cluster_dim)::

    weights={"solar": 2.0, "wind": 1.0}

  • Dict-of-dicts (multiple cluster_dim)::

    weights={ "variable": {"solar": 2.0}, "region": {"north": 1.5}, }

Weights are multiplied across dimensions, e.g. ("solar", "north") gets weight 2.0 * 1.5 = 3.0.

None
**tsam_kwargs Any

Additional keyword arguments passed to tsam.aggregate().

{}
Source code in src/tsam_xarray/_core.py
def aggregate(
    da: xr.DataArray,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    n_clusters: int,
    weights: Weights = None,
    **tsam_kwargs: Any,
) -> AggregationResult:
    """Aggregate an xarray DataArray using tsam.

    Args:
        da: Input data with a time dimension and optional
            extra dimensions.
        time_dim: Name of the time dimension.
        cluster_dim: Dimension(s) to cluster together.
            Multiple dims are stacked internally into a
            MultiIndex and unstacked in results. All remaining
            dims are sliced independently. Empty ``()`` for 1D
            time series with no column dimension.
        n_clusters: Number of cluster representatives.
        weights: Per-coordinate weights for clustering.
            Missing entries default to 1.0. Two formats:

            - **Simple dict** (single ``cluster_dim``)::

                  weights={"solar": 2.0, "wind": 1.0}

            - **Dict-of-dicts** (multiple ``cluster_dim``)::

                  weights={
                      "variable": {"solar": 2.0},
                      "region": {"north": 1.5},
                  }

              Weights are multiplied across dimensions,
              e.g. ``("solar", "north")`` gets weight
              ``2.0 * 1.5 = 3.0``.

        **tsam_kwargs: Additional keyword arguments passed to
            ``tsam.aggregate()``.
    """
    _validate_time_dim(da, time_dim)
    col_dims = _resolve_cluster_dim(cluster_dim)
    slice_dims = _infer_slice_dims(da, time_dim, col_dims)
    _validate(da, time_dim, col_dims, slice_dims)
    da = _validate_data(da, time_dim, col_dims, slice_dims)
    _validate_no_cluster_config_weights(tsam_kwargs)
    per_dim_weights = _normalize_weights(weights, da, col_dims)

    if not slice_dims:
        return _aggregate_single(
            da, n_clusters, time_dim, col_dims, per_dim_weights, tsam_kwargs
        )

    slice_coords = {d: da.coords[d].values for d in slice_dims}
    slice_keys = list(itertools.product(*(slice_coords[d] for d in slice_dims)))

    results: list[AggregationResult] = []

    for key in slice_keys:
        sel = dict(zip(slice_dims, key, strict=True))
        da_slice = da.sel(sel)
        r = _aggregate_single(
            da_slice, n_clusters, time_dim, col_dims, per_dim_weights, tsam_kwargs
        )
        results.append(r)

    # Validate consistent cluster counts (can differ with extremes="append")
    _validate_consistent_cluster_counts(results, slice_keys)

    return _concat_results(results, slice_dims, slice_coords, slice_keys)

find_best_combination

find_best_combination(
    *args: Any, **kwargs: Any
) -> TuningResult

Deprecated alias for :func:grid_search.

Source code in src/tsam_xarray/_tuning.py
def find_best_combination(*args: Any, **kwargs: Any) -> TuningResult:
    """Deprecated alias for :func:`grid_search`."""
    import warnings

    warnings.warn(
        "find_best_combination is deprecated, use grid_search instead",
        FutureWarning,
        stacklevel=2,
    )
    return grid_search(*args, **kwargs)

find_optimal_combination

find_optimal_combination(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    data_reduction: float,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult

Find optimal n_clusters/n_segments for a target data reduction.

Tests all (n_clusters, n_segments) combinations that achieve the target data reduction, evaluating each across all slices.

Parameters:

Name Type Description Default
da Any

Input data.

required
time_dim str

Name of the time dimension.

required
cluster_dim Sequence[str] | str

Dimension(s) to cluster together.

required
data_reduction float

Target data reduction (e.g., 0.01 for 1% of original).

required
weights Weights

Per-coordinate weights for clustering and RMSE evaluation.

None
period_duration int | float | str

Hours per period (default: 24 for daily).

24
show_progress bool

Show progress bar (requires tqdm).

True
save_all_results bool

Keep all AggregationResults (memory-intensive).

True
**tsam_kwargs Any

Additional keyword arguments passed to tsam.aggregate().

{}

Returns:

Type Description
TuningResult

Best combination with lowest overall RMSE.

Source code in src/tsam_xarray/_tuning.py
def find_optimal_combination(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    data_reduction: float,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult:
    """Find optimal n_clusters/n_segments for a target data reduction.

    Tests all (n_clusters, n_segments) combinations that achieve
    the target data reduction, evaluating each across all slices.

    Args:
        da: Input data.
        time_dim: Name of the time dimension.
        cluster_dim: Dimension(s) to cluster together.
        data_reduction: Target data reduction (e.g., 0.01 for
            1% of original).
        weights: Per-coordinate weights for clustering and
            RMSE evaluation.
        period_duration: Hours per period (default: 24 for
            daily).
        show_progress: Show progress bar (requires tqdm).
        save_all_results: Keep all AggregationResults
            (memory-intensive).
        **tsam_kwargs: Additional keyword arguments passed to
            ``tsam.aggregate()``.

    Returns:
        Best combination with lowest overall RMSE.
    """
    n_timesteps_per_period, _n_periods, n_timesteps = _infer_time_params(
        da, time_dim, period_duration
    )

    # Generate candidates: for each segment count, max clusters that fits
    seen: set[tuple[int, int]] = set()
    candidates: list[tuple[int, int]] = []
    for n_seg in range(1, n_timesteps_per_period + 1):
        n_clust = find_clusters_for_reduction(n_timesteps, n_seg, data_reduction)
        if n_clust >= 2 and (n_clust, n_seg) not in seen:
            candidates.append((n_clust, n_seg))
            seen.add((n_clust, n_seg))

    if not candidates:
        msg = (
            f"No valid (n_clusters, n_segments) combinations "
            f"for data_reduction={data_reduction}"
        )
        raise ValueError(msg)

    history, all_results, best_rmse, best_result, best_nc, best_ns = (
        _evaluate_candidates(
            candidates,
            da,
            time_dim=time_dim,
            cluster_dim=cluster_dim,
            weights=weights,
            period_duration=period_duration,
            show_progress=show_progress,
            progress_desc="Testing configurations",
            save_all_results=save_all_results,
            tsam_kwargs=tsam_kwargs,
        )
    )

    if best_result is None:
        msg = "All configurations failed"
        raise RuntimeError(msg)

    return TuningResult(
        n_clusters=best_nc,
        n_segments=best_ns,
        rmse=best_rmse,
        best_result=best_result,
        history=history,
        all_results=all_results,
    )

find_pareto_front

find_pareto_front(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    max_timesteps: int | None = None,
    timesteps: Sequence[int] | None = None,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult

Find Pareto-optimal configs (RMSE vs complexity).

Runs the same grid search as :func:grid_search but filters the results to the Pareto frontier -- configurations where no other tested combo has both lower RMSE and fewer timesteps.

Parameters:

Name Type Description Default
da Any

Input data.

required
time_dim str

Name of the time dimension.

required
cluster_dim Sequence[str] | str

Dimension(s) to cluster together.

required
max_timesteps int | None

Maximum total timesteps to test (n_clusters * n_segments). Defaults to total number of timesteps in the data.

None
timesteps Sequence[int] | None

Specific timestep counts to test. Only combinations where n_clusters * n_segments is in this list are evaluated. Mutually exclusive with max_timesteps.

None
weights Weights

Per-coordinate weights for clustering and RMSE evaluation.

None
period_duration int | float | str

Hours per period (default: 24).

24
show_progress bool

Show progress bar.

True
save_all_results bool

Keep all AggregationResults (memory-intensive).

True
**tsam_kwargs Any

Additional keyword arguments passed to tsam.aggregate().

{}

Returns:

Type Description
TuningResult

Pareto-optimal result with lowest RMSE on the

TuningResult

frontier.

Source code in src/tsam_xarray/_tuning.py
def find_pareto_front(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    max_timesteps: int | None = None,
    timesteps: Sequence[int] | None = None,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult:
    """Find Pareto-optimal configs (RMSE vs complexity).

    Runs the same grid search as :func:`grid_search`
    but filters the results to the Pareto frontier --
    configurations where no other tested combo has both lower
    RMSE and fewer timesteps.

    Args:
        da: Input data.
        time_dim: Name of the time dimension.
        cluster_dim: Dimension(s) to cluster together.
        max_timesteps: Maximum total timesteps to test
            (n_clusters * n_segments). Defaults to total
            number of timesteps in the data.
        timesteps: Specific timestep counts to test. Only
            combinations where ``n_clusters * n_segments``
            is in this list are evaluated. Mutually exclusive
            with ``max_timesteps``.
        weights: Per-coordinate weights for clustering and
            RMSE evaluation.
        period_duration: Hours per period (default: 24).
        show_progress: Show progress bar.
        save_all_results: Keep all AggregationResults
            (memory-intensive).
        **tsam_kwargs: Additional keyword arguments passed to
            ``tsam.aggregate()``.

    Returns:
        Pareto-optimal result with lowest RMSE on the
        frontier.
    """
    grid = grid_search(
        da,
        time_dim=time_dim,
        cluster_dim=cluster_dim,
        max_timesteps=max_timesteps,
        timesteps=timesteps,
        weights=weights,
        period_duration=period_duration,
        show_progress=show_progress,
        save_all_results=save_all_results,
        **tsam_kwargs,
    )

    pareto_history, pareto_results = _pareto_filter(grid.history, grid.all_results)

    # Best on Pareto front = lowest RMSE (last entry when sorted
    # by ascending timesteps / descending RMSE).
    best_idx = min(
        range(len(pareto_history)),
        key=lambda i: pareto_history[i]["rmse"],
    )
    best_h = pareto_history[best_idx]

    # Reuse the best_result from the grid search when it matches.
    if (
        best_h["n_clusters"] == grid.n_clusters
        and best_h["n_segments"] == grid.n_segments
    ):
        best_result = grid.best_result
    elif pareto_results:
        best_result = pareto_results[best_idx]
    else:
        seg_config = SegmentConfig(n_segments=best_h["n_segments"])
        best_result = aggregate(
            da,
            time_dim=time_dim,
            cluster_dim=cluster_dim,
            n_clusters=best_h["n_clusters"],
            weights=weights,
            segments=seg_config,
            period_duration=period_duration,
            **tsam_kwargs,
        )

    return TuningResult(
        n_clusters=best_h["n_clusters"],
        n_segments=best_h["n_segments"],
        rmse=best_h["rmse"],
        best_result=best_result,
        history=pareto_history,
        all_results=pareto_results,
    )
grid_search(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    max_timesteps: int | None = None,
    timesteps: Sequence[int] | None = None,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult

Full grid search for best (n_clusters, n_segments).

Tests all valid (n_clusters, n_segments) pairs up to max_timesteps and returns the one with the lowest overall RMSE. The complete unfiltered history is preserved.

Parameters:

Name Type Description Default
da Any

Input data.

required
time_dim str

Name of the time dimension.

required
cluster_dim Sequence[str] | str

Dimension(s) to cluster together.

required
max_timesteps int | None

Maximum total timesteps to test (n_clusters * n_segments). Defaults to total number of timesteps in the data.

None
timesteps Sequence[int] | None

Specific timestep counts to test. Only combinations where n_clusters * n_segments is in this list are evaluated. Mutually exclusive with max_timesteps.

None
weights Weights

Per-coordinate weights for clustering and RMSE evaluation.

None
period_duration int | float | str

Hours per period (default: 24).

24
show_progress bool

Show progress bar.

True
save_all_results bool

Keep all AggregationResults (memory-intensive).

True
**tsam_kwargs Any

Additional keyword arguments passed to tsam.aggregate().

{}

Returns:

Type Description
TuningResult

Best combination with lowest overall RMSE and full

TuningResult

history.

Source code in src/tsam_xarray/_tuning.py
def grid_search(
    da: Any,
    *,
    time_dim: str,
    cluster_dim: Sequence[str] | str,
    max_timesteps: int | None = None,
    timesteps: Sequence[int] | None = None,
    weights: Weights = None,
    period_duration: int | float | str = 24,
    show_progress: bool = True,
    save_all_results: bool = True,
    **tsam_kwargs: Any,
) -> TuningResult:
    """Full grid search for best (n_clusters, n_segments).

    Tests all valid (n_clusters, n_segments) pairs up to
    ``max_timesteps`` and returns the one with the lowest overall
    RMSE.  The complete unfiltered ``history`` is preserved.

    Args:
        da: Input data.
        time_dim: Name of the time dimension.
        cluster_dim: Dimension(s) to cluster together.
        max_timesteps: Maximum total timesteps to test
            (n_clusters * n_segments). Defaults to total
            number of timesteps in the data.
        timesteps: Specific timestep counts to test. Only
            combinations where ``n_clusters * n_segments``
            is in this list are evaluated. Mutually exclusive
            with ``max_timesteps``.
        weights: Per-coordinate weights for clustering and
            RMSE evaluation.
        period_duration: Hours per period (default: 24).
        show_progress: Show progress bar.
        save_all_results: Keep all AggregationResults
            (memory-intensive).
        **tsam_kwargs: Additional keyword arguments passed to
            ``tsam.aggregate()``.

    Returns:
        Best combination with lowest overall RMSE and full
        history.
    """
    n_timesteps_per_period, n_periods, n_timesteps = _infer_time_params(
        da, time_dim, period_duration
    )

    if timesteps is not None and max_timesteps is not None:
        msg = "Cannot specify both 'timesteps' and 'max_timesteps'"
        raise ValueError(msg)

    if max_timesteps is None and timesteps is None:
        max_timesteps = n_timesteps

    allowed = set(timesteps) if timesteps is not None else None
    if max_timesteps is None:
        max_timesteps = max(allowed) if allowed else n_timesteps

    # Generate grid of candidates
    # Cap n_clusters at n_periods - 1 (n_periods = trivial perfect fit)
    max_clusters = n_periods - 1
    candidates: list[tuple[int, int]] = []
    for n_seg in range(1, n_timesteps_per_period + 1):
        for n_clust in range(2, min(max_clusters, max_timesteps // n_seg) + 1):
            total = n_clust * n_seg
            if total <= max_timesteps and (allowed is None or total in allowed):
                candidates.append((n_clust, n_seg))

    if not candidates:
        msg = f"No valid combinations for max_timesteps={max_timesteps}"
        raise ValueError(msg)

    history, all_results, best_rmse, best_result, best_nc, best_ns = (
        _evaluate_candidates(
            candidates,
            da,
            time_dim=time_dim,
            cluster_dim=cluster_dim,
            weights=weights,
            period_duration=period_duration,
            show_progress=show_progress,
            progress_desc="Grid search",
            save_all_results=save_all_results,
            tsam_kwargs=tsam_kwargs,
        )
    )

    if best_result is None:
        msg = "All configurations failed"
        raise RuntimeError(msg)

    return TuningResult(
        n_clusters=best_nc,
        n_segments=best_ns,
        rmse=best_rmse,
        best_result=best_result,
        history=history,
        all_results=all_results,
    )