How to implement matrix multiplication in Scala

1 Answer

0 votes
// Multiply rows of A by columns of B.

type Matrix = Vector[Vector[Int]]

// Print a 2D array (matrix)
def printMatrix(arr2d: Matrix): Unit =
  arr2d.foreach { row =>
    row.foreach { value =>
      print(f"$value%4d")
    }
    println()
  }

// Multiply matrices A and B into C
def multipleMatrix(a: Matrix, b: Matrix): Matrix =
  val rows = a.length
  val cols = b.head.length
  val inner = a.head.length

  Vector.tabulate(rows, cols) { (i, j) =>
    (0 until inner).map(k => a(i)(k) * b(k)(j)).sum
  }

@main def run(): Unit =
  // Create matrices
  val a: Matrix = Vector(
    Vector(1, 8, 5),
    Vector(6, 7, 1),
    Vector(8, 7, 6)
  )

  val b: Matrix = Vector(
    Vector(4, 8, 1),
    Vector(6, 5, 3),
    Vector(4, 6, 5)
  )

  // c[0][0] = (a[0][0] * b[0][0]) + (a[0][1] * b[1][0]) + (a[0][2] * b[2][0])

  printMatrix(a)
  println()
  printMatrix(b)
  println()

  val c: Matrix = multipleMatrix(a, b)

  printMatrix(c)


/*
run:
        
     1   8   5
     6   7   1
     8   7   6
    
     4   8   1
     6   5   3
     4   6   5
    
    72  78  50
    70  89  32
    98 135  59

*/

 



answered May 25 by avibootz
...