mbrs.modules.kmeans module#
- class mbrs.modules.kmeans.Kmeans(cfg: Config)[source]#
Bases:
objectk-means clustering implemented in PyTorch.
- Parameters:
cfg (Kmeans.Config) – Configuration for k-means.
- class Config(ncentroids: int = 8, niter: int = 5, kmeanspp: bool = True, seed: int = 0)[source]#
Bases:
objectConfiguration for k-means.
ncentroids (int): Number of centroids.
niter (int): Number of k-means iteration
kmeanspp (bool): Initialize the centroids using k-means++.
seed (bool): Random seed.
- assign(x: Tensor, centroids: Tensor) Tensor[source]#
Assigns the nearest neighbor centroid ID.
- Parameters:
x (torch.Tensor) – Assigned vectors of shape (n, dim).
centroids (torch.Tensor) – Centroids tensor of shape (ncentroids, dim).
- Returns:
Assigned IDs of shape (n,).
- Return type:
torch.Tensor
- init_kmeanspp(x: Tensor, rng: Generator, ncentroids: int) Tensor[source]#
Initializes the centroids via k-means++.
- Parameters:
x (Tensor) – Input vectors of shape (n, dim).
rng (Generator) – Random number generator.
ncentroids (int) – Number of centroids.
- Returns:
Centroid vectors obtained using k-means++.
- Return type:
Tensor