Source code for mbrs.modules.als

  1from typing import Optional, Tuple
  2
  3import torch
  4import torch.linalg as LA
  5from torch import Tensor
  6
  7from mbrs import timer
  8
  9
[docs] 10class MatrixFactorizationALS: 11 """Alternating least squares (ALS) implemented in PyTorch. 12 13 Args: 14 regularization_weight (float): Weight of L2 regularization. 15 rank (int): Rank of the factarized matrices. 16 """ 17 18 def __init__(self, regularization_weight: float = 0.1, rank: int = 8) -> None: 19 self.regularization_weight = regularization_weight 20 self.rank = rank 21
[docs] 22 def compute_loss( 23 self, 24 matrix: Tensor, 25 x: Tensor, 26 y: Tensor, 27 observed_mask: Optional[Tensor] = None, 28 ) -> float: 29 """Compute the objective loss function. 30 31 Args: 32 matrix (Tensor): Target matrix of shape `(N, M)`. 33 x (Tensor): Left-side low-rank matrix of shape `(N, r)`. 34 y (Tensor): Right-side low-rank matrix of shape `(M, r)`. 35 observed_mask (Tensor, optional): Valid indices boolean mask of shape `(N, M)`. 36 37 Returns: 38 float: Objective loss. 39 """ 40 mse_loss = ((observed_mask * (matrix - x @ (y.T))) ** 2).sum() 41 l2_regularization_loss = x.norm() ** 2 + y.norm() ** 2 42 loss = mse_loss + self.regularization_weight * l2_regularization_loss 43 return loss.item()
44
[docs] 45 def factorize( 46 self, 47 matrix: Tensor, 48 observed_mask: Optional[Tensor] = None, 49 niter: int = 30, 50 tolerance: float = 1e-4, 51 seed: int = 0, 52 ) -> Tuple[Tensor, Tensor]: 53 """Factorize the given matrix. 54 55 The input matrix of shape `(N, M)` is decomposed into `X @ Y.T`, 56 where `X` and `Y` shape `(N, r)` and `(M, r)`, respectively. 57 58 This implementation does not compute the inverse matrix directly in `X = A^-1 @ b`. 59 Instead, `AX = b` is solved. 60 61 Args: 62 matrix (Tensor): Input matrix of shape `(N, M)`. 63 observed_mask (Tensor, optional): Boolean mask of valid indices of shape `(N, M)`. 64 niter (int): The number of alternating steps performed. 65 tolerance (float): If the difference between the previous and current loss 66 is smaller this value, ALS is regarded as converged. 67 seed (int): A seed for the random number generator. 68 69 Returns: 70 Tensor: Low-rank matrix `X` of shape `(N, r)`. 71 Tensor: Low-rank matrix `Y` of shape `(M, r)`. 72 """ 73 rng = torch.Generator(matrix.device) 74 rng = rng.manual_seed(seed) 75 76 N, M = matrix.size() 77 # Initialization: 78 # Empirically observed the convergence to be much better with the scaled initialization. 79 X = ( 80 torch.rand((N, self.rank), generator=rng, device=matrix.device) 81 * (N * self.rank) ** -0.5 82 ) 83 Y = ( 84 torch.rand((M, self.rank), generator=rng, device=matrix.device) 85 * (M * self.rank) ** -0.5 86 ) 87 if observed_mask is None: 88 observed_mask = matrix.new_ones((N, M), dtype=torch.bool) 89 90 regularization_term = self.regularization_weight * torch.eye( 91 self.rank, device=matrix.device 92 ) 93 prev_loss = float("1e5") 94 for _ in range(niter): 95 with timer.measure("ALS/iteration"): 96 # A: r x r 97 # B: r x N 98 # Solve Ax = b 99 X = LA.solve( 100 Y.T[None, :, :] @ (Y[None, :, :] * observed_mask[:, :, None]) 101 + regularization_term, 102 matrix @ Y, 103 ) 104 Y = LA.solve( 105 X.T[None, :, :] @ (X[None, :, :] * observed_mask.T[:, :, None]) 106 + regularization_term, 107 matrix.T @ X, 108 ) 109 loss = self.compute_loss(matrix, X, Y, observed_mask=observed_mask) 110 if prev_loss - loss <= tolerance: 111 break 112 prev_loss = loss 113 return X, Y