dask/daskOpenGithub ↗

[PR-12292] Adding full_matrices parameter to dask.array.linalg.svd

dask.array.linalg.svd does not conform to the Array API spec because it lacks the full_matrices keyword argument.

dask

dask/dask Issue #10389: https://github.com/dask/dask/issues/10389

Description

dask.array.linalg.svd did not accept the full_matrices keyword argument, making it non-conformant with the Python Array API standard. This prevented any Array API-consuming library from using dask as a drop-in backend for SVD operations, even when only the reduced SVD (which dask already computes) was needed.

Background

Dask relies on NumPy's np.linalg.svd and its own TSQR (Tall-Skinny QR) algorithm to compute Singular Value Decomposition on chunked arrays. The Array API standard defines a common interface that array libraries (NumPy, CuPy, Dask, JAX, PyTorch) must follow so downstream libraries like scikit-learn can write backend-agnostic code.

Problem:

  • dask.array.linalg.svd was defined without the full_matrices parameter.
  • The Array API spec requires svd(x, full_matrices=True) as the standard signature.
  • Any library calling svd(x, full_matrices=False) with a dask array would crash immediately.
  • Example error encountered:
    TypeError: svd() got an unexpected keyword argument 'full_matrices'
    

Impact:

  • Array API-consuming libraries (e.g., scikit-learn) cannot use dask as a backend for SVD
  • Users forced to fall back to NumPy, losing lazy/parallel computation benefits
  • Dask cannot be used as a drop-in replacement in backend-agnostic code paths

Objective

Ensure that dask.array.linalg.svd:

  • Accepts the full_matrices keyword argument per the Array API spec
  • Works correctly when full_matrices=False (reduced SVD — already implemented)
  • Raises a clear NotImplementedError when full_matrices=True (not feasible with chunked arrays)
  • Maintains full backward compatibility with existing code

Root Cause Analysis

  • Current dask.array.linalg.svd defines the function without full_matrices:

    def svd(a, coerce_signs=True):
    
  • Internally, it already computes the reduced SVD (equivalent to full_matrices=False):

    u, s, v = delayed(np.linalg.svd, nout=3)(a, full_matrices=False)
    
  • The Array API spec requires the signature svd(x, full_matrices=True), meaning any spec-compliant caller passes full_matrices explicitly.

  • This problem arises because dask's SVD was written before the Array API standard was finalized. The internal behavior was correct, but the public interface didn't expose the parameter.

  • Computing full SVD (full_matrices=True) produces U of shape (M, M) and Vh of shape (N, N), which requires constructing the complete orthogonal basis — not feasible to do efficiently across distributed chunks.

Proposed Solution

  1. Add full_matrices parameter to the function signature with default False:

    def svd(a, coerce_signs=True, full_matrices=False):
    
  2. Add guard clause before existing logic:

    if full_matrices:
        raise NotImplementedError(
            "full_matrices=True is not implemented for dask arrays. "
            "Use full_matrices=False to compute the reduced SVD."
        )
    
  3. Update docstring to document the new parameter:

    full_matrices : bool, optional
        If True, raises ``NotImplementedError``. Only ``full_matrices=False``
        (reduced SVD) is currently supported. Default is ``False``.
    
        .. versionadded:: 2024.1.0
    
  4. Add tests:

    • Verify full_matrices=True raises NotImplementedError with descriptive message
    • Verify full_matrices=False produces correct results against NumPy for tall, wide, and square matrices
  5. Default is False (not True as in spec) because:

    • Dask only supports reduced SVD
    • Defaulting to True would break every existing da.linalg.svd(x) call
    • Issue author (ogrisel) explicitly recommends this approach

Implementation Details

  • dask/array/linalg.py modifications:

    • Added full_matrices=False to svd() signature after coerce_signs
    • Added guard clause raising NotImplementedError when full_matrices=True
    • Placed guard before all other validation (ndim, chunk checks) for clearest error
    • Used NotImplementedError (not ValueError) to match dask convention and allow Array API consumers to catch and fall back
  • Docstring updates:

    • Added full_matrices parameter documentation in numpydoc format
    • Added .. versionadded:: 2024.1.0 directive per dask CalVer convention
  • Test additions:

Testing & Validation

IDScenarioExpected ResultStatus
1full_matrices=True on single-chunk arrayRaises NotImplementedError with "full_matrices" in message
2full_matrices=False on wide matrix (10×20)Results match np.linalg.svd(x, full_matrices=False)
3full_matrices=False on tall matrix (20×10)Results match np.linalg.svd(x, full_matrices=False)
4full_matrices=False on square matrix (10×10)Results match np.linalg.svd(x, full_matrices=False)
5All existing SVD tests (no full_matrices arg)Backward compatible, all pass unchanged

Validation command:

python -m pytest dask/array/tests/test_linalg.py -x -v -k "svd"

Output:

97 passed, 4 skipped, 294 deselected in 10.79s

Tags

daskarray-apisvdfull_matrices
← All Open Source Contributions