Convert C++ band_matrix Class to Java
import java.util.ArrayList; import java.util.List;
public class band_matrix {
private List<List
public band_matrix() {} // constructor
public band_matrix(int dim, int n_u, int n_l) { // constructor
resize(dim, n_u, n_l);
}
public void resize(int dim, int n_u, int n_l) { // init with dim,n_u,n_l
m_upper = new ArrayList<>();
m_lower = new ArrayList<>();
for (int i = 0; i < n_u + 1; i++) {
m_upper.add(new ArrayList<>());
}
for (int i = 0; i < n_l + 1; i++) {
m_lower.add(new ArrayList<>());
}
for (int i = 0; i < n_u + 1; i++) {
for (int j = 0; j < dim; j++) {
m_upper.get(i).add(0.0);
}
}
for (int i = 0; i < n_l + 1; i++) {
for (int j = 0; j < dim; j++) {
m_lower.get(i).add(0.0);
}
}
}
public int dim() { // matrix dimension
return m_upper.get(0).size();
}
public int num_upper() {
return m_upper.size() - 1;
}
public int num_lower() {
return m_lower.size() - 1;
}
// access operator
public double get(int i, int j) { // read
if (i < j || i > j + num_upper()) {
return 0.0;
} else if (i == j) {
return saved_diag(i);
} else if (i > j) {
return m_lower.get(i - j).get(j);
} else {
return m_upper.get(j - i).get(i);
}
}
public void set(int i, int j, double value) { // write
if (i < j || i > j + num_upper()) {
throw new IllegalArgumentException("Invalid indices");
} else if (i == j) {
saved_diag(i, value);
} else if (i > j) {
m_lower.get(i - j).set(j, value);
} else {
m_upper.get(j - i).set(i, value);
}
}
// we can store an additional diagonal (in m_lower)
public double saved_diag(int i) {
return m_lower.get(0).get(i);
}
public void saved_diag(int i, double value) {
m_lower.get(0).set(i, value);
}
public void lu_decompose() {
int n = dim();
int n_u = num_upper();
int n_l = num_lower();
for (int k = 0; k < n - 1; k++) {
for (int j = k + 1; j <= Math.min(k + n_l, n - 1); j++) {
double factor = get(j, k) / get(k, k);
set(j, k, factor);
for (int i = k + 1; i <= Math.min(k + n_u, n - 1); i++) {
set(j, i, get(j, i) - factor * get(k, i));
}
}
}
}
public List<Double> r_solve(List<Double> b) {
int n = dim();
int n_l = num_lower();
List<Double> x = new ArrayList<>(n);
for (int i = 0; i < n; i++) {
x.add(0.0);
}
for (int i = 0; i < n; i++) {
double sum = b.get(i);
for (int j = Math.max(0, i - n_l); j < i; j++) {
sum -= get(i, j) * x.get(j);
}
x.set(i, sum / get(i, i));
}
return x;
}
public List<Double> l_solve(List<Double> b) {
int n = dim();
int n_u = num_upper();
List<Double> x = new ArrayList<>(n);
for (int i = 0; i < n; i++) {
x.add(0.0);
}
for (int i = n - 1; i >= 0; i--) {
double sum = b.get(i);
for (int j = i + 1; j <= Math.min(i + n_u, n - 1); j++) {
sum -= get(i, j) * x.get(j);
}
x.set(i, sum / get(i, i));
}
return x;
}
public List<Double> lu_solve(List<Double> b, boolean is_lu_decomposed) {
if (!is_lu_decomposed) {
lu_decompose();
}
List<Double> y = l_solve(b);
return r_solve(y);
}
}
原文地址: https://www.cveoy.top/t/topic/fA8C 著作权归作者所有。请勿转载和采集!