Source code for mbrs.modules.kmeans

  1from __future__ import annotations
  2
  3from dataclasses import dataclass
  4from typing import Tuple
  5
  6import torch
  7from torch import Generator, Tensor
  8
  9from mbrs import timer
 10
 11
[docs] 12class Kmeans: 13 """k-means clustering implemented in PyTorch. 14 15 Args: 16 cfg (Kmeans.Config): Configuration for k-means. 17 """ 18
[docs] 19 @dataclass 20 class Config: 21 """Configuration for k-means. 22 23 - ncentroids (int): Number of centroids. 24 - niter (int): Number of k-means iteration 25 - kmeanspp (bool): Initialize the centroids using k-means++. 26 - seed (bool): Random seed. 27 """ 28 29 ncentroids: int = 8 30 niter: int = 5 31 kmeanspp: bool = True 32 seed: int = 0
33 34 def __init__(self, cfg: Config) -> None: 35 self.cfg = cfg 36
[docs] 37 def assign(self, x: Tensor, centroids: Tensor) -> Tensor: 38 """Assigns the nearest neighbor centroid ID. 39 40 Args: 41 x (torch.Tensor): Assigned vectors of shape `(n, dim)`. 42 centroids (torch.Tensor): Centroids tensor of shape `(ncentroids, dim)`. 43 44 Returns: 45 torch.Tensor: Assigned IDs of shape `(n,)`. 46 """ 47 return torch.cdist(x, centroids, p=2).argmin(dim=-1)
48
[docs] 49 def init_kmeanspp(self, x: Tensor, rng: Generator, ncentroids: int) -> Tensor: 50 """Initializes the centroids via k-means++. 51 52 Args: 53 x (Tensor): Input vectors of shape `(n, dim)`. 54 rng (Generator): Random number generator. 55 ncentroids (int): Number of centroids. 56 57 Returns: 58 Tensor: Centroid vectors obtained using k-means++. 59 """ 60 centroids = x[ 61 torch.randint(x.size(0), size=(1,), generator=rng, device=x.device), : 62 ] 63 for _ in range(ncentroids - 1): 64 # Nc x N 65 sqdists = torch.cdist(centroids, x, p=2) ** 2 66 neighbor_sqdists = sqdists.min(dim=0).values.float().clamp(min=1e-5) 67 weights = neighbor_sqdists / neighbor_sqdists.sum() 68 new_centroid = x[torch.multinomial(weights, 1, generator=rng), :] 69 centroids = torch.cat([centroids, new_centroid]) 70 return centroids
71
[docs] 72 def train(self, x: Tensor) -> Tuple[Tensor, Tensor]: 73 """Trains k-means. 74 75 Args: 76 x (torch.Tensor): Input vectors of shape `(n, dim)`. 77 78 Returns: 79 Tensor: Centroids tensor of shape `(ncentroids, dim)`. 80 Tensor: Assigend IDs of shape `(n,)`. 81 """ 82 if self.cfg.ncentroids == 1: 83 with timer.measure("kmeans/iteration"): 84 centroids = x.mean(dim=0, keepdim=True) 85 return centroids, self.assign(x, centroids) 86 elif x.size(0) <= self.cfg.ncentroids: 87 return x, torch.arange(x.size(0), device=x.device) 88 89 with timer.measure("kmeans/initialize"): 90 rng = torch.Generator(x.device) 91 rng = rng.manual_seed(self.cfg.seed) 92 if self.cfg.kmeanspp: 93 centroids = self.init_kmeanspp(x, rng, self.cfg.ncentroids) 94 else: 95 centroids = x[ 96 torch.randperm(x.size(0), generator=rng, device=x.device)[ 97 : self.cfg.ncentroids 98 ] 99 ] 100 101 assigns = x.new_full((x.size(0),), fill_value=-1) 102 for i in range(self.cfg.niter): 103 with timer.measure("kmeans/iteration"): 104 new_assigns = self.assign(x, centroids) 105 if torch.equal(new_assigns, assigns): 106 break 107 assigns = new_assigns 108 for k in range(self.cfg.ncentroids): 109 if (assigns == k).any(): 110 centroids[k] = x[assigns == k].mean(dim=0) 111 return centroids, assigns