Basic matrix operations

This commit is contained in:
Savanni D'Gerinel 2024-06-09 23:21:16 -04:00
parent 2569a48792
commit c2777e2a70
3 changed files with 277 additions and 0 deletions

View File

@ -0,0 +1,269 @@
use crate::types::{eq_f64, Tuple};
#[derive(Clone, Debug)]
pub struct Matrix {
width: usize,
height: usize,
values: Vec<f64>,
}
impl Matrix {
pub fn new(width: usize, height: usize) -> Self {
Self {
width,
height,
values: vec![0.; width * height],
}
}
pub fn identity() -> Self {
let mut values = vec![0.; 16];
values[0] = 1.;
values[5] = 1.;
values[10] = 1.;
values[15] = 1.;
Self {
width: 4,
height: 4,
values,
}
}
pub fn cell(&self, row: usize, column: usize) -> f64 {
self.values[self.addr(row, column)]
}
pub fn cell_mut<'a>(&'a mut self, row: usize, column: usize) -> &'a mut f64 {
let addr = self.addr(row, column);
&mut self.values[addr]
}
pub fn transpose(self) -> Self {
let mut m = Self::new(self.width, self.height);
for row in 0..self.height {
for column in 0..self.width {
*m.cell_mut(row, column) = self.cell(column, row);
}
}
m
}
#[inline]
fn addr(&self, row: usize, column: usize) -> usize {
row * self.width + column
}
}
impl From<[[f64; 2]; 2]> for Matrix {
fn from(s: [[f64; 2]; 2]) -> Self {
Self {
width: 2,
height: 2,
values: s.concat(),
}
}
}
impl From<[[f64; 3]; 3]> for Matrix {
fn from(s: [[f64; 3]; 3]) -> Self {
Self {
width: 3,
height: 3,
values: s.concat(),
}
}
}
impl From<[[f64; 4]; 4]> for Matrix {
fn from(s: [[f64; 4]; 4]) -> Self {
Self {
width: 4,
height: 4,
values: s.concat(),
}
}
}
impl PartialEq for Matrix {
fn eq(&self, rside: &Matrix) -> bool {
if self.width != rside.width || self.height != rside.height {
return false;
};
self.values
.iter()
.zip(rside.values.iter())
.all(|(l, r)| eq_f64(*l, *r))
}
}
impl std::ops::Mul for Matrix {
type Output = Matrix;
fn mul(self, rside: Matrix) -> Matrix {
assert_eq!(self.width, 4);
assert_eq!(self.height, 4);
assert_eq!(rside.width, 4);
assert_eq!(rside.height, 4);
let mut m = Matrix::new(self.width, self.height);
for row in 0..4 {
for column in 0..4 {
*m.cell_mut(row, column) = self.cell(row, 0) * rside.cell(0, column)
+ self.cell(row, 1) * rside.cell(1, column)
+ self.cell(row, 2) * rside.cell(2, column)
+ self.cell(row, 3) * rside.cell(3, column);
}
}
m
}
}
impl std::ops::Mul<Tuple> for Matrix {
type Output = Tuple;
fn mul(self, rside: Tuple) -> Tuple {
assert_eq!(self.width, 4);
assert_eq!(self.height, 4);
let mut t = [0.; 4];
for row in 0..4 {
t[row] = self.cell(row, 0) * rside.0
+ self.cell(row, 1) * rside.1
+ self.cell(row, 2) * rside.2
+ self.cell(row, 3) * rside.3;
}
Tuple::from(t)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_constructs_a_matrix() {
let m = Matrix::from([[-3., 5.], [1., -2.]]);
assert_eq!(m.cell(0, 0), -3.);
assert_eq!(m.cell(0, 1), 5.);
assert_eq!(m.cell(1, 0), 1.);
assert_eq!(m.cell(1, 1), -2.);
let m = Matrix::from([[-3., 5., 0.], [1., -2., -7.], [0., 1., 1.]]);
assert_eq!(m.cell(0, 0), -3.);
assert_eq!(m.cell(1, 1), -2.);
assert_eq!(m.cell(2, 2), 1.);
let m = Matrix::from([
[1., 2., 3., 4.],
[5.5, 6.5, 7.5, 8.5],
[9., 10., 11., 12.],
[13.5, 14.5, 15.5, 16.5],
]);
assert_eq!(m.cell(0, 0), 1.);
assert_eq!(m.cell(0, 3), 4.);
assert_eq!(m.cell(1, 0), 5.5);
assert_eq!(m.cell(1, 2), 7.5);
assert_eq!(m.cell(2, 2), 11.);
assert_eq!(m.cell(3, 0), 13.5);
assert_eq!(m.cell(3, 2), 15.5);
}
#[test]
fn it_compares_two_matrices() {
let a = Matrix::from([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 8., 7., 6.],
[5., 4., 3., 2.],
]);
let b = Matrix::from([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 8., 7., 6.],
[5., 4., 3., 2.],
]);
let c = Matrix::from([
[2., 3., 4., 5.],
[6., 7., 8., 9.],
[8., 7., 6., 5.],
[4., 3., 2., 1.],
]);
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn multiply_two_matrices() {
let a = Matrix::from([
[1., 2., 3., 4.],
[5., 6., 7., 8.],
[9., 8., 7., 6.],
[5., 4., 3., 2.],
]);
let b = Matrix::from([
[-2., 1., 2., 3.],
[3., 2., 1., -1.],
[4., 3., 6., 5.],
[1., 2., 7., 8.],
]);
let expected = Matrix::from([
[20., 22., 50., 48.],
[44., 54., 114., 108.],
[40., 58., 110., 102.],
[16., 26., 46., 42.],
]);
assert_eq!(a * b, expected);
}
#[test]
fn multiply_matrix_by_tuple() {
let a = Matrix::from([
[1., 2., 3., 4.],
[2., 4., 4., 2.],
[8., 6., 4., 1.],
[0., 0., 0., 1.],
]);
let b = Tuple(1., 2., 3., 1.);
assert_eq!(a * b, Tuple(18., 24., 33., 1.));
}
#[test]
fn identity_matrix() {
let a = Matrix::from([
[0., 1., 2., 4.],
[1., 2., 4., 8.],
[2., 4., 8., 16.],
[4., 8., 16., 32.],
]);
assert_eq!(a.clone() * Matrix::identity(), a);
}
#[test]
fn transpose_a_matrix() {
let a = Matrix::from([
[0., 9., 3., 0.],
[9., 8., 0., 8.],
[1., 8., 5., 3.],
[0., 0., 5., 8.],
]);
let expected = Matrix::from([
[0., 9., 1., 0.],
[9., 8., 8., 0.],
[3., 0., 5., 5.],
[0., 8., 3., 8.],
]);
assert_eq!(a.transpose(), expected);
}
}

View File

@ -1,11 +1,13 @@
mod canvas;
mod color;
mod matrix;
mod point;
mod tuple;
mod vector;
pub use canvas::Canvas;
pub use color::Color;
pub use matrix::Matrix;
pub use point::Point;
pub use tuple::Tuple;
pub use vector::Vector;

View File

@ -15,6 +15,12 @@ impl Tuple {
}
}
impl From<[f64; 4]> for Tuple {
fn from(source: [f64; 4]) -> Self {
Self(source[0], source[1], source[2], source[3])
}
}
impl PartialEq for Tuple {
fn eq(&self, r: &Tuple) -> bool {
eq_f64(self.0, r.0) && eq_f64(self.1, r.1) && eq_f64(self.2, r.2) && eq_f64(self.3, r.3)