Calculate the inverse of matrices
This commit is contained in:
parent
c2777e2a70
commit
971206d325
|
@ -2,32 +2,25 @@ use crate::types::{eq_f64, Tuple};
|
|||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Matrix {
|
||||
width: usize,
|
||||
height: usize,
|
||||
size: usize,
|
||||
values: Vec<f64>,
|
||||
}
|
||||
|
||||
impl Matrix {
|
||||
pub fn new(width: usize, height: usize) -> Self {
|
||||
pub fn new(size: usize) -> Self {
|
||||
Self {
|
||||
width,
|
||||
height,
|
||||
values: vec![0.; width * height],
|
||||
size,
|
||||
values: vec![0.; size * size],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn identity() -> Self {
|
||||
let mut values = vec![0.; 16];
|
||||
values[0] = 1.;
|
||||
values[5] = 1.;
|
||||
values[10] = 1.;
|
||||
values[15] = 1.;
|
||||
|
||||
Self {
|
||||
width: 4,
|
||||
height: 4,
|
||||
values,
|
||||
}
|
||||
Self::from([
|
||||
[1., 0., 0., 0.],
|
||||
[0., 1., 0., 0.],
|
||||
[0., 0., 1., 0.],
|
||||
[0., 0., 0., 1.],
|
||||
])
|
||||
}
|
||||
|
||||
pub fn cell(&self, row: usize, column: usize) -> f64 {
|
||||
|
@ -39,11 +32,11 @@ impl Matrix {
|
|||
&mut self.values[addr]
|
||||
}
|
||||
|
||||
pub fn transpose(self) -> Self {
|
||||
let mut m = Self::new(self.width, self.height);
|
||||
pub fn transpose(&self) -> Self {
|
||||
let mut m = Self::new(self.size);
|
||||
|
||||
for row in 0..self.height {
|
||||
for column in 0..self.width {
|
||||
for row in 0..self.size {
|
||||
for column in 0..self.size {
|
||||
*m.cell_mut(row, column) = self.cell(column, row);
|
||||
}
|
||||
}
|
||||
|
@ -51,17 +44,97 @@ impl Matrix {
|
|||
m
|
||||
}
|
||||
|
||||
pub fn invert(&self) -> Self {
|
||||
let mut m = Matrix::new(self.size);
|
||||
let determinant = self.determinant();
|
||||
for row in 0..self.size {
|
||||
for column in 0..self.size {
|
||||
*m.cell_mut(row, column) = self.cofactor(row, column);
|
||||
}
|
||||
}
|
||||
|
||||
let mut m = m.transpose();
|
||||
for row in 0..self.size {
|
||||
for column in 0..self.size {
|
||||
let value = m.cell(row, column);
|
||||
*m.cell_mut(row, column) = value / determinant;
|
||||
}
|
||||
}
|
||||
|
||||
m
|
||||
}
|
||||
|
||||
fn determinant(&self) -> f64 {
|
||||
// TODO: optimization may not be necessary, but this can be optimized by memoizing
|
||||
// submatrices and cofactors.
|
||||
if self.size == 2 {
|
||||
self.cell(0, 0) * self.cell(1, 1) - self.cell(0, 1) * self.cell(1, 0)
|
||||
} else if self.size == 3 {
|
||||
self.cell(0, 0) * self.cofactor(0, 0)
|
||||
+ self.cell(0, 1) * self.cofactor(0, 1)
|
||||
+ self.cell(0, 2) * self.cofactor(0, 2)
|
||||
} else if self.size == 4 {
|
||||
self.cell(0, 0) * self.cofactor(0, 0)
|
||||
+ self.cell(0, 1) * self.cofactor(0, 1)
|
||||
+ self.cell(0, 2) * self.cofactor(0, 2)
|
||||
+ self.cell(0, 3) * self.cofactor(0, 3)
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
}
|
||||
|
||||
fn invertible(&self) -> bool {
|
||||
!eq_f64(self.determinant(), 0.)
|
||||
}
|
||||
|
||||
fn submatrix(&self, row: usize, column: usize) -> Self {
|
||||
let mut m = Self::new(self.size - 1);
|
||||
|
||||
for r in 0..self.size {
|
||||
for c in 0..self.size {
|
||||
let dest_r = if r < row {
|
||||
r
|
||||
} else if r > row {
|
||||
r - 1
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
let dest_c = if c < column {
|
||||
c
|
||||
} else if c > column {
|
||||
c - 1
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
*m.cell_mut(dest_r, dest_c) = self.cell(r, c);
|
||||
}
|
||||
}
|
||||
|
||||
m
|
||||
}
|
||||
|
||||
fn minor(&self, row: usize, column: usize) -> f64 {
|
||||
self.submatrix(row, column).determinant()
|
||||
}
|
||||
|
||||
fn cofactor(&self, row: usize, column: usize) -> f64 {
|
||||
if (row + column) % 2 == 0 {
|
||||
self.minor(row, column)
|
||||
} else {
|
||||
-self.minor(row, column)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn addr(&self, row: usize, column: usize) -> usize {
|
||||
row * self.width + column
|
||||
row * self.size + column
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[[f64; 2]; 2]> for Matrix {
|
||||
fn from(s: [[f64; 2]; 2]) -> Self {
|
||||
Self {
|
||||
width: 2,
|
||||
height: 2,
|
||||
size: 2,
|
||||
values: s.concat(),
|
||||
}
|
||||
}
|
||||
|
@ -70,8 +143,7 @@ impl From<[[f64; 2]; 2]> for Matrix {
|
|||
impl From<[[f64; 3]; 3]> for Matrix {
|
||||
fn from(s: [[f64; 3]; 3]) -> Self {
|
||||
Self {
|
||||
width: 3,
|
||||
height: 3,
|
||||
size: 3,
|
||||
values: s.concat(),
|
||||
}
|
||||
}
|
||||
|
@ -80,8 +152,7 @@ impl From<[[f64; 3]; 3]> for Matrix {
|
|||
impl From<[[f64; 4]; 4]> for Matrix {
|
||||
fn from(s: [[f64; 4]; 4]) -> Self {
|
||||
Self {
|
||||
width: 4,
|
||||
height: 4,
|
||||
size: 4,
|
||||
values: s.concat(),
|
||||
}
|
||||
}
|
||||
|
@ -89,7 +160,7 @@ impl From<[[f64; 4]; 4]> for Matrix {
|
|||
|
||||
impl PartialEq for Matrix {
|
||||
fn eq(&self, rside: &Matrix) -> bool {
|
||||
if self.width != rside.width || self.height != rside.height {
|
||||
if self.size != rside.size {
|
||||
return false;
|
||||
};
|
||||
|
||||
|
@ -103,12 +174,10 @@ impl PartialEq for Matrix {
|
|||
impl std::ops::Mul for Matrix {
|
||||
type Output = Matrix;
|
||||
fn mul(self, rside: Matrix) -> Matrix {
|
||||
assert_eq!(self.width, 4);
|
||||
assert_eq!(self.height, 4);
|
||||
assert_eq!(rside.width, 4);
|
||||
assert_eq!(rside.height, 4);
|
||||
assert_eq!(self.size, 4);
|
||||
assert_eq!(rside.size, 4);
|
||||
|
||||
let mut m = Matrix::new(self.width, self.height);
|
||||
let mut m = Matrix::new(self.size);
|
||||
for row in 0..4 {
|
||||
for column in 0..4 {
|
||||
*m.cell_mut(row, column) = self.cell(row, 0) * rside.cell(0, column)
|
||||
|
@ -125,8 +194,7 @@ impl std::ops::Mul for Matrix {
|
|||
impl std::ops::Mul<Tuple> for Matrix {
|
||||
type Output = Tuple;
|
||||
fn mul(self, rside: Tuple) -> Tuple {
|
||||
assert_eq!(self.width, 4);
|
||||
assert_eq!(self.height, 4);
|
||||
assert_eq!(self.size, 4);
|
||||
|
||||
let mut t = [0.; 4];
|
||||
|
||||
|
@ -266,4 +334,103 @@ mod tests {
|
|||
|
||||
assert_eq!(a.transpose(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculates_2x2_determinant() {
|
||||
let m = Matrix::from([[1., 5.], [-3., 2.]]);
|
||||
assert_eq!(m.determinant(), 17.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculates_3x3_determinant() {
|
||||
let m = Matrix::from([[1., 2., 6.], [-5., 8., -4.], [2., 6., 4.]]);
|
||||
assert_eq!(m.cofactor(0, 0), 56.);
|
||||
assert_eq!(m.cofactor(0, 1), 12.);
|
||||
assert_eq!(m.cofactor(0, 2), -46.);
|
||||
assert_eq!(m.determinant(), -196.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculates_4x4_determinant() {
|
||||
let m = Matrix::from([
|
||||
[-2., -8., 3., 5.],
|
||||
[-3., 1., 7., 3.],
|
||||
[1., 2., -9., 6.],
|
||||
[-6., 7., 7., -9.],
|
||||
]);
|
||||
assert_eq!(m.cofactor(0, 0), 690.);
|
||||
assert_eq!(m.cofactor(0, 1), 447.);
|
||||
assert_eq!(m.cofactor(0, 2), 210.);
|
||||
assert_eq!(m.cofactor(0, 3), 51.);
|
||||
assert_eq!(m.determinant(), -4071.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculates_submatrix() {
|
||||
let m = Matrix::from([[1., 5., 0.], [-3., 2., 7.], [0., 6., -3.]]);
|
||||
let expected = Matrix::from([[-3., 2.], [0., 6.]]);
|
||||
assert_eq!(m.submatrix(0, 2), expected);
|
||||
|
||||
let m = Matrix::from([
|
||||
[-6., 1., 1., 6.],
|
||||
[-8., 5., 8., 6.],
|
||||
[-1., 0., 8., 2.],
|
||||
[-7., 1., -1., 1.],
|
||||
]);
|
||||
let expected = Matrix::from([[-6., 1., 6.], [-8., 8., 6.], [-7., -1., 1.]]);
|
||||
assert_eq!(m.submatrix(2, 1), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn calculates_minors_and_cofactors() {
|
||||
let m = Matrix::from([[3., 5., 0.], [2., -1., -7.], [6., -1., 5.]]);
|
||||
assert_eq!(m.submatrix(1, 0).determinant(), 25.);
|
||||
assert_eq!(m.minor(0, 0), -12.);
|
||||
assert_eq!(m.cofactor(0, 0), -12.);
|
||||
assert_eq!(m.minor(1, 0), 25.);
|
||||
assert_eq!(m.cofactor(1, 0), -25.);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invert_4x4_matrix() {
|
||||
let m = Matrix::from([
|
||||
[6., 4., 4., 4.],
|
||||
[5., 5., 7., 6.],
|
||||
[4., -9., 3., -7.],
|
||||
[9., 1., 7., -6.],
|
||||
]);
|
||||
|
||||
assert_eq!(m.determinant(), -2120.);
|
||||
assert!(m.invertible());
|
||||
|
||||
let m = Matrix::from([
|
||||
[-4., 2., -2., -3.],
|
||||
[9., 6., 2., 6.],
|
||||
[0., -5., 1., -5.],
|
||||
[0., 0., 0., 0.],
|
||||
]);
|
||||
assert_eq!(m.determinant(), 0.);
|
||||
assert!(!m.invertible());
|
||||
|
||||
let m = Matrix::from([
|
||||
[-5., 2., 6., -8.],
|
||||
[1., -5., 1., 8.],
|
||||
[7., 7., -6., -7.],
|
||||
[1., -3., 7., 4.],
|
||||
]);
|
||||
let expected = Matrix::from([
|
||||
[0.21805, 0.45113, 0.24060, -0.04511],
|
||||
[-0.80827, -1.45677, -0.44361, 0.52068],
|
||||
[-0.07895, -0.22368, -0.05263, 0.19737],
|
||||
[-0.52256, -0.81391, -0.30075, 0.30639],
|
||||
]);
|
||||
|
||||
assert_eq!(m.determinant(), 532.);
|
||||
assert_eq!(m.cofactor(2, 3), -160.);
|
||||
assert!(eq_f64(expected.cell(3, 2), -160. / 532.));
|
||||
assert_eq!(m.cofactor(3, 2), 105.);
|
||||
assert!(eq_f64(expected.cell(2, 3), 105. / 532.));
|
||||
|
||||
assert_eq!(m.invert(), expected);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue