{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, Rank2Types #-}

module Text.EditDistance.Bits (
        levenshteinDistance, levenshteinDistanceWithLengths, {-levenshteinDistanceCutoff,-} restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
    ) where

import Data.Bits
import Data.Char
import Data.Word
import Data.List
import qualified Data.IntMap as IM

--import Debug.Trace

--type BitVector = Integer

-- Continuation-passing foldl's to work around the lack of recursive CPR optimisation in GHC

{-# INLINE foldl'3k #-}
foldl'3k :: (forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res)
         -> (a, b, c) -> [x] -> (a, b, c)
foldl'3k :: (forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res)
-> (a, b, c) -> [x] -> (a, b, c)
foldl'3k forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res
f = (a, b, c) -> [x] -> (a, b, c)
go
  where go :: (a, b, c) -> [x] -> (a, b, c)
go (!a
_, !b
_, !c
_) [x]
_      | Bool
False = (a, b, c)
forall a. HasCallStack => a
undefined
        go ( a
a,  b
b,  c
c) []     = (a
a, b
b, c
c)
        go ( a
a,  b
b,  c
c) (x
x:[x]
xs) = (a, b, c) -> x -> ((a, b, c) -> (a, b, c)) -> (a, b, c)
forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res
f (a
a, b
b, c
c) x
x (((a, b, c) -> (a, b, c)) -> (a, b, c))
-> ((a, b, c) -> (a, b, c)) -> (a, b, c)
forall a b. (a -> b) -> a -> b
$ \(a, b, c)
abc -> (a, b, c) -> [x] -> (a, b, c)
go (a, b, c)
abc [x]
xs

{-# INLINE foldl'5k #-}
foldl'5k :: (forall res. (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res)
         -> (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
foldl'5k :: (forall res.
 (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res)
-> (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
foldl'5k forall res. (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res
f = (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
go
  where go :: (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
go (!a
_, !b
_, !c
_, !d
_, !e
_) [x]
_      | Bool
False = (a, b, c, d, e)
forall a. HasCallStack => a
undefined
        go ( a
a,  b
b,  c
c,  d
d,  e
e) []     = (a
a, b
b, c
c, d
d, e
e)
        go ( a
a,  b
b,  c
c,  d
d,  e
e) (x
x:[x]
xs) = (a, b, c, d, e)
-> x -> ((a, b, c, d, e) -> (a, b, c, d, e)) -> (a, b, c, d, e)
forall res. (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res
f (a
a, b
b, c
c, d
d, e
e) x
x (((a, b, c, d, e) -> (a, b, c, d, e)) -> (a, b, c, d, e))
-> ((a, b, c, d, e) -> (a, b, c, d, e)) -> (a, b, c, d, e)
forall a b. (a -> b) -> a -> b
$ \(a, b, c, d, e)
abcde -> (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
go (a, b, c, d, e)
abcde [x]
xs

-- Based on the algorithm presented in "A Bit-Vector Algorithm for Computing Levenshtein and Damerau Edit Distances" in PSC'02 (Heikki Hyyro).
-- See http://www.cs.uta.fi/~helmu/pubs/psc02.pdf and http://www.cs.uta.fi/~helmu/pubs/PSCerr.html for an explanation
levenshteinDistance :: String -> String -> Int
levenshteinDistance :: String -> String -> Int
levenshteinDistance String
str1 String
str2 = Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths Int
m Int
n String
str1 String
str2
  where
    m :: Int
m = String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
    n :: Int
n = String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2

levenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !Int
m !Int
n String
str1 String
str2
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n    = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 -- n must be larger so this check is sufficient
                then Word64 -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (Word64
forall a. HasCallStack => a
undefined :: Word64) Int
m Int
n String
str1 String
str2
                else Integer -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (Integer
forall a. HasCallStack => a
undefined :: Integer) Int
m Int
n String
str1 String
str2
  | Bool
otherwise = if Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 -- m must be larger so this check is sufficient
                then Word64 -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (Word64
forall a. HasCallStack => a
undefined :: Word64) Int
n Int
m String
str2 String
str1
                else Integer -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (Integer
forall a. HasCallStack => a
undefined :: Integer) Int
n Int
m String
str2 String
str1

{-# SPECIALIZE levenshteinDistance' :: Word64 -> Int -> Int -> String -> String -> Int #-}
{-# SPECIALIZE levenshteinDistance' :: Integer -> Int -> Int -> String -> String -> Int #-}
levenshteinDistance' :: (Num bv, Bits bv) => bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' :: bv -> Int -> Int -> String -> String -> Int
levenshteinDistance' (bv
_bv_dummy :: bv) !Int
m !Int
n String
str1 String
str2
  | [] <- String
str1 = Int
n
  | Bool
otherwise  = (bv, bv, Int) -> Int
forall a b c. (a, b, c) -> c
extractAnswer ((bv, bv, Int) -> Int) -> (bv, bv, Int) -> Int
forall a b. (a -> b) -> a -> b
$ (forall res.
 (bv, bv, Int) -> Char -> ((bv, bv, Int) -> res) -> res)
-> (bv, bv, Int) -> String -> (bv, bv, Int)
forall a b c x.
(forall res. (a, b, c) -> x -> ((a, b, c) -> res) -> res)
-> (a, b, c) -> [x] -> (a, b, c)
foldl'3k (IntMap bv
-> bv
-> bv
-> (bv, bv, Int)
-> Char
-> ((bv, bv, Int) -> res)
-> res
forall bv res.
(Num bv, Bits bv) =>
IntMap bv
-> bv
-> bv
-> (bv, bv, Int)
-> Char
-> ((bv, bv, Int) -> res)
-> res
levenshteinDistanceWorker (String -> IntMap bv
forall bv. (Num bv, Bits bv) => String -> IntMap bv
matchVectors String
str1) bv
top_bit_mask bv
vector_mask) (bv
m_ones, bv
0, Int
m) String
str2
  where m_ones :: bv
m_ones@bv
vector_mask = (bv
2 bv -> Int -> bv
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
m) bv -> bv -> bv
forall a. Num a => a -> a -> a
- bv
1
        top_bit_mask :: bv
top_bit_mask = bv
1 bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) :: bv
        extractAnswer :: (a, b, c) -> c
extractAnswer (a
_, b
_, c
distance) = c
distance

{-# SPECIALIZE INLINE levenshteinDistanceWorker :: IM.IntMap Word64  -> Word64  -> Word64  -> (Word64, Word64, Int)   -> Char -> ((Word64,  Word64,  Int) -> res) -> res #-}
{-# SPECIALIZE INLINE levenshteinDistanceWorker :: IM.IntMap Integer -> Integer -> Integer -> (Integer, Integer, Int) -> Char -> ((Integer, Integer, Int) -> res) -> res #-}
levenshteinDistanceWorker :: (Num bv, Bits bv)
                          => IM.IntMap bv -> bv -> bv -> (bv, bv, Int) -> Char
                          -> ((bv, bv, Int) -> res) -> res
levenshteinDistanceWorker :: IntMap bv
-> bv
-> bv
-> (bv, bv, Int)
-> Char
-> ((bv, bv, Int) -> res)
-> res
levenshteinDistanceWorker !IntMap bv
str1_mvs !bv
top_bit_mask !bv
vector_mask (!bv
vp, !bv
vn, !Int
distance) !Char
char2 (bv, bv, Int) -> res
k
  = {- trace (unlines ["pm = " ++ show pm'
                   ,"d0 = " ++ show d0'
                   ,"hp = " ++ show hp'
                   ,"hn = " ++ show hn'
                   ,"vp = " ++ show vp'
                   ,"vn = " ++ show vn'
                   ,"distance' = " ++ show distance'
                   ,"distance'' = " ++ show distance'']) -} bv
vp' bv -> res -> res
`seq` bv
vn' bv -> res -> res
`seq` Int
distance'' Int -> res -> res
`seq` (bv, bv, Int) -> res
k (bv
vp', bv
vn', Int
distance'')
  where
    pm' :: bv
pm' = bv -> Int -> IntMap bv -> bv
forall a. a -> Int -> IntMap a -> a
IM.findWithDefault bv
0 (Char -> Int
ord Char
char2) IntMap bv
str1_mvs

    d0' :: bv
d0' = ((((bv
pm' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vp) bv -> bv -> bv
forall a. Num a => a -> a -> a
+ bv
vp) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask) bv -> bv -> bv
forall a. Bits a => a -> a -> a
`xor` bv
vp) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
pm' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
vn
    hp' :: bv
hp' = bv
vn bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv -> bv -> bv
forall bv. (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement bv
vector_mask (bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
vp)
    hn' :: bv
hn' = bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vp

    hp'_shift :: bv
hp'_shift = ((bv
hp' bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask
    hn'_shift :: bv
hn'_shift = (bv
hn' bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask
    vp' :: bv
vp' = bv
hn'_shift bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv -> bv -> bv
forall bv. (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement bv
vector_mask (bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
hp'_shift)
    vn' :: bv
vn' = bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
hp'_shift

    distance' :: Int
distance' = if bv
hp' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
top_bit_mask bv -> bv -> Bool
forall a. Eq a => a -> a -> Bool
/= bv
0 then Int
distance Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 else Int
distance
    distance'' :: Int
distance'' = if bv
hn' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
top_bit_mask bv -> bv -> Bool
forall a. Eq a => a -> a -> Bool
/= bv
0 then Int
distance' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 else Int
distance'


{-

-- Just can't get this working!

-- Based on the algorithm presented in "A Bit-Vector Algorithm for Computing Levenshtein and Damerau Edit Distances" in PSC'02 (Heikki Hyyro).
-- See http://www.cs.uta.fi/~helmu/pubs/psc02.pdf and http://www.cs.uta.fi/~helmu/pubs/PSCerr.html for an explanation
levenshteinDistanceCutoff :: Int -> String -> String -> Int
levenshteinDistanceCutoff cutoff str1 str2
  | length str1 <= length str2 = levenshteinDistanceCutoff' cutoff str1 str2
  | otherwise = levenshteinDistanceCutoff' cutoff str2 str1

levenshteinDistanceCutoff' :: Int -> String -> String -> Int
levenshteinDistanceCutoff' cutoff str1 str2
  | [] <- str1 = n
  | otherwise  = extractAnswer $ foldl' (levenshteinDistanceCutoffFlatWorker (matchVectors str1))
                                    (foldl' (levenshteinDistanceCutoffDiagWorker (matchVectors str1)) (top_bit_mask, vector_mask, all_ones, 0, initial_pm_offset, initial_dist) str2_diag)
                                    str2_flat
  where m = length str1
        n = length str2
        vector_length = if testBit bottom_factor 0
                        then cutoff      -- Odd
                        else cutoff + 1  -- Even
        all_ones@vector_mask = (2 ^ vector_length) - 1
        top_bit_mask = trace (show bottom_factor ++ ", " ++ show vector_length) $ 1 `shiftL` (vector_length - 1)
        extractAnswer (_, _, _, _, _, distance) = distance

        len_difference = n - m
        top_factor = cutoff + len_difference
        bottom_factor = cutoff - len_difference
        bottom_factor_shift = (bottom_factor `shiftR` 1)

        initial_dist = bottom_factor_shift               -- The distance the virtual first vector ended on
        initial_pm_offset = (top_factor `shiftR` 1)      -- The amount of left shift to apply to the >next< pattern match vector
        diag_threshold = negate bottom_factor_shift + m  -- The index in str2 where we stop going diagonally down and start going across
        (str2_diag, str2_flat) = splitAt diag_threshold str2

levenshteinDistanceCutoffDiagWorker :: IM.IntMap BitVector -> (BitVector, BitVector, BitVector, BitVector, Int, Int) -> Char -> (BitVector, BitVector, BitVector, BitVector, Int, Int)
levenshteinDistanceCutoffDiagWorker !str1_mvs (!top_bit_mask, !vector_mask, !vp, !vn, !pm_offset, !distance) !char2
  = trace (unlines ["vp = " ++ show vp
                   ,"vn = " ++ show vn
                   ,"vector_mask = " ++ show vector_mask
                   ,"pm_offset = " ++ show pm_offset
                   ,"unshifted_pm = " ++ show unshifted_pm
                   ,"pm' = " ++ show pm'
                   ,"d0' = " ++ show d0'
                   ,"hp' = " ++ show hp'
                   ,"hn' = " ++ show hn'
                   ,"vp' = " ++ show vp'
                   ,"vn' = " ++ show vn'
                   ,"distance' = " ++ show distance']) (top_bit_mask, vector_mask, vp', vn', pm_offset - 1, distance')
  where
    unshifted_pm = IM.findWithDefault 0 (ord char2) str1_mvs
    pm' = (unshifted_pm `shift` pm_offset) .&. vector_mask

    d0' = ((((pm' .&. vp) + vp) .&. vector_mask) `xor` vp) .|. pm' .|. vn
    hp' = vn .|. sizedComplement vector_mask (d0' .|. vp)
    hn' = d0' .&. vp

    d0'_shift = d0' `shiftR` 1
    vp' = hn' .|. sizedComplement vector_mask (d0'_shift .|. hp')
    vn' = d0'_shift .&. hp'

    distance' = if d0' .&. top_bit_mask /= 0 then distance else distance + 1

levenshteinDistanceCutoffFlatWorker :: IM.IntMap BitVector -> (BitVector, BitVector, BitVector, BitVector, Int, Int) -> Char -> (BitVector, BitVector, BitVector, BitVector, Int, Int)
levenshteinDistanceCutoffFlatWorker !str1_mvs (!top_bit_mask, !vector_mask, !vp, !vn, !pm_offset, !distance) !char2
  = trace (unlines ["pm_offset = " ++ show pm_offset
                   ,"top_bit_mask' = " ++ show top_bit_mask'
                   ,"vector_mask' = " ++ show vector_mask'
                   ,"pm = " ++ show pm'
                   ,"d0 = " ++ show d0'
                   ,"hp = " ++ show hp'
                   ,"hn = " ++ show hn'
                   ,"vp = " ++ show vp'
                   ,"vn = " ++ show vn'
                   ,"distance' = " ++ show distance'
                   ,"distance'' = " ++ show distance'']) (top_bit_mask', vector_mask', vp', vn', pm_offset - 1, distance'')
  where
    top_bit_mask' = top_bit_mask `shiftR` 1
    vector_mask' = vector_mask `shiftR` 1
    pm' = (IM.findWithDefault 0 (ord char2) str1_mvs `rotate` pm_offset) .&. vector_mask'

    d0' = ((((pm' .&. vp) + vp) `xor` vp) .|. pm' .|. vn) .&. vector_mask'
    hp' = vn .|. sizedComplement vector_mask' (d0' .|. vp)
    hn' = d0' .&. vp

    d0'_shift = d0' `shiftR` 1
    vp' = hn' .|. sizedComplement vector_mask' (d0'_shift .|. hp')
    vn' = d0'_shift .&. hp'

    distance' = if hp' .&. top_bit_mask' /= 0 then distance + 1 else distance
    distance'' = if hn' .&. top_bit_mask' /= 0 then distance' - 1 else distance'

-}

-- Based on the algorithm presented in "A Bit-Vector Algorithm for Computing Levenshtein and Damerau Edit Distances" in PSC'02 (Heikki Hyyro).
-- See http://www.cs.uta.fi/~helmu/pubs/psc02.pdf and http://www.cs.uta.fi/~helmu/pubs/PSCerr.html for an explanation
restrictedDamerauLevenshteinDistance :: String -> String -> Int
restrictedDamerauLevenshteinDistance :: String -> String -> Int
restrictedDamerauLevenshteinDistance String
str1 String
str2 = Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths Int
m Int
n String
str1 String
str2
  where
    m :: Int
m = String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str1
    n :: Int
n = String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
str2

restrictedDamerauLevenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths :: Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !Int
m !Int
n String
str1 String
str2
  | Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n    = if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 -- n must be larger so this check is sufficient
                then Word64 -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (Word64
forall a. HasCallStack => a
undefined :: Word64) Int
m Int
n String
str1 String
str2
                else Integer -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (Integer
forall a. HasCallStack => a
undefined :: Integer) Int
m Int
n String
str1 String
str2
  | Bool
otherwise = if Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 -- m must be larger so this check is sufficient
                then Word64 -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (Word64
forall a. HasCallStack => a
undefined :: Word64) Int
n Int
m String
str2 String
str1
                else Integer -> Int -> Int -> String -> String -> Int
forall bv.
(Num bv, Bits bv) =>
bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (Integer
forall a. HasCallStack => a
undefined :: Integer) Int
n Int
m String
str2 String
str1

{-# SPECIALIZE restrictedDamerauLevenshteinDistance' :: Word64 -> Int -> Int -> String -> String -> Int #-}
{-# SPECIALIZE restrictedDamerauLevenshteinDistance' :: Integer -> Int -> Int -> String -> String -> Int #-}
restrictedDamerauLevenshteinDistance' :: (Num bv, Bits bv) => bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' :: bv -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistance' (bv
_bv_dummy :: bv) !Int
m !Int
n String
str1 String
str2
  | [] <- String
str1 = Int
n
  | Bool
otherwise  = (bv, bv, bv, bv, Int) -> Int
forall a b c d e. (a, b, c, d, e) -> e
extractAnswer ((bv, bv, bv, bv, Int) -> Int) -> (bv, bv, bv, bv, Int) -> Int
forall a b. (a -> b) -> a -> b
$ (forall res.
 (bv, bv, bv, bv, Int)
 -> Char -> ((bv, bv, bv, bv, Int) -> res) -> res)
-> (bv, bv, bv, bv, Int) -> String -> (bv, bv, bv, bv, Int)
forall a b c d e x.
(forall res.
 (a, b, c, d, e) -> x -> ((a, b, c, d, e) -> res) -> res)
-> (a, b, c, d, e) -> [x] -> (a, b, c, d, e)
foldl'5k (IntMap bv
-> bv
-> bv
-> (bv, bv, bv, bv, Int)
-> Char
-> ((bv, bv, bv, bv, Int) -> res)
-> res
forall bv res.
(Num bv, Bits bv) =>
IntMap bv
-> bv
-> bv
-> (bv, bv, bv, bv, Int)
-> Char
-> ((bv, bv, bv, bv, Int) -> res)
-> res
restrictedDamerauLevenshteinDistanceWorker (String -> IntMap bv
forall bv. (Num bv, Bits bv) => String -> IntMap bv
matchVectors String
str1) bv
top_bit_mask bv
vector_mask) (bv
0, bv
0, bv
m_ones, bv
0, Int
m) String
str2
  where m_ones :: bv
m_ones@bv
vector_mask = (bv
2 bv -> Int -> bv
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
m) bv -> bv -> bv
forall a. Num a => a -> a -> a
- bv
1
        top_bit_mask :: bv
top_bit_mask = bv
1 bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) :: bv
        extractAnswer :: (a, b, c, d, e) -> e
extractAnswer (a
_, b
_, c
_, d
_, e
distance) = e
distance

{-# SPECIALIZE INLINE restrictedDamerauLevenshteinDistanceWorker :: IM.IntMap Word64 -> Word64 -> Word64 -> (Word64, Word64, Word64, Word64, Int) -> Char -> ((Word64, Word64, Word64, Word64, Int) -> res) -> res #-}
{-# SPECIALIZE INLINE restrictedDamerauLevenshteinDistanceWorker :: IM.IntMap Integer -> Integer -> Integer -> (Integer, Integer, Integer, Integer, Int) -> Char -> ((Integer, Integer, Integer, Integer, Int) -> res) -> res #-}
restrictedDamerauLevenshteinDistanceWorker :: (Num bv, Bits bv) => IM.IntMap bv -> bv -> bv -> (bv, bv, bv, bv, Int) -> Char -> ((bv, bv, bv, bv, Int) -> res) -> res
restrictedDamerauLevenshteinDistanceWorker :: IntMap bv
-> bv
-> bv
-> (bv, bv, bv, bv, Int)
-> Char
-> ((bv, bv, bv, bv, Int) -> res)
-> res
restrictedDamerauLevenshteinDistanceWorker !IntMap bv
str1_mvs !bv
top_bit_mask !bv
vector_mask (!bv
pm, !bv
d0, !bv
vp, !bv
vn, !Int
distance) !Char
char2 (bv, bv, bv, bv, Int) -> res
k
  = bv
pm' bv -> res -> res
`seq` bv
d0' bv -> res -> res
`seq` bv
vp' bv -> res -> res
`seq` bv
vn' bv -> res -> res
`seq` Int
distance'' Int -> res -> res
`seq` (bv, bv, bv, bv, Int) -> res
k (bv
pm', bv
d0', bv
vp', bv
vn', Int
distance'')
  where
    pm' :: bv
pm' = bv -> Int -> IntMap bv -> bv
forall a. a -> Int -> IntMap a -> a
IM.findWithDefault bv
0 (Char -> Int
ord Char
char2) IntMap bv
str1_mvs

    d0' :: bv
d0' = ((((bv -> bv -> bv
forall bv. (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement bv
vector_mask bv
d0) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
pm') bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
pm) -- No need to mask the shiftL because of the restricted range of pm
      bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. ((((bv
pm' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vp) bv -> bv -> bv
forall a. Num a => a -> a -> a
+ bv
vp) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask) bv -> bv -> bv
forall a. Bits a => a -> a -> a
`xor` bv
vp) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
pm' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
vn
    hp' :: bv
hp' = bv
vn bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv -> bv -> bv
forall bv. (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement bv
vector_mask (bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
vp)
    hn' :: bv
hn' = bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vp

    hp'_shift :: bv
hp'_shift = ((bv
hp' bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask
    hn'_shift :: bv
hn'_shift = (bv
hn' bv -> Int -> bv
forall a. Bits a => a -> Int -> a
`shiftL` Int
1) bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
vector_mask
    vp' :: bv
vp' = bv
hn'_shift bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv -> bv -> bv
forall bv. (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement bv
vector_mask (bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.|. bv
hp'_shift)
    vn' :: bv
vn' = bv
d0' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
hp'_shift

    distance' :: Int
distance' = if bv
hp' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
top_bit_mask bv -> bv -> Bool
forall a. Eq a => a -> a -> Bool
/= bv
0 then Int
distance Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 else Int
distance
    distance'' :: Int
distance'' = if bv
hn' bv -> bv -> bv
forall a. Bits a => a -> a -> a
.&. bv
top_bit_mask bv -> bv -> Bool
forall a. Eq a => a -> a -> Bool
/= bv
0 then Int
distance' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 else Int
distance'


{-# SPECIALIZE INLINE sizedComplement :: Word64 -> Word64 -> Word64 #-}
{-# SPECIALIZE INLINE sizedComplement :: Integer -> Integer -> Integer #-}
sizedComplement :: (Num bv, Bits bv) => bv -> bv -> bv
sizedComplement :: bv -> bv -> bv
sizedComplement bv
vector_mask bv
vect = bv
vector_mask bv -> bv -> bv
forall a. Bits a => a -> a -> a
`xor` bv
vect

{-# SPECIALIZE matchVectors :: String -> IM.IntMap Word64 #-}
{-# SPECIALIZE matchVectors :: String -> IM.IntMap Integer #-}
matchVectors :: (Num bv, Bits bv) => String -> IM.IntMap bv
matchVectors :: String -> IntMap bv
matchVectors = (Int, IntMap bv) -> IntMap bv
forall a b. (a, b) -> b
snd ((Int, IntMap bv) -> IntMap bv)
-> (String -> (Int, IntMap bv)) -> String -> IntMap bv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, IntMap bv) -> Char -> (Int, IntMap bv))
-> (Int, IntMap bv) -> String -> (Int, IntMap bv)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Int, IntMap bv) -> Char -> (Int, IntMap bv)
forall a a.
(Bits a, Integral a, Num a) =>
(a, IntMap a) -> Char -> (a, IntMap a)
go (Int
0 :: Int, IntMap bv
forall a. IntMap a
IM.empty)
  where
    go :: (a, IntMap a) -> Char -> (a, IntMap a)
go (!a
ix, !IntMap a
im) Char
char = let ix' :: a
ix' = a
ix a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
                             im' :: IntMap a
im' = (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
forall a. (a -> a -> a) -> Int -> a -> IntMap a -> IntMap a
IM.insertWith a -> a -> a
forall a. Bits a => a -> a -> a
(.|.) (Char -> Int
ord Char
char) (a
2 a -> a -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ a
ix) IntMap a
im
                         in (a
ix', IntMap a
im')