{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, Rank2Types #-}
module Text.EditDistance.Bits (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Data.Bits
import Data.Char
import Data.Word
import Data.List
import qualified Data.IntMap as IM
{-# INLINE foldl'3k #-}
foldl'3k :: (forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res)
-> (a, b, c) -> [x] -> (a, b, c)
foldl'3k f = go
where go (!_, !_, !_) _ | False = undefined
go ( a, b, c) [] = (a, b, c)
go ( a, b, c) (x:xs) = f (a, b, c) x $ \abc -> go abc xs
{-# INLINE foldl'5k #-}
foldl'5k :: (forall res. (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res)
-> (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
foldl'5k f = go
where go (!_, !_, !_, !_, !_) _ | False = undefined
go ( a, b, c, d, e) [] = (a, b, c, d, e)
go ( a, b, c, d, e) (x:xs) = f (a, b, c, d, e) x $ \abcde -> go abcde xs
levenshteinDistance :: String -> String -> Int
levenshteinDistance str1 str2 = levenshteinDistanceWithLengths m n str1 str2
where
m = length str1
n = length str2
levenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !m !n str1 str2
| m <= n = if n <= 64
then levenshteinDistance' (undefined :: Word64) m n str1 str2
else levenshteinDistance' (undefined :: Integer) m n str1 str2
| otherwise = if m <= 64
then levenshteinDistance' (undefined :: Word64) n m str2 str1
else levenshteinDistance' (undefined :: Integer) n m str2 str1
{-# SPECIALIZE levenshteinDistance' :: Word64 -> Int -> Int -> String -> String -> Int #-}
{-# SPECIALIZE levenshteinDistance' :: Integer -> Int -> Int -> String -> String -> Int #-}
levenshteinDistance' :: (Num bv, Bits bv) => bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (_bv_dummy :: bv) !m !n str1 str2
| [] <- str1 = n
| otherwise = extractAnswer $ foldl'3k (levenshteinDistanceWorker (matchVectors str1) top_bit_mask vector_mask) (m_ones, 0, m) str2
where m_ones@vector_mask = (2 ^ m) - 1
top_bit_mask = 1 `shiftL` (m - 1) :: bv
extractAnswer (_, _, distance) = distance
{-# SPECIALIZE INLINE levenshteinDistanceWorker :: IM.IntMap Word64 -> Word64 -> Word64 -> (Word64, Word64, Int) -> Char -> ((Word64, Word64, Int) -> res) -> res #-}
{-# SPECIALIZE INLINE levenshteinDistanceWorker :: IM.IntMap Integer -> Integer -> Integer -> (Integer, Integer, Int) -> Char -> ((Integer, Integer, Int) -> res) -> res #-}
levenshteinDistanceWorker :: (Num bv, Bits bv)
=> IM.IntMap bv -> bv -> bv -> (bv, bv, Int) -> Char
-> ((bv, bv, Int) -> res) -> res
levenshteinDistanceWorker !str1_mvs !top_bit_mask !vector_mask (!vp, !vn, !distance) !char2 k
= vp' `seq` vn' `seq` distance'' `seq` k (vp', vn', distance'')
where
pm' = IM.findWithDefault 0 (ord char2) str1_mvs
d0' = ((((pm' .&. vp) + vp) .&. vector_mask) `xor` vp) .|. pm' .|. vn
hp' = vn .|. sizedComplement vector_mask (d0' .|. vp)
hn' = d0' .&. vp
hp'_shift = ((hp' `shiftL` 1) .|. 1) .&. vector_mask
hn'_shift = (hn' `shiftL` 1) .&. vector_mask
vp' = hn'_shift .|. sizedComplement vector_mask (d0' .|. hp'_shift)
vn' = d0' .&. hp'_shift
distance' = if hp' .&. top_bit_mask /= 0 then distance + 1 else distance
distance'' = if hn' .&. top_bit_mask /= 0 then distance' - 1 else distance'
restrictedDamerauLevenshteinDistance :: String -> String -> Int
restrictedDamerauLevenshteinDistance str1 str2 = restrictedDamerauLevenshteinDistanceWithLengths m n str1 str2
where
m = length str1
n = length str2
restrictedDamerauLevenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !m !n str1 str2
| m <= n = if n <= 64
then restrictedDamerauLevenshteinDistance' (undefined :: Word64) m n str1 str2
else restrictedDamerauLevenshteinDistance' (undefined :: Integer) m n str1 str2
| otherwise = if m <= 64
then restrictedDamerauLevenshteinDistance' (undefined :: Word64) n m str2 str1
else restrictedDamerauLevenshteinDistance' (undefined :: Integer) n m str2 str1
{-# SPECIALIZE restrictedDamerauLevenshteinDistance' :: Word64 -> Int -> Int -> String -> String -> Int #-}
{-# SPECIALIZE restrictedDamerauLevenshteinDistance' :: Integer -> Int -> Int -> String -> String -> Int #-}
restrictedDamerauLevenshteinDistance' :: (Num bv, Bits bv) => bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (_bv_dummy :: bv) !m !n str1 str2
| [] <- str1 = n
| otherwise = extractAnswer $ foldl'5k (restrictedDamerauLevenshteinDistanceWorker (matchVectors str1) top_bit_mask vector_mask) (0, 0, m_ones, 0, m) str2
where m_ones@vector_mask = (2 ^ m) - 1
top_bit_mask = 1 `shiftL` (m - 1) :: bv
extractAnswer (_, _, _, _, distance) = distance
{-# SPECIALIZE INLINE restrictedDamerauLevenshteinDistanceWorker :: IM.IntMap Word64 -> Word64 -> Word64 -> (Word64, Word64, Word64, Word64, Int) -> Char -> ((Word64, Word64, Word64, Word64, Int) -> res) -> res #-}
{-# SPECIALIZE INLINE restrictedDamerauLevenshteinDistanceWorker :: IM.IntMap Integer -> Integer -> Integer -> (Integer, Integer, Integer, Integer, Int) -> Char -> ((Integer, Integer, Integer, Integer, Int) -> res) -> res #-}
restrictedDamerauLevenshteinDistanceWorker :: (Num bv, Bits bv) => IM.IntMap bv -> bv -> bv -> (bv, bv, bv, bv, Int) -> Char -> ((bv, bv, bv, bv, Int) -> res) -> res
restrictedDamerauLevenshteinDistanceWorker !str1_mvs !top_bit_mask !vector_mask (!pm, !d0, !vp, !vn, !distance) !char2 k
= pm' `seq` d0' `seq` vp' `seq` vn' `seq` distance'' `seq` k (pm', d0', vp', vn', distance'')
where
pm' = IM.findWithDefault 0 (ord char2) str1_mvs
d0' = ((((sizedComplement vector_mask d0) .&. pm') `shiftL` 1) .&. pm)
.|. ((((pm' .&. vp) + vp) .&. vector_mask) `xor` vp) .|. pm' .|. vn
hp' = vn .|. sizedComplement vector_mask (d0' .|. vp)
hn' = d0' .&. vp
hp'_shift = ((hp' `shiftL` 1) .|. 1) .&. vector_mask
hn'_shift = (hn' `shiftL` 1) .&. vector_mask
vp' = hn'_shift .|. sizedComplement vector_mask (d0' .|. hp'_shift)
vn' = d0' .&. hp'_shift
distance' = if hp' .&. top_bit_mask /= 0 then distance + 1 else distance
distance'' = if hn' .&. top_bit_mask /= 0 then distance' - 1 else distance'
{-# SPECIALIZE INLINE sizedComplement :: Word64 -> Word64 -> Word64 #-}
{-# SPECIALIZE INLINE sizedComplement :: Integer -> Integer -> Integer #-}
sizedComplement :: (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement vector_mask vect = vector_mask `xor` vect
{-# SPECIALIZE matchVectors :: String -> IM.IntMap Word64 #-}
{-# SPECIALIZE matchVectors :: String -> IM.IntMap Integer #-}
matchVectors :: (Num bv, Bits bv) => String -> IM.IntMap bv
matchVectors = snd . foldl' go (0 :: Int, IM.empty)
where
go (!ix, !im) char = let ix' = ix + 1
im' = IM.insertWith (.|.) (ord char) (2 ^ ix) im
in (ix', im')