chore: optimize string search with Aho–Corasick algorithm (#1476)

* chore: optimize string search with Aho–Corasick algorithm

* chore: optimize keywords replacer

* fix: replacer bugs

* chore: reorder members
master
Kevin Wan 3 years ago committed by GitHub
parent 09d1fad6e0
commit f1102fb262
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,8 @@ package stringx
type node struct {
children map[rune]*node
fail *node
depth int
end bool
}
@ -12,17 +14,19 @@ func (n *node) add(word string) {
}
nd := n
for _, char := range chars {
var depth int
for i, char := range chars {
if nd.children == nil {
child := new(node)
nd.children = map[rune]*node{
char: child,
}
child.depth = i + 1
nd.children = map[rune]*node{char: child}
nd = child
} else if child, ok := nd.children[char]; ok {
nd = child
depth++
} else {
child := new(node)
child.depth = i + 1
nd.children[char] = child
nd = child
}
@ -30,3 +34,68 @@ func (n *node) add(word string) {
nd.end = true
}
func (n *node) build() {
n.fail = n
for _, child := range n.children {
child.fail = n
n.buildNode(child)
}
}
func (n *node) buildNode(nd *node) {
if nd.children == nil {
return
}
var fifo []*node
for key, child := range nd.children {
fifo = append(fifo, child)
if fail, ok := nd.fail.children[key]; ok {
child.fail = fail
} else {
child.fail = n
}
}
for _, val := range fifo {
n.buildNode(val)
}
}
func (n *node) find(chars []rune) []scope {
var scopes []scope
size := len(chars)
cur := n
for i := 0; i < size; i++ {
child, ok := cur.children[chars[i]]
if ok {
cur = child
} else if cur == n {
continue
} else {
cur = cur.fail
if child, ok = cur.children[chars[i]]; !ok {
continue
}
cur = child
}
if child.end {
scopes = append(scopes, scope{
start: i + 1 - child.depth,
stop: i + 1,
})
}
if child.fail != n && child.fail.end {
scopes = append(scopes, scope{
start: i + 1 - child.fail.depth,
stop: i + 1,
})
}
}
return scopes
}

@ -0,0 +1,25 @@
package stringx
import "testing"
func BenchmarkNodeFind(b *testing.B) {
b.ReportAllocs()
keywords := []string{
"A",
"AV",
"AV演员",
"无名氏",
"AV演员色情",
"日本AV女优",
}
trie := new(node)
for _, keyword := range keywords {
trie.add(keyword)
}
trie.build()
for i := 0; i < b.N; i++ {
trie.find([]rune("日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演"))
}
}

@ -9,7 +9,7 @@ type (
}
replacer struct {
node
*node
mapping map[string]string
}
)
@ -17,58 +17,81 @@ type (
// NewReplacer returns a Replacer.
func NewReplacer(mapping map[string]string) Replacer {
rep := &replacer{
node: new(node),
mapping: mapping,
}
for k := range mapping {
rep.add(k)
}
rep.build()
return rep
}
// Replace replaces text with given substitutes.
func (r *replacer) Replace(text string) string {
var builder strings.Builder
var start int
chars := []rune(text)
size := len(chars)
start := -1
for i := 0; i < size; i++ {
child, ok := r.children[chars[i]]
if !ok {
for start < size {
cur := r.node
if start > 0 {
builder.WriteString(string(chars[:start]))
}
for i := start; i < size; i++ {
child, ok := cur.children[chars[i]]
if ok {
cur = child
} else if cur == r.node {
builder.WriteRune(chars[i])
// cur already points to root, set start only
start = i + 1
continue
} else {
curDepth := cur.depth
cur = cur.fail
child, ok = cur.children[chars[i]]
if !ok {
// write this path
builder.WriteString(string(chars[i-curDepth : i+1]))
// go to root
cur = r.node
start = i + 1
continue
}
if start < 0 {
start = i
}
end := -1
if child.end {
end = i + 1
failDepth := cur.depth
// write path before jump
builder.WriteString(string(chars[start : start+curDepth-failDepth]))
start += curDepth - failDepth
cur = child
}
j := i + 1
for ; j < size; j++ {
grandchild, ok := child.children[chars[j]]
if !ok {
break
if cur.end {
val := string(chars[i+1-cur.depth : i+1])
builder.WriteString(r.mapping[val])
builder.WriteString(string(chars[i+1:]))
// only matching this path, all previous paths are done
if start >= i+1-cur.depth && i+1 >= size {
return builder.String()
}
child = grandchild
if child.end {
end = j + 1
i = j
chars = []rune(builder.String())
size = len(chars)
builder.Reset()
break
}
}
if end > 0 {
i = j - 1
builder.WriteString(r.mapping[string(chars[start:end])])
} else {
builder.WriteRune(chars[i])
if !cur.end {
builder.WriteString(string(chars[start:]))
return builder.String()
}
start = -1
}
return builder.String()
return string(chars)
}

@ -0,0 +1,42 @@
//go:build go1.18
// +build go1.18
package stringx
import (
"fmt"
"math/rand"
"strings"
"testing"
)
func FuzzReplacerReplace(f *testing.F) {
keywords := make(map[string]string)
for i := 0; i < 20; i++ {
keywords[Randn(rand.Intn(10)+5)] = Randn(rand.Intn(5) + 1)
}
rep := NewReplacer(keywords)
printableKeywords := func() string {
var buf strings.Builder
for k, v := range keywords {
fmt.Fprintf(&buf, "%q: %q,\n", k, v)
}
return buf.String()
}
f.Add(50)
f.Fuzz(func(t *testing.T, n int) {
text := Randn(rand.Intn(n%50+50) + 1)
defer func() {
if r := recover(); r != nil {
t.Errorf("mapping: %s\ntext: %s", printableKeywords(), text)
}
}()
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("mapping: %s\ntext: %s\nresult: %s\nmatch: %v",
printableKeywords(), text, val, keys)
}
})
}

@ -15,6 +15,14 @@ func TestReplacer_Replace(t *testing.T) {
assert.Equal(t, "零1234五", NewReplacer(mapping).Replace("零一二三四五"))
}
func TestReplacer_ReplaceOverlap(t *testing.T) {
mapping := map[string]string{
"3d": "34",
"bc": "23",
}
assert.Equal(t, "a234e", NewReplacer(mapping).Replace("abcde"))
}
func TestReplacer_ReplaceSingleChar(t *testing.T) {
mapping := map[string]string{
"二": "2",
@ -42,3 +50,99 @@ func TestReplacer_ReplaceMultiMatches(t *testing.T) {
}
assert.Equal(t, "零一23四五一23四五", NewReplacer(mapping).Replace("零一二三四五一二三四五"))
}
func TestReplacer_ReplaceJumpToFail(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cde": "234",
}
assert.Equal(t, "ab234fg", NewReplacer(mapping).Replace("abcdefg"))
}
func TestReplacer_ReplaceJumpToFailDup(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"ccde": "2234",
}
assert.Equal(t, "ab2234fg", NewReplacer(mapping).Replace("abccdefg"))
}
func TestReplacer_ReplaceJumpToFailEnding(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cdef": "2345",
}
assert.Equal(t, "ab2345", NewReplacer(mapping).Replace("abcdef"))
}
func TestReplacer_ReplaceEmpty(t *testing.T) {
mapping := map[string]string{
"bcdf": "1235",
"cdef": "2345",
}
assert.Equal(t, "", NewReplacer(mapping).Replace(""))
}
func TestFuzzCase1(t *testing.T) {
keywords := map[string]string{
"yQyJykiqoh": "xw",
"tgN70z": "Q2P",
"tXKhEn": "w1G8",
"5nfOW1XZO": "GN",
"f4Ov9i9nHD": "cT",
"1ov9Q": "Y",
"7IrC9n": "400i",
"JQLxonpHkOjv": "XI",
"DyHQ3c7": "Ygxux",
"ffyqJi": "u",
"UHuvXrbD8pni": "dN",
"LIDzNbUlTX": "g",
"yN9WZh2rkc8Q": "3U",
"Vhk11rz8CObceC": "jf",
"R0Rt4H2qChUQf": "7U5M",
"MGQzzPCVKjV9": "yYz",
"B5jUUl0u1XOY": "l4PZ",
"pdvp2qfLgG8X": "BM562",
"ZKl9qdApXJ2": "T",
"37jnugkSevU66": "aOHFX",
}
rep := NewReplacer(keywords)
text := "yjF8fyqJiiqrczOCVyoYbLvrMpnkj"
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("result: %s, match: %v", val, keys)
}
}
func TestFuzzCase2(t *testing.T) {
keywords := map[string]string{
"dmv2SGZvq9Yz": "TE",
"rCL5DRI9uFP8": "hvsc8",
"7pSA2jaomgg": "v",
"kWSQvjVOIAxR": "Oje",
"hgU5bYYkD3r6": "qCXu",
"0eh6uI": "MMlt",
"3USZSl85EKeMzw": "Pc",
"JONmQSuXa": "dX",
"EO1WIF": "G",
"uUmFJGVmacjF": "1N",
"DHpw7": "M",
"NYB2bm": "CPya",
"9FiNvBAHHNku5": "7FlDE",
"tJi3I4WxcY": "q5",
"sNJ8Z1ToBV0O": "tl",
"0iOg72QcPo": "RP",
"pSEqeL": "5KZ",
"GOyYqTgmvQ": "9",
"Qv4qCsj": "nl52E",
"wNQ5tOutYu5s8": "6iGa",
}
rep := NewReplacer(keywords)
text := "AoRxrdKWsGhFpXwVqMLWRL74OukwjBuBh0g7pSrk"
val := rep.Replace(text)
keys := rep.(*replacer).node.find([]rune(val))
if len(keys) > 0 {
t.Errorf("result: %s, match: %v", val, keys)
}
}

@ -39,6 +39,8 @@ func NewTrie(words []string, opts ...TrieOption) Trie {
n.add(word)
}
n.build()
return n
}
@ -48,7 +50,7 @@ func (n *trieNode) Filter(text string) (sentence string, keywords []string, foun
return text, nil, false
}
scopes := n.findKeywordScopes(chars)
scopes := n.find(chars)
keywords = n.collectKeywords(chars, scopes)
for _, match := range scopes {
@ -65,7 +67,7 @@ func (n *trieNode) FindKeywords(text string) []string {
return nil
}
scopes := n.findKeywordScopes(chars)
scopes := n.find(chars)
return n.collectKeywords(chars, scopes)
}
@ -85,48 +87,6 @@ func (n *trieNode) collectKeywords(chars []rune, scopes []scope) []string {
return keywords
}
func (n *trieNode) findKeywordScopes(chars []rune) []scope {
var scopes []scope
size := len(chars)
start := -1
for i := 0; i < size; i++ {
child, ok := n.children[chars[i]]
if !ok {
continue
}
if start < 0 {
start = i
}
if child.end {
scopes = append(scopes, scope{
start: start,
stop: i + 1,
})
}
for j := i + 1; j < size; j++ {
grandchild, ok := child.children[chars[j]]
if !ok {
break
}
child = grandchild
if child.end {
scopes = append(scopes, scope{
start: start,
stop: j + 1,
})
}
}
start = -1
}
return scopes
}
func (n *trieNode) replaceWithAsterisk(chars []rune, start, stop int) {
for i := start; i < stop; i++ {
chars[i] = n.mask

@ -6,6 +6,17 @@ import (
"github.com/stretchr/testify/assert"
)
func TestTrieSimple(t *testing.T) {
trie := NewTrie([]string{
"bc",
"cd",
})
output, keywords, found := trie.Filter("abcd")
assert.True(t, found)
assert.Equal(t, "a***", output)
assert.ElementsMatch(t, []string{"bc", "cd"}, keywords)
}
func TestTrie(t *testing.T) {
tests := []struct {
input string
@ -14,11 +25,11 @@ func TestTrie(t *testing.T) {
found bool
}{
{
input: "日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
input: "日本AV演员兼电视、电影演员。无名氏AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演",
output: "日本****兼电视、电影演员。*****女优是xx出道, ******们最精彩的表演是******表演",
keywords: []string{
"AV演员",
"苍井空",
"无名氏",
"AV",
"日本AV女优",
"AV演员色情",
@ -89,7 +100,7 @@ func TestTrie(t *testing.T) {
"一不",
"AV",
"AV演员",
"苍井空",
"无名氏",
"AV演员色情",
"日本AV女优",
})
@ -145,20 +156,3 @@ func TestTrieNested(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, "零########九十", output)
}
func BenchmarkTrie(b *testing.B) {
b.ReportAllocs()
trie := NewTrie([]string{
"A",
"AV",
"AV演员",
"苍井空",
"AV演员色情",
"日本AV女优",
})
for i := 0; i < b.N; i++ {
trie.Filter("日本AV演员兼电视、电影演员。苍井空AV女优是xx出道, 日本AV女优们最精彩的表演是AV演员色情表演")
}
}

@ -3,10 +3,6 @@ package syncx
import "sync"
type (
// SharedCalls is an alias of SingleFlight.
// Deprecated: use SingleFlight.
SharedCalls = SingleFlight
// SingleFlight lets the concurrent calls with the same key to share the call result.
// For example, A called F, before it's done, B called F. Then B would not execute F,
// and shared the result returned by F which called by A.
@ -37,12 +33,6 @@ func NewSingleFlight() SingleFlight {
}
}
// NewSharedCalls returns a SingleFlight.
// Deprecated: use NewSingleFlight.
func NewSharedCalls() SingleFlight {
return NewSingleFlight()
}
func (g *flightGroup) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
c, done := g.createCall(key)
if done {

Loading…
Cancel
Save