3

我正在编写自己的矩阵模块以供娱乐和练习(时间和空间复杂度无关紧要)。现在我想实现矩阵乘法,我正在努力解决它。这可能是我使用 Haskell 的原因,但我没有太多经验。这是我的数据类型:

data Matrix a =
M {
  rows::Int,
  cols::Int,
  values::[a]
}

它将这样的 3x2 矩阵存储在数组中:

1 2
3 4
5 6
= [1,2,3,4,5,6]

我有一个有点工作的转置功能

transpose::(Matrix a)->(Matrix a)
transpose (M rows cols values) = M cols rows (aux values 0 0 [])
  where
   aux::[a]->Int->Int->[a]->[a]
   aux values row col transposed 
     | cols > col =
       if rows > row then 
         aux values (row+1) col (transposed ++ [valueAtIndex (M rows cols values) (row,col)])
       else aux values 0 (col+1) transposed
     | otherwise = transposed

要索引数组中的元素,我正在使用这个函数

valueAtIndex::(Matrix a)->(Int, Int)->a
valueAtIndex (M rows cols values) (row, col) 
  | rows <= row || cols <= col = error "indices too large for given Matrix"
  | otherwise = values !! (cols * row + col)

据我了解,我必须为 m1: 2x3 和 m2: 3x2 获取这样的元素

m1(0,0)*m2(0,0)+m1(0,1)*m2(0,1)+m1(0,2)*m2(0,2)
m1(0,0)*m2(1,0)+m1(0,1)*m2(1,1)+m1(0,2)*m2(1,2)
m1(1,0)*m2(0,0)+m1(1,1)*m2(0,1)+m1(1,2)*m2(0,2)
m1(1,0)*m2(1,0)+m1(1,1)*m2(1,1)+m1(1,2)*m2(2,2)

现在我需要一个函数,它需要两个矩阵,rows m1 == cols m2然后以某种方式递归计算正确的矩阵。

multiplyMatrix::Num a=>(Matrix a)->(Matrix a)->(Matrix a)
4

1 回答 1

3

首先,我不太相信这样的线性列表是个好主意。Haskell 中的列表被建模为链表。这意味着通常访问第k个元素将在O(k)中运行。因此,对于m×n矩阵,这意味着需要O(mn)才能访问最后一个元素。通过使用二维链表:包含链表的链表,我们将其缩小到O(m+n),这通常更快。是的,由于您使用了更多“缺点”数据构造函数,因此存在一些开销,但遍历的数量通常较低。如果您真的想要快速访问,您应该使用数组、向量等。但是还有其他设计决策要做。

所以我建议我们将矩阵建模为:

data Matrix a = M {
  rows :: Int,
  cols :: Int,
  values :: [[a]]
}

现在有了这个数据构造函数,我们可以将转置定义为:

transpose' :: Matrix a -> Matrix a
transpose' (M r c as) = M c r (trans as)
    where trans [] = []
          trans xs = map head xs : trans (map tail xs)

(这里我们假设列表的列表总是矩形的)

所以现在进行矩阵乘法。如果AB是两个矩阵,并且C = A × B,那么这基本上意味着a i, jA的第i行和B的第j列的点积。或A的第i行和B T的第j行(B的转置)。因此,我们可以将点积定义为:

dot_prod :: Num a => [a] -> [a] -> a
dot_prod xs ys = sum (zipWith (*) xs ys)

现在只需遍历行和列,并将元素放在正确的列表中。喜欢:

mat_mul :: Num a => Matrix a -> Matrix a -> Matrix a
mat_mul (M r ca xss) m2 | ca /= ra = error "Invalid matrix shapes"
                        | otherwise = M r c (matmul xss)
    where (M c rb yss) = transpose m2
          matmul [] = []
          matmul (xs:xss) = generaterow yss xs : matmul xss
          generaterow [] _ = []
          generaterow (ys:yss) xs = dot_prod xs ys : generaterow yss xs
于 2018-06-18T21:18:14.333 回答