Source code for mbrs.modules.kmeans
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Tuple
5
6import torch
7from torch import Generator, Tensor
8
9from mbrs import timer
10
11
[docs]
12class Kmeans:
13 """k-means clustering implemented in PyTorch.
14
15 Args:
16 cfg (Kmeans.Config): Configuration for k-means.
17 """
18
[docs]
19 @dataclass
20 class Config:
21 """Configuration for k-means.
22
23 - ncentroids (int): Number of centroids.
24 - niter (int): Number of k-means iteration
25 - kmeanspp (bool): Initialize the centroids using k-means++.
26 - seed (bool): Random seed.
27 """
28
29 ncentroids: int = 8
30 niter: int = 5
31 kmeanspp: bool = True
32 seed: int = 0
33
34 def __init__(self, cfg: Config) -> None:
35 self.cfg = cfg
36
[docs]
37 def assign(self, x: Tensor, centroids: Tensor) -> Tensor:
38 """Assigns the nearest neighbor centroid ID.
39
40 Args:
41 x (torch.Tensor): Assigned vectors of shape `(n, dim)`.
42 centroids (torch.Tensor): Centroids tensor of shape `(ncentroids, dim)`.
43
44 Returns:
45 torch.Tensor: Assigned IDs of shape `(n,)`.
46 """
47 return torch.cdist(x, centroids, p=2).argmin(dim=-1)
48
[docs]
49 def init_kmeanspp(self, x: Tensor, rng: Generator, ncentroids: int) -> Tensor:
50 """Initializes the centroids via k-means++.
51
52 Args:
53 x (Tensor): Input vectors of shape `(n, dim)`.
54 rng (Generator): Random number generator.
55 ncentroids (int): Number of centroids.
56
57 Returns:
58 Tensor: Centroid vectors obtained using k-means++.
59 """
60 centroids = x[
61 torch.randint(x.size(0), size=(1,), generator=rng, device=x.device), :
62 ]
63 for _ in range(ncentroids - 1):
64 # Nc x N
65 sqdists = torch.cdist(centroids, x, p=2) ** 2
66 neighbor_sqdists = sqdists.min(dim=0).values.float().clamp(min=1e-5)
67 weights = neighbor_sqdists / neighbor_sqdists.sum()
68 new_centroid = x[torch.multinomial(weights, 1, generator=rng), :]
69 centroids = torch.cat([centroids, new_centroid])
70 return centroids
71
[docs]
72 def train(self, x: Tensor) -> Tuple[Tensor, Tensor]:
73 """Trains k-means.
74
75 Args:
76 x (torch.Tensor): Input vectors of shape `(n, dim)`.
77
78 Returns:
79 Tensor: Centroids tensor of shape `(ncentroids, dim)`.
80 Tensor: Assigend IDs of shape `(n,)`.
81 """
82 if self.cfg.ncentroids == 1:
83 with timer.measure("kmeans/iteration"):
84 centroids = x.mean(dim=0, keepdim=True)
85 return centroids, self.assign(x, centroids)
86 elif x.size(0) <= self.cfg.ncentroids:
87 return x, torch.arange(x.size(0), device=x.device)
88
89 with timer.measure("kmeans/initialize"):
90 rng = torch.Generator(x.device)
91 rng = rng.manual_seed(self.cfg.seed)
92 if self.cfg.kmeanspp:
93 centroids = self.init_kmeanspp(x, rng, self.cfg.ncentroids)
94 else:
95 centroids = x[
96 torch.randperm(x.size(0), generator=rng, device=x.device)[
97 : self.cfg.ncentroids
98 ]
99 ]
100
101 assigns = x.new_full((x.size(0),), fill_value=-1)
102 for i in range(self.cfg.niter):
103 with timer.measure("kmeans/iteration"):
104 new_assigns = self.assign(x, centroids)
105 if torch.equal(new_assigns, assigns):
106 break
107 assigns = new_assigns
108 for k in range(self.cfg.ncentroids):
109 if (assigns == k).any():
110 centroids[k] = x[assigns == k].mean(dim=0)
111 return centroids, assigns