1

我已经在 python 中实现了用于模式搜索的 Rabin-Karp 2D 算法。但是,我的实现比 1000x2000 矩阵上的蛮力版本慢。请帮我识别瓶颈。谢谢,我很欣赏你的评论。

注意 1:代码在找到模式匹配但运行速度较慢的位置时是正确的,我的计算机上的蛮力版本为 1.23 秒与 0.54 秒。

注意 2:尽管可以提出最坏的情况,使得 Rabin-Karp 可能像蛮力一样慢,但给出的测试用例并不是故意设计成 O(m(n-m+1))。

免责声明:虽然这个问题是 Sedgewick 和 Wayne 的 Algorithms, 4th Edition 中的一个分配问题,但这不是我的作业。我正在学习这个算法。

这是代码:

'''
Searches for a 2D pattern in a 2D text. Assumes that both the pattern and the 
text are rectangles of characters.
O(Mr * Nr * Nc), where Mr is the pattern row length, Nr is the text row length 
and Nc is the text column length

'''
MOD = 10**9+7

class RabinKarp2DV3(object):

    def __init__(self, rad, pattern):
        #Radix of the alphabet. Assumes ASCII characters
        self.RADIX = rad
        self.pattern = pattern
        self.height = len(pattern)
        self.width = len(pattern[0])
        self.factors_col = [0]*(self.height)
        self.factors_row = [0]*(self.width)
        self.factors_col[0] = 1
        for i in range(1, len(self.factors_col)):
            self.factors_col[i] = (self.RADIX * self.factors_col[i - 1]) % MOD
        self.factors_row[0] = 1
        for i in range(1, len(self.factors_row)):
            self.factors_row[i] = (self.RADIX * self.factors_row[i - 1]) % MOD
        hash1d_p = [0]*self.width
        self.hash2D(self.pattern, hash1d_p, self.width)
        self.patternHash = self.SingleHash(hash1d_p)
       

    def hash2D(self, data, hash1d, hei):
        for i in range(hei):
            hash1d[i] = 0
            for j in range(self.height):
                hash1d[i] = (self.RADIX * hash1d[i] + ord(data[j][i])) % MOD

    def rehash2D(self, data, hash1d, hei, j):
        for i in range(hei):
            hash1d[i] = self.RADIX*((hash1d[i] + MOD - self.factors_col[self.height-1] 
                    * ord(data[j][i])%MOD) % MOD) % MOD
            hash1d[i] = (hash1d[i] + ord(data[j+self.height][i])) % MOD

    def SingleHash(self, hash1d):
        res = 0
        for i in range(self.width):            
            res = (self.RADIX * res + hash1d[i]) % MOD
        return res

    def SingleReHash(self, hash, hash1d, pos):
        hash = self.RADIX*((hash + MOD - self.factors_row[self.width-1]*hash1d[pos]%MOD) % MOD) % MOD 
        hash = (hash + hash1d[pos+self.width]) % MOD   
        return hash
    
    
    def check(self, text, i, j):
        x, y = i, j
        for a in range(self.height):
            for b in range(self.width):
                if text[x][y] != self.pattern[a][b]:
                    return False
                y += 1
            x += 1
            y = j
        return True

    def search(self, text):
        hash1d = [0]*len(text[0])
        for i in range(len(text)-self.height+1):            
            if i == 0:
                self.hash2D(text, hash1d, len(text[0]))
            else:
                self.rehash2D(text, hash1d, len(text[0]), i-1)
            textHash = 0
            for j in range(len(text[0]) - self.width+1):
                if j == 0:
                    textHash = self.SingleHash(hash1d)
                else:
                    textHash = self.SingleReHash(textHash, hash1d, j-1)
                #print(i, j, textHash, patternHash)
                if textHash == self.patternHash and self.check(text, i, j):
                    return [i, j]
        return None

class BruteForce(object):

    def __init__(self, pattern):
        self.pattern = pattern
        self.height = len(pattern)
        self.width = len(pattern[0])
        
    def check(self, text, i, j):
        x, y = i, j
        for a in range(self.height):
            for b in range(self.width):
                if text[x][y] != self.pattern[a][b]:
                    return False
                y += 1
            x += 1
            y = j
        return True

    def search(self, text):
        for i in range(len(text)-self.height+1):                        
            for j in range(len(text[0]) - self.width+1):
                if self.check(text, i, j):
                    return [i, j]
        return None




    
if __name__ == "__main__":
    
    import random
    import string
    import time
    chars = string.ascii_uppercase
    im, jm = 1000, 2000
    text = []
    for i in range(im):
        s = ''
        for j in range(jm):
            s += random.choice(chars)
        text.append(s)        
    pattern = []
    for i in range(20):
        pattern.append(text[357+i][478:478+40])
    start_time = time.time()
    matcher = RabinKarp2DV3(256, pattern)
    print(matcher.search(text))
    print("--- %s seconds ---" % (time.time() - start_time))    
    start_time = time.time()
    matcher = BruteForce(pattern)
    print(matcher.search(text))
    print("--- %s seconds ---" % (time.time() - start_time))
4

0 回答 0