1from typing import Optional, Tuple
2
3import torch
4import torch.linalg as LA
5from torch import Tensor
6
7from mbrs import timer
8
9
[docs]
10class MatrixFactorizationALS:
11 """Alternating least squares (ALS) implemented in PyTorch.
12
13 Args:
14 regularization_weight (float): Weight of L2 regularization.
15 rank (int): Rank of the factarized matrices.
16 """
17
18 def __init__(self, regularization_weight: float = 0.1, rank: int = 8) -> None:
19 self.regularization_weight = regularization_weight
20 self.rank = rank
21
[docs]
22 def compute_loss(
23 self,
24 matrix: Tensor,
25 x: Tensor,
26 y: Tensor,
27 observed_mask: Optional[Tensor] = None,
28 ) -> float:
29 """Compute the objective loss function.
30
31 Args:
32 matrix (Tensor): Target matrix of shape `(N, M)`.
33 x (Tensor): Left-side low-rank matrix of shape `(N, r)`.
34 y (Tensor): Right-side low-rank matrix of shape `(M, r)`.
35 observed_mask (Tensor, optional): Valid indices boolean mask of shape `(N, M)`.
36
37 Returns:
38 float: Objective loss.
39 """
40 mse_loss = ((observed_mask * (matrix - x @ (y.T))) ** 2).sum()
41 l2_regularization_loss = x.norm() ** 2 + y.norm() ** 2
42 loss = mse_loss + self.regularization_weight * l2_regularization_loss
43 return loss.item()
44
[docs]
45 def factorize(
46 self,
47 matrix: Tensor,
48 observed_mask: Optional[Tensor] = None,
49 niter: int = 30,
50 tolerance: float = 1e-4,
51 seed: int = 0,
52 ) -> Tuple[Tensor, Tensor]:
53 """Factorize the given matrix.
54
55 The input matrix of shape `(N, M)` is decomposed into `X @ Y.T`,
56 where `X` and `Y` shape `(N, r)` and `(M, r)`, respectively.
57
58 This implementation does not compute the inverse matrix directly in `X = A^-1 @ b`.
59 Instead, `AX = b` is solved.
60
61 Args:
62 matrix (Tensor): Input matrix of shape `(N, M)`.
63 observed_mask (Tensor, optional): Boolean mask of valid indices of shape `(N, M)`.
64 niter (int): The number of alternating steps performed.
65 tolerance (float): If the difference between the previous and current loss
66 is smaller this value, ALS is regarded as converged.
67 seed (int): A seed for the random number generator.
68
69 Returns:
70 Tensor: Low-rank matrix `X` of shape `(N, r)`.
71 Tensor: Low-rank matrix `Y` of shape `(M, r)`.
72 """
73 rng = torch.Generator(matrix.device)
74 rng = rng.manual_seed(seed)
75
76 N, M = matrix.size()
77 # Initialization:
78 # Empirically observed the convergence to be much better with the scaled initialization.
79 X = (
80 torch.rand((N, self.rank), generator=rng, device=matrix.device)
81 * (N * self.rank) ** -0.5
82 )
83 Y = (
84 torch.rand((M, self.rank), generator=rng, device=matrix.device)
85 * (M * self.rank) ** -0.5
86 )
87 if observed_mask is None:
88 observed_mask = matrix.new_ones((N, M), dtype=torch.bool)
89
90 regularization_term = self.regularization_weight * torch.eye(
91 self.rank, device=matrix.device
92 )
93 prev_loss = float("1e5")
94 for _ in range(niter):
95 with timer.measure("ALS/iteration"):
96 # A: r x r
97 # B: r x N
98 # Solve Ax = b
99 X = LA.solve(
100 Y.T[None, :, :] @ (Y[None, :, :] * observed_mask[:, :, None])
101 + regularization_term,
102 matrix @ Y,
103 )
104 Y = LA.solve(
105 X.T[None, :, :] @ (X[None, :, :] * observed_mask.T[:, :, None])
106 + regularization_term,
107 matrix.T @ X,
108 )
109 loss = self.compute_loss(matrix, X, Y, observed_mask=observed_mask)
110 if prev_loss - loss <= tolerance:
111 break
112 prev_loss = loss
113 return X, Y