mbrs.modules.kmeans module#

class mbrs.modules.kmeans.Kmeans(cfg: Config)[source]#

Bases: object

k-means clustering implemented in PyTorch.

Parameters:

cfg (Kmeans.Config) – Configuration for k-means.

class Config(ncentroids: int = 8, niter: int = 5, kmeanspp: bool = True, seed: int = 0)[source]#

Bases: object

Configuration for k-means.

  • ncentroids (int): Number of centroids.

  • niter (int): Number of k-means iteration

  • kmeanspp (bool): Initialize the centroids using k-means++.

  • seed (bool): Random seed.

kmeanspp: bool = True#
ncentroids: int = 8#
niter: int = 5#
seed: int = 0#
assign(x: Tensor, centroids: Tensor) Tensor[source]#

Assigns the nearest neighbor centroid ID.

Parameters:
  • x (torch.Tensor) – Assigned vectors of shape (n, dim).

  • centroids (torch.Tensor) – Centroids tensor of shape (ncentroids, dim).

Returns:

Assigned IDs of shape (n,).

Return type:

torch.Tensor

init_kmeanspp(x: Tensor, rng: Generator, ncentroids: int) Tensor[source]#

Initializes the centroids via k-means++.

Parameters:
  • x (Tensor) – Input vectors of shape (n, dim).

  • rng (Generator) – Random number generator.

  • ncentroids (int) – Number of centroids.

Returns:

Centroid vectors obtained using k-means++.

Return type:

Tensor

train(x: Tensor) Tuple[Tensor, Tensor][source]#

Trains k-means.

Parameters:

x (torch.Tensor) – Input vectors of shape (n, dim).

Returns:

Centroids tensor of shape (ncentroids, dim). Tensor: Assigend IDs of shape (n,).

Return type:

Tensor