Calculate the inverse of matrices

This commit is contained in:
Savanni D'Gerinel 2024-06-10 07:54:58 -04:00
parent c2777e2a70
commit 971206d325
1 changed files with 203 additions and 36 deletions

View File

@ -2,32 +2,25 @@ use crate::types::{eq_f64, Tuple};
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct Matrix { pub struct Matrix {
width: usize, size: usize,
height: usize,
values: Vec<f64>, values: Vec<f64>,
} }
impl Matrix { impl Matrix {
pub fn new(width: usize, height: usize) -> Self { pub fn new(size: usize) -> Self {
Self { Self {
width, size,
height, values: vec![0.; size * size],
values: vec![0.; width * height],
} }
} }
pub fn identity() -> Self { pub fn identity() -> Self {
let mut values = vec![0.; 16]; Self::from([
values[0] = 1.; [1., 0., 0., 0.],
values[5] = 1.; [0., 1., 0., 0.],
values[10] = 1.; [0., 0., 1., 0.],
values[15] = 1.; [0., 0., 0., 1.],
])
Self {
width: 4,
height: 4,
values,
}
} }
pub fn cell(&self, row: usize, column: usize) -> f64 { pub fn cell(&self, row: usize, column: usize) -> f64 {
@ -39,11 +32,11 @@ impl Matrix {
&mut self.values[addr] &mut self.values[addr]
} }
pub fn transpose(self) -> Self { pub fn transpose(&self) -> Self {
let mut m = Self::new(self.width, self.height); let mut m = Self::new(self.size);
for row in 0..self.height { for row in 0..self.size {
for column in 0..self.width { for column in 0..self.size {
*m.cell_mut(row, column) = self.cell(column, row); *m.cell_mut(row, column) = self.cell(column, row);
} }
} }
@ -51,17 +44,97 @@ impl Matrix {
m m
} }
pub fn invert(&self) -> Self {
let mut m = Matrix::new(self.size);
let determinant = self.determinant();
for row in 0..self.size {
for column in 0..self.size {
*m.cell_mut(row, column) = self.cofactor(row, column);
}
}
let mut m = m.transpose();
for row in 0..self.size {
for column in 0..self.size {
let value = m.cell(row, column);
*m.cell_mut(row, column) = value / determinant;
}
}
m
}
fn determinant(&self) -> f64 {
// TODO: optimization may not be necessary, but this can be optimized by memoizing
// submatrices and cofactors.
if self.size == 2 {
self.cell(0, 0) * self.cell(1, 1) - self.cell(0, 1) * self.cell(1, 0)
} else if self.size == 3 {
self.cell(0, 0) * self.cofactor(0, 0)
+ self.cell(0, 1) * self.cofactor(0, 1)
+ self.cell(0, 2) * self.cofactor(0, 2)
} else if self.size == 4 {
self.cell(0, 0) * self.cofactor(0, 0)
+ self.cell(0, 1) * self.cofactor(0, 1)
+ self.cell(0, 2) * self.cofactor(0, 2)
+ self.cell(0, 3) * self.cofactor(0, 3)
} else {
0.
}
}
fn invertible(&self) -> bool {
!eq_f64(self.determinant(), 0.)
}
fn submatrix(&self, row: usize, column: usize) -> Self {
let mut m = Self::new(self.size - 1);
for r in 0..self.size {
for c in 0..self.size {
let dest_r = if r < row {
r
} else if r > row {
r - 1
} else {
continue;
};
let dest_c = if c < column {
c
} else if c > column {
c - 1
} else {
continue;
};
*m.cell_mut(dest_r, dest_c) = self.cell(r, c);
}
}
m
}
fn minor(&self, row: usize, column: usize) -> f64 {
self.submatrix(row, column).determinant()
}
fn cofactor(&self, row: usize, column: usize) -> f64 {
if (row + column) % 2 == 0 {
self.minor(row, column)
} else {
-self.minor(row, column)
}
}
#[inline] #[inline]
fn addr(&self, row: usize, column: usize) -> usize { fn addr(&self, row: usize, column: usize) -> usize {
row * self.width + column row * self.size + column
} }
} }
impl From<[[f64; 2]; 2]> for Matrix { impl From<[[f64; 2]; 2]> for Matrix {
fn from(s: [[f64; 2]; 2]) -> Self { fn from(s: [[f64; 2]; 2]) -> Self {
Self { Self {
width: 2, size: 2,
height: 2,
values: s.concat(), values: s.concat(),
} }
} }
@ -70,8 +143,7 @@ impl From<[[f64; 2]; 2]> for Matrix {
impl From<[[f64; 3]; 3]> for Matrix { impl From<[[f64; 3]; 3]> for Matrix {
fn from(s: [[f64; 3]; 3]) -> Self { fn from(s: [[f64; 3]; 3]) -> Self {
Self { Self {
width: 3, size: 3,
height: 3,
values: s.concat(), values: s.concat(),
} }
} }
@ -80,8 +152,7 @@ impl From<[[f64; 3]; 3]> for Matrix {
impl From<[[f64; 4]; 4]> for Matrix { impl From<[[f64; 4]; 4]> for Matrix {
fn from(s: [[f64; 4]; 4]) -> Self { fn from(s: [[f64; 4]; 4]) -> Self {
Self { Self {
width: 4, size: 4,
height: 4,
values: s.concat(), values: s.concat(),
} }
} }
@ -89,7 +160,7 @@ impl From<[[f64; 4]; 4]> for Matrix {
impl PartialEq for Matrix { impl PartialEq for Matrix {
fn eq(&self, rside: &Matrix) -> bool { fn eq(&self, rside: &Matrix) -> bool {
if self.width != rside.width || self.height != rside.height { if self.size != rside.size {
return false; return false;
}; };
@ -103,12 +174,10 @@ impl PartialEq for Matrix {
impl std::ops::Mul for Matrix { impl std::ops::Mul for Matrix {
type Output = Matrix; type Output = Matrix;
fn mul(self, rside: Matrix) -> Matrix { fn mul(self, rside: Matrix) -> Matrix {
assert_eq!(self.width, 4); assert_eq!(self.size, 4);
assert_eq!(self.height, 4); assert_eq!(rside.size, 4);
assert_eq!(rside.width, 4);
assert_eq!(rside.height, 4);
let mut m = Matrix::new(self.width, self.height); let mut m = Matrix::new(self.size);
for row in 0..4 { for row in 0..4 {
for column in 0..4 { for column in 0..4 {
*m.cell_mut(row, column) = self.cell(row, 0) * rside.cell(0, column) *m.cell_mut(row, column) = self.cell(row, 0) * rside.cell(0, column)
@ -125,8 +194,7 @@ impl std::ops::Mul for Matrix {
impl std::ops::Mul<Tuple> for Matrix { impl std::ops::Mul<Tuple> for Matrix {
type Output = Tuple; type Output = Tuple;
fn mul(self, rside: Tuple) -> Tuple { fn mul(self, rside: Tuple) -> Tuple {
assert_eq!(self.width, 4); assert_eq!(self.size, 4);
assert_eq!(self.height, 4);
let mut t = [0.; 4]; let mut t = [0.; 4];
@ -266,4 +334,103 @@ mod tests {
assert_eq!(a.transpose(), expected); assert_eq!(a.transpose(), expected);
} }
#[test]
fn calculates_2x2_determinant() {
let m = Matrix::from([[1., 5.], [-3., 2.]]);
assert_eq!(m.determinant(), 17.);
}
#[test]
fn calculates_3x3_determinant() {
let m = Matrix::from([[1., 2., 6.], [-5., 8., -4.], [2., 6., 4.]]);
assert_eq!(m.cofactor(0, 0), 56.);
assert_eq!(m.cofactor(0, 1), 12.);
assert_eq!(m.cofactor(0, 2), -46.);
assert_eq!(m.determinant(), -196.);
}
#[test]
fn calculates_4x4_determinant() {
let m = Matrix::from([
[-2., -8., 3., 5.],
[-3., 1., 7., 3.],
[1., 2., -9., 6.],
[-6., 7., 7., -9.],
]);
assert_eq!(m.cofactor(0, 0), 690.);
assert_eq!(m.cofactor(0, 1), 447.);
assert_eq!(m.cofactor(0, 2), 210.);
assert_eq!(m.cofactor(0, 3), 51.);
assert_eq!(m.determinant(), -4071.);
}
#[test]
fn calculates_submatrix() {
let m = Matrix::from([[1., 5., 0.], [-3., 2., 7.], [0., 6., -3.]]);
let expected = Matrix::from([[-3., 2.], [0., 6.]]);
assert_eq!(m.submatrix(0, 2), expected);
let m = Matrix::from([
[-6., 1., 1., 6.],
[-8., 5., 8., 6.],
[-1., 0., 8., 2.],
[-7., 1., -1., 1.],
]);
let expected = Matrix::from([[-6., 1., 6.], [-8., 8., 6.], [-7., -1., 1.]]);
assert_eq!(m.submatrix(2, 1), expected);
}
#[test]
fn calculates_minors_and_cofactors() {
let m = Matrix::from([[3., 5., 0.], [2., -1., -7.], [6., -1., 5.]]);
assert_eq!(m.submatrix(1, 0).determinant(), 25.);
assert_eq!(m.minor(0, 0), -12.);
assert_eq!(m.cofactor(0, 0), -12.);
assert_eq!(m.minor(1, 0), 25.);
assert_eq!(m.cofactor(1, 0), -25.);
}
#[test]
fn invert_4x4_matrix() {
let m = Matrix::from([
[6., 4., 4., 4.],
[5., 5., 7., 6.],
[4., -9., 3., -7.],
[9., 1., 7., -6.],
]);
assert_eq!(m.determinant(), -2120.);
assert!(m.invertible());
let m = Matrix::from([
[-4., 2., -2., -3.],
[9., 6., 2., 6.],
[0., -5., 1., -5.],
[0., 0., 0., 0.],
]);
assert_eq!(m.determinant(), 0.);
assert!(!m.invertible());
let m = Matrix::from([
[-5., 2., 6., -8.],
[1., -5., 1., 8.],
[7., 7., -6., -7.],
[1., -3., 7., 4.],
]);
let expected = Matrix::from([
[0.21805, 0.45113, 0.24060, -0.04511],
[-0.80827, -1.45677, -0.44361, 0.52068],
[-0.07895, -0.22368, -0.05263, 0.19737],
[-0.52256, -0.81391, -0.30075, 0.30639],
]);
assert_eq!(m.determinant(), 532.);
assert_eq!(m.cofactor(2, 3), -160.);
assert!(eq_f64(expected.cell(3, 2), -160. / 532.));
assert_eq!(m.cofactor(3, 2), 105.);
assert!(eq_f64(expected.cell(2, 3), 105. / 532.));
assert_eq!(m.invert(), expected);
}
} }