diff --git a/ray-tracer/src/types/matrix.rs b/ray-tracer/src/types/matrix.rs index 822480e..40257cd 100644 --- a/ray-tracer/src/types/matrix.rs +++ b/ray-tracer/src/types/matrix.rs @@ -2,32 +2,25 @@ use crate::types::{eq_f64, Tuple}; #[derive(Clone, Debug)] pub struct Matrix { - width: usize, - height: usize, + size: usize, values: Vec, } impl Matrix { - pub fn new(width: usize, height: usize) -> Self { + pub fn new(size: usize) -> Self { Self { - width, - height, - values: vec![0.; width * height], + size, + values: vec![0.; size * size], } } pub fn identity() -> Self { - let mut values = vec![0.; 16]; - values[0] = 1.; - values[5] = 1.; - values[10] = 1.; - values[15] = 1.; - - Self { - width: 4, - height: 4, - values, - } + Self::from([ + [1., 0., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.], + ]) } pub fn cell(&self, row: usize, column: usize) -> f64 { @@ -39,11 +32,11 @@ impl Matrix { &mut self.values[addr] } - pub fn transpose(self) -> Self { - let mut m = Self::new(self.width, self.height); + pub fn transpose(&self) -> Self { + let mut m = Self::new(self.size); - for row in 0..self.height { - for column in 0..self.width { + for row in 0..self.size { + for column in 0..self.size { *m.cell_mut(row, column) = self.cell(column, row); } } @@ -51,17 +44,97 @@ impl Matrix { m } + pub fn invert(&self) -> Self { + let mut m = Matrix::new(self.size); + let determinant = self.determinant(); + for row in 0..self.size { + for column in 0..self.size { + *m.cell_mut(row, column) = self.cofactor(row, column); + } + } + + let mut m = m.transpose(); + for row in 0..self.size { + for column in 0..self.size { + let value = m.cell(row, column); + *m.cell_mut(row, column) = value / determinant; + } + } + + m + } + + fn determinant(&self) -> f64 { + // TODO: optimization may not be necessary, but this can be optimized by memoizing + // submatrices and cofactors. + if self.size == 2 { + self.cell(0, 0) * self.cell(1, 1) - self.cell(0, 1) * self.cell(1, 0) + } else if self.size == 3 { + self.cell(0, 0) * self.cofactor(0, 0) + + self.cell(0, 1) * self.cofactor(0, 1) + + self.cell(0, 2) * self.cofactor(0, 2) + } else if self.size == 4 { + self.cell(0, 0) * self.cofactor(0, 0) + + self.cell(0, 1) * self.cofactor(0, 1) + + self.cell(0, 2) * self.cofactor(0, 2) + + self.cell(0, 3) * self.cofactor(0, 3) + } else { + 0. + } + } + + fn invertible(&self) -> bool { + !eq_f64(self.determinant(), 0.) + } + + fn submatrix(&self, row: usize, column: usize) -> Self { + let mut m = Self::new(self.size - 1); + + for r in 0..self.size { + for c in 0..self.size { + let dest_r = if r < row { + r + } else if r > row { + r - 1 + } else { + continue; + }; + let dest_c = if c < column { + c + } else if c > column { + c - 1 + } else { + continue; + }; + *m.cell_mut(dest_r, dest_c) = self.cell(r, c); + } + } + + m + } + + fn minor(&self, row: usize, column: usize) -> f64 { + self.submatrix(row, column).determinant() + } + + fn cofactor(&self, row: usize, column: usize) -> f64 { + if (row + column) % 2 == 0 { + self.minor(row, column) + } else { + -self.minor(row, column) + } + } + #[inline] fn addr(&self, row: usize, column: usize) -> usize { - row * self.width + column + row * self.size + column } } impl From<[[f64; 2]; 2]> for Matrix { fn from(s: [[f64; 2]; 2]) -> Self { Self { - width: 2, - height: 2, + size: 2, values: s.concat(), } } @@ -70,8 +143,7 @@ impl From<[[f64; 2]; 2]> for Matrix { impl From<[[f64; 3]; 3]> for Matrix { fn from(s: [[f64; 3]; 3]) -> Self { Self { - width: 3, - height: 3, + size: 3, values: s.concat(), } } @@ -80,8 +152,7 @@ impl From<[[f64; 3]; 3]> for Matrix { impl From<[[f64; 4]; 4]> for Matrix { fn from(s: [[f64; 4]; 4]) -> Self { Self { - width: 4, - height: 4, + size: 4, values: s.concat(), } } @@ -89,7 +160,7 @@ impl From<[[f64; 4]; 4]> for Matrix { impl PartialEq for Matrix { fn eq(&self, rside: &Matrix) -> bool { - if self.width != rside.width || self.height != rside.height { + if self.size != rside.size { return false; }; @@ -103,12 +174,10 @@ impl PartialEq for Matrix { impl std::ops::Mul for Matrix { type Output = Matrix; fn mul(self, rside: Matrix) -> Matrix { - assert_eq!(self.width, 4); - assert_eq!(self.height, 4); - assert_eq!(rside.width, 4); - assert_eq!(rside.height, 4); + assert_eq!(self.size, 4); + assert_eq!(rside.size, 4); - let mut m = Matrix::new(self.width, self.height); + let mut m = Matrix::new(self.size); for row in 0..4 { for column in 0..4 { *m.cell_mut(row, column) = self.cell(row, 0) * rside.cell(0, column) @@ -125,8 +194,7 @@ impl std::ops::Mul for Matrix { impl std::ops::Mul for Matrix { type Output = Tuple; fn mul(self, rside: Tuple) -> Tuple { - assert_eq!(self.width, 4); - assert_eq!(self.height, 4); + assert_eq!(self.size, 4); let mut t = [0.; 4]; @@ -266,4 +334,103 @@ mod tests { assert_eq!(a.transpose(), expected); } + + #[test] + fn calculates_2x2_determinant() { + let m = Matrix::from([[1., 5.], [-3., 2.]]); + assert_eq!(m.determinant(), 17.); + } + + #[test] + fn calculates_3x3_determinant() { + let m = Matrix::from([[1., 2., 6.], [-5., 8., -4.], [2., 6., 4.]]); + assert_eq!(m.cofactor(0, 0), 56.); + assert_eq!(m.cofactor(0, 1), 12.); + assert_eq!(m.cofactor(0, 2), -46.); + assert_eq!(m.determinant(), -196.); + } + + #[test] + fn calculates_4x4_determinant() { + let m = Matrix::from([ + [-2., -8., 3., 5.], + [-3., 1., 7., 3.], + [1., 2., -9., 6.], + [-6., 7., 7., -9.], + ]); + assert_eq!(m.cofactor(0, 0), 690.); + assert_eq!(m.cofactor(0, 1), 447.); + assert_eq!(m.cofactor(0, 2), 210.); + assert_eq!(m.cofactor(0, 3), 51.); + assert_eq!(m.determinant(), -4071.); + } + + #[test] + fn calculates_submatrix() { + let m = Matrix::from([[1., 5., 0.], [-3., 2., 7.], [0., 6., -3.]]); + let expected = Matrix::from([[-3., 2.], [0., 6.]]); + assert_eq!(m.submatrix(0, 2), expected); + + let m = Matrix::from([ + [-6., 1., 1., 6.], + [-8., 5., 8., 6.], + [-1., 0., 8., 2.], + [-7., 1., -1., 1.], + ]); + let expected = Matrix::from([[-6., 1., 6.], [-8., 8., 6.], [-7., -1., 1.]]); + assert_eq!(m.submatrix(2, 1), expected); + } + + #[test] + fn calculates_minors_and_cofactors() { + let m = Matrix::from([[3., 5., 0.], [2., -1., -7.], [6., -1., 5.]]); + assert_eq!(m.submatrix(1, 0).determinant(), 25.); + assert_eq!(m.minor(0, 0), -12.); + assert_eq!(m.cofactor(0, 0), -12.); + assert_eq!(m.minor(1, 0), 25.); + assert_eq!(m.cofactor(1, 0), -25.); + } + + #[test] + fn invert_4x4_matrix() { + let m = Matrix::from([ + [6., 4., 4., 4.], + [5., 5., 7., 6.], + [4., -9., 3., -7.], + [9., 1., 7., -6.], + ]); + + assert_eq!(m.determinant(), -2120.); + assert!(m.invertible()); + + let m = Matrix::from([ + [-4., 2., -2., -3.], + [9., 6., 2., 6.], + [0., -5., 1., -5.], + [0., 0., 0., 0.], + ]); + assert_eq!(m.determinant(), 0.); + assert!(!m.invertible()); + + let m = Matrix::from([ + [-5., 2., 6., -8.], + [1., -5., 1., 8.], + [7., 7., -6., -7.], + [1., -3., 7., 4.], + ]); + let expected = Matrix::from([ + [0.21805, 0.45113, 0.24060, -0.04511], + [-0.80827, -1.45677, -0.44361, 0.52068], + [-0.07895, -0.22368, -0.05263, 0.19737], + [-0.52256, -0.81391, -0.30075, 0.30639], + ]); + + assert_eq!(m.determinant(), 532.); + assert_eq!(m.cofactor(2, 3), -160.); + assert!(eq_f64(expected.cell(3, 2), -160. / 532.)); + assert_eq!(m.cofactor(3, 2), 105.); + assert!(eq_f64(expected.cell(2, 3), 105. / 532.)); + + assert_eq!(m.invert(), expected); + } }