NULAPACK
NUmerical Linear Algebra PACKage
Loading...
Searching...
No Matches
thomas.py
Go to the documentation of this file.
1# ====================================================================
2# N U L A P A C K
3# U U L A P A C K
4# L L L A P A C K
5# A A A A P A C K
6# P P P P P A C K
7# A A A A A A C K
8# C C C C C C C K
9# K K K K K K K K
10#
11# This file is part of NULAPACK - NUmerical Linear Algebra PACKage
12#
13# Copyright (C) 2025 Saud Zahir
14#
15# NULAPACK is free software: you can redistribute it and/or modify
16# it under the terms of the GNU General Public License as published by
17# the Free Software Foundation, either version 3 of the License, or
18# (at your option) any later version.
19#
20# NULAPACK is distributed in the hope that it will be useful,
21# but WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
23# GNU General Public License for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with NULAPACK. If not, see <https://www.gnu.org/licenses/>.
27# ====================================================================
28
29import _nulapack
30import numpy as np
31
32
33def thomas(a: np.ndarray, b: np.ndarray):
34 """
35 Solve a tridiagonal linear system A * X = B using the Thomas algorithm.
36
37 Parameters
38 ----------
39 a : ndarray
40 Coefficient matrix (n x n) stored as a full matrix.
41 b : ndarray
42 Right-hand side vector (n,)
43
44 Returns
45 -------
46 x : ndarray
47 Solution vector
48 info : int
49 0 if success, <0 if zero diagonal detected
50 """
51 a = np.ascontiguousarray(a)
52 b = np.asfortranarray(b)
53 n = a.shape[0]
54
55 x = np.zeros_like(b)
56
57 a_flat = a.ravel()
58
59 if np.issubdtype(a.dtype, np.floating):
60 if a.dtype == np.float32:
61 status = _nulapack.sgttsv(a_flat, b, x, 0, n)
62 else: # float64
63 status = _nulapack.dgttsv(a_flat, b, x, 0, n)
64 elif np.issubdtype(a.dtype, np.complexfloating):
65 if a.dtype == np.complex64:
66 status = _nulapack.cgttsv(a_flat, b, x, 0, n)
67 else: # complex128
68 status = _nulapack.zgttsv(a_flat, b, x, 0, n)
69 else:
70 raise TypeError(f"Unsupported array dtype: {a.dtype}")
71
72 return x, int(status) if status is not None else 0