mbrs.modules.als module#
- class mbrs.modules.als.MatrixFactorizationALS(regularization_weight: float = 0.1, rank: int = 8)[source]#
Bases:
objectAlternating least squares (ALS) implemented in PyTorch.
- Parameters:
- compute_loss(matrix: Tensor, x: Tensor, y: Tensor, observed_mask: Tensor | None = None) float[source]#
Compute the objective loss function.
- Parameters:
matrix (Tensor) – Target matrix of shape (N, M).
x (Tensor) – Left-side low-rank matrix of shape (N, r).
y (Tensor) – Right-side low-rank matrix of shape (M, r).
observed_mask (Tensor, optional) – Valid indices boolean mask of shape (N, M).
- Returns:
Objective loss.
- Return type:
- factorize(matrix: Tensor, observed_mask: Tensor | None = None, niter: int = 30, tolerance: float = 0.0001, seed: int = 0) Tuple[Tensor, Tensor][source]#
Factorize the given matrix.
The input matrix of shape (N, M) is decomposed into X @ Y.T, where X and Y shape (N, r) and (M, r), respectively.
This implementation does not compute the inverse matrix directly in X = A^-1 @ b. Instead, AX = b is solved.
- Parameters:
matrix (Tensor) – Input matrix of shape (N, M).
observed_mask (Tensor, optional) – Boolean mask of valid indices of shape (N, M).
niter (int) – The number of alternating steps performed.
tolerance (float) – If the difference between the previous and current loss is smaller this value, ALS is regarded as converged.
seed (int) – A seed for the random number generator.
- Returns:
Low-rank matrix X of shape (N, r). Tensor: Low-rank matrix Y of shape (M, r).
- Return type:
Tensor