diff options
Diffstat (limited to 'src/matrix.c')
-rw-r--r-- | src/matrix.c | 96 |
1 files changed, 93 insertions, 3 deletions
diff --git a/src/matrix.c b/src/matrix.c index 22dd171..0891734 100644 --- a/src/matrix.c +++ b/src/matrix.c @@ -1,5 +1,6 @@ #include "lizfcm.h" #include <assert.h> +#include <math.h> #include <stdio.h> #include <string.h> @@ -71,7 +72,7 @@ Matrix_double **lu_decomp(Matrix_double *m) { for (size_t y = 0; y < m->rows; y++) { if (u->data[y]->data[y] == 0) { printf("ERROR: a pivot is zero in given matrix\n"); - exit(-1); + assert(false); } } @@ -82,7 +83,7 @@ Matrix_double **lu_decomp(Matrix_double *m) { if (denom == 0) { printf("ERROR: non-factorable matrix\n"); - exit(-1); + assert(false); } double factor = -(u->data[y]->data[x] / denom); @@ -129,7 +130,7 @@ Array_double *fsubst(Matrix_double *l, Array_double *b) { return x; } -Array_double *solve_matrix(Matrix_double *m, Array_double *b) { +Array_double *solve_matrix_lu_bsubst(Matrix_double *m, Array_double *b) { assert(b->size == m->rows); assert(m->rows == m->cols); @@ -144,10 +145,99 @@ Array_double *solve_matrix(Matrix_double *m, Array_double *b) { free_matrix(u); free_matrix(l); + free(u_l); return x; } +Matrix_double *gaussian_elimination(Matrix_double *m) { + uint64_t h = 0; + uint64_t k = 0; + + Matrix_double *m_cp = copy_matrix(m); + + while (h < m_cp->rows && k < m_cp->cols) { + uint64_t max_row = 0; + double total_max = 0.0; + + for (uint64_t row = h; row < m_cp->rows; row++) { + double this_max = c_max(fabs(m_cp->data[row]->data[k]), total_max); + if (c_max(this_max, total_max) == this_max) { + max_row = row; + } + } + + if (max_row == 0) { + k++; + continue; + } + + Array_double *swp = m_cp->data[max_row]; + m_cp->data[max_row] = m_cp->data[h]; + m_cp->data[h] = swp; + + for (uint64_t row = h + 1; row < m_cp->rows; row++) { + double factor = m_cp->data[row]->data[k] / m_cp->data[h]->data[k]; + m_cp->data[row]->data[k] = 0.0; + + for (uint64_t col = k + 1; col < m_cp->cols; col++) { + m_cp->data[row]->data[col] -= m_cp->data[h]->data[col] * factor; + } + } + + h++; + k++; + } + + return m_cp; +} + +Array_double *solve_matrix_gaussian(Matrix_double *m, Array_double *b) { + assert(b->size == m->rows); + assert(m->rows == m->cols); + + Matrix_double *m_augment_b = add_column(m, b); + Matrix_double *eliminated = gaussian_elimination(m_augment_b); + + Array_double *b_gauss = col_v(eliminated, m->cols); + Matrix_double *u = slice_column(eliminated, m->rows); + + Array_double *solution = bsubst(u, b_gauss); + + free_matrix(m_augment_b); + free_matrix(eliminated); + free_matrix(u); + free_vector(b_gauss); + + return solution; +} + +Matrix_double *slice_column(Matrix_double *m, size_t x) { + Matrix_double *sliced = copy_matrix(m); + + for (size_t row = 0; row < m->rows; row++) { + Array_double *old_row = sliced->data[row]; + sliced->data[row] = slice_element(old_row, x); + free_vector(old_row); + } + sliced->cols--; + + return sliced; +} + +Matrix_double *add_column(Matrix_double *m, Array_double *v) { + Matrix_double *pushed = copy_matrix(m); + + for (size_t row = 0; row < m->rows; row++) { + Array_double *old_row = pushed->data[row]; + pushed->data[row] = add_element(old_row, v->data[row]); + free_vector(old_row); + } + + pushed->cols++; + return pushed; +} + void free_matrix(Matrix_double *m) { for (size_t y = 0; y < m->rows; ++y) free_vector(m->data[y]); |