aboutsummaryrefslogtreecommitdiff
path: root/lua/trie.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua/trie.lua')
-rw-r--r--lua/trie.lua96
1 files changed, 51 insertions, 45 deletions
diff --git a/lua/trie.lua b/lua/trie.lua
index e74d9eb..9d84487 100644
--- a/lua/trie.lua
+++ b/lua/trie.lua
@@ -14,6 +14,7 @@
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
local ffi = require 'ffi'
+local bit = require 'bit'
local bnot = bit.bnot
local band, bor, bxor = bit.band, bit.bor, bit.bxor
@@ -76,32 +77,30 @@ local function verify_byte_to_index()
end
end
-local function new_trie()
+local function trie_create()
local ptr = ffi.C.malloc(Trie_size)
ffi.fill(ptr, Trie_size)
return ffi.cast(Trie_ptr_t, ptr)
end
-local INDEX_LOOKUP_TABLE = ffi.new('uint8_t[256]')
-local b_a = string.byte('a')
-local b_z = string.byte('z')
-local b_A = string.byte('A')
-local b_Z = string.byte('Z')
-local b_0 = string.byte('0')
-local b_9 = string.byte('9')
-for i = 0, 255 do
- if i >= b_0 and i <= b_9 then
- INDEX_LOOKUP_TABLE[i] = i - b_0
- elseif i >= b_A and i <= b_Z then
- INDEX_LOOKUP_TABLE[i] = i - b_A + 10
- elseif i >= b_a and i <= b_z then
- INDEX_LOOKUP_TABLE[i] = i - b_a + 10 + 26
- else
- INDEX_LOOKUP_TABLE[i] = 255
+local INDEX_LOOKUP_TABLE = ffi.new 'uint8_t[256]'
+local CHAR_LOOKUP_TABLE = ffi.new('char[62]', '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
+do
+ local b = string.byte
+ for i = 0, 255 do
+ if i >= b'0' and i <= b'9' then
+ INDEX_LOOKUP_TABLE[i] = i - b'0'
+ elseif i >= b'A' and i <= b'Z' then
+ INDEX_LOOKUP_TABLE[i] = i - b'A' + 10
+ elseif i >= b'a' and i <= b'z' then
+ INDEX_LOOKUP_TABLE[i] = i - b'a' + 10 + 26
+ else
+ INDEX_LOOKUP_TABLE[i] = 255
+ end
end
end
-local function insert(trie, value)
+local function trie_insert(trie, value)
if trie == nil then return false end
local node = trie
for i = 1, #value do
@@ -110,7 +109,7 @@ local function insert(trie, value)
return false
end
if node.character[index] == nil then
- node.character[index] = new_trie()
+ node.character[index] = trie_create()
end
node = node.character[index]
end
@@ -118,7 +117,7 @@ local function insert(trie, value)
return node, trie
end
-local function search(trie, value)
+local function trie_search(trie, value)
if trie == nil then return false end
local node = trie
for i = 1, #value do
@@ -135,7 +134,7 @@ local function search(trie, value)
return node.is_leaf
end
-local function longest_prefix(trie, value)
+local function trie_longest_prefix(trie, value)
if trie == nil then return false end
local node = trie
local last_i = nil
@@ -158,19 +157,21 @@ local function longest_prefix(trie, value)
end
end
+local function trie_extend(trie, t)
+ assert(type(t) == 'table')
+ for _, v in ipairs(t) do
+ trie_insert(trie, v)
+ end
+end
+
--- Printing utilities
local function index_to_char(index)
- if index < 10 then
- return string.char(index + b_0)
- elseif index < 36 then
- return string.char(index - 10 + b_A)
- else
- return string.char(index - 26 - 10 + b_a)
- end
+ if index < 0 or index > 61 then return end
+ return CHAR_LOOKUP_TABLE[index]
end
-local function trie_structure(trie)
+local function trie_as_table(trie)
if trie == nil then
return nil
end
@@ -178,7 +179,7 @@ local function trie_structure(trie)
for i = 0, 61 do
local child = trie.character[i]
if child ~= nil then
- local child_table = trie_structure(child)
+ local child_table = trie_as_table(child)
child_table.c = index_to_char(i)
table.insert(children, child_table)
end
@@ -189,10 +190,10 @@ local function trie_structure(trie)
}
end
-local function print_structure(s)
+local function print_trie_table(s)
local mark
if not s then
- return nil
+ return {'nil'}
end
if s.c then
if s.is_leaf then
@@ -208,7 +209,7 @@ local function print_structure(s)
end
local lines = {}
for _, child in ipairs(s.children) do
- local child_lines = print_structure(child)
+ local child_lines = print_trie_table(child)
for _, child_line in ipairs(child_lines) do
table.insert(lines, child_line)
end
@@ -235,35 +236,40 @@ local function print_structure(s)
return lines
end
-local function free_trie(trie)
+local function trie_destroy(trie)
if trie == nil then
return
end
for i = 0, 61 do
local child = trie.character[i]
if child ~= nil then
- free_trie(child)
+ trie_destroy(child)
end
end
ffi.C.free(trie)
end
local Trie_mt = {
- __new = new_trie;
+ __new = function(_, init)
+ local trie = trie_create()
+ if type(init) == 'table' then
+ trie_extend(trie, init)
+ end
+ return trie
+ end;
__index = {
- insert = insert;
- search = search;
- longest_prefix = longest_prefix;
+ insert = trie_insert;
+ search = trie_search;
+ longest_prefix = trie_longest_prefix;
+ extend = trie_extend;
};
__tostring = function(trie)
- local structure = trie_structure(trie)
- if structure then
- return table.concat(print_structure(structure), '\n')
- else
+ if trie == nil then
return 'nil'
end
+ return table.concat(print_trie_table(trie_as_table(trie)), '\n')
end;
- __gc = free_trie;
+ __gc = trie_destroy;
}
return ffi.metatype('struct Trie', Trie_mt)
@@ -305,4 +311,4 @@ return ffi.metatype('struct Trie', Trie_mt)
-- end
-- print(os.clock() - start)
--- print(table.concat(print_structure(trie_structure(trie)), '\n'))
+-- print(table.concat(print_trie_table(trie_as_table(trie)), '\n'))