tensor.rs 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. use ndarray::{Array, ArrayD};
  2. use ndarray_rand::rand_distr::StandardNormal;
  3. use ndarray_rand::RandomExt;
  4. // Define a Tensor struct
  5. #[derive(Debug)]
  6. pub struct Tensor {
  7. data: ArrayD<f64>,
  8. }
  9. impl Tensor {
  10. // Create a new tensor with random values
  11. pub fn new(shape: Vec<usize>) -> Self {
  12. let data = ArrayD::random(shape, StandardNormal);
  13. Tensor { data }
  14. }
  15. pub fn apply(&self, func: fn(f64) -> f64) -> Tensor {
  16. let applied_data = self.data.mapv(func);
  17. Tensor { data: applied_data }
  18. }
  19. // Create a new tensor with zeros
  20. pub fn zeros(shape: Vec<usize>) -> Self {
  21. let data = Array::from_elem(shape, 0.0);
  22. Tensor { data }
  23. }
  24. // Create a new tensor with ones
  25. pub fn ones(shape: Vec<usize>) -> Self {
  26. let data = Array::from_elem(shape, 1.0);
  27. Tensor { data }
  28. }
  29. // Perform element-wise addition
  30. pub fn add(&self, other: &Tensor) -> Self {
  31. let result_data = &self.data + &other.data;
  32. Tensor { data: result_data }
  33. }
  34. // Perform element-wise multiplication
  35. pub fn multiply(&self, other: &Tensor) -> Self {
  36. let result_data = &self.data * &other.data;
  37. Tensor { data: result_data }
  38. }
  39. // Perform matrix multiplication
  40. pub fn matmul(&self, other: &Tensor) -> Self {
  41. let dim_lhs = self.data.shape().to_vec();
  42. let dim_rhs = other.data.shape().to_vec();
  43. let lhs = self.data
  44. .to_owned()
  45. .into_shape((dim_lhs[0], dim_lhs[1]))
  46. .expect("Invalid dimensions");
  47. let rhs = other.data
  48. .to_owned()
  49. .into_shape((dim_rhs[0], dim_rhs[1]))
  50. .expect("Invalid dimensions");
  51. let result_data = lhs.dot(&rhs);
  52. Tensor { data: result_data.into_dimensionality().expect("Dimensions!") }
  53. }
  54. // Print the tensor's data
  55. pub fn print(&self) {
  56. println!("{:?}", self.data);
  57. }
  58. }