diff options
author | Ashkan Kiani <ashkan.k.kiani@gmail.com> | 2019-10-17 15:15:32 -0700 |
---|---|---|
committer | Ashkan Kiani <ashkan.k.kiani@gmail.com> | 2019-10-17 15:15:32 -0700 |
commit | f0c8b75cc423ff0decfdf8ebba0dc7c3015c9aac (patch) | |
tree | a49d90c84ffa249a82ba58ba8e867ac9f5159f2e /lua/trie.lua |
Initial commit.
Diffstat (limited to 'lua/trie.lua')
-rw-r--r-- | lua/trie.lua | 306 |
1 files changed, 306 insertions, 0 deletions
diff --git a/lua/trie.lua b/lua/trie.lua new file mode 100644 index 0000000..60750f9 --- /dev/null +++ b/lua/trie.lua @@ -0,0 +1,306 @@ +-- Trie implementation in luajit +-- Copyright © 2019 Ashkan Kiani + +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. + +-- This program is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. + +-- You should have received a copy of the GNU General Public License +-- along with this program. If not, see <http://www.gnu.org/licenses/>. +local ffi = require 'ffi' + +local bnot = bit.bnot +local band, bor, bxor = bit.band, bit.bor, bit.bxor +local lshift, rshift, rol = bit.lshift, bit.rshift, bit.rol + +ffi.cdef [[ +struct Trie { + bool is_leaf; + struct Trie* character[62]; +}; +void *malloc(size_t size); +void free(void *ptr); +]] + +local Trie_t = ffi.typeof('struct Trie') +local Trie_ptr_t = ffi.typeof('$ *', Trie_t) + +local function byte_to_index(b) + -- 0-9 starts at string.byte('0') == 0x30 == 48 == 0b0011_0000 + -- A-Z starts at string.byte('A') == 0x41 == 65 == 0b0100_0001 + -- a-z starts at string.byte('a') == 0x61 == 97 == 0b0110_0001 + + -- This works for mapping characters to + -- 0-9 A-Z a-z in that order + -- Letters have bit 0x40 set, so we use that as an indicator for + -- an additional offset from the space of the digits, and then + -- add the 10 allocated for the range of digits. + -- Then, within that indicator for letters, we subtract another + -- (65 - 97) which is the difference between lower and upper case + -- and add back another 26 to allocate for the range of uppercase + -- letters. + -- return b - 0x30 + -- + rshift(b, 6) * ( + -- 0x30 - 0x41 + -- + 10 + -- + band(1, rshift(b, 5)) * ( + -- 0x61 - 0x41 + -- + 26 + -- )) + return b - 0x30 - rshift(b, 6) * (7 + band(1, rshift(b, 5)) * 6) +end + +local function insensitive_byte_to_index(b) + -- return b - 0x30 + -- + rshift(b, 6) * ( + -- 0x30 - 0x61 + -- + 10 + -- ) + b = bor(b, 0x20) + return b - 0x30 - rshift(b, 6) * 39 +end + +local function verify_byte_to_index() + local chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + for i = 1, #chars do + local c = chars:sub(i,i) + local index = byte_to_index(string.byte(c)) + assert((i-1) == index, vim.inspect{index=index,c=c}) + end +end + +local Trie_size = ffi.sizeof(Trie_t) + +local function new_trie() + local ptr = ffi.C.malloc(Trie_size) + ffi.fill(ptr, Trie_size) + return ffi.cast(Trie_ptr_t, ptr) +end + +local INDEX_LOOKUP_TABLE = ffi.new('uint8_t[256]') +local b_a = string.byte('a') +local b_z = string.byte('z') +local b_A = string.byte('A') +local b_Z = string.byte('Z') +local b_0 = string.byte('0') +local b_9 = string.byte('9') +for i = 0, 255 do + if i >= b_0 and i <= b_9 then + INDEX_LOOKUP_TABLE[i] = i - b_0 + elseif i >= b_A and i <= b_Z then + INDEX_LOOKUP_TABLE[i] = i - b_A + 10 + elseif i >= b_a and i <= b_z then + INDEX_LOOKUP_TABLE[i] = i - b_a + 10 + 26 + else + INDEX_LOOKUP_TABLE[i] = 255 + end +end + +local function insert(trie, value) + if trie == nil then return false end + local node = trie + for i = 1, #value do + local index = INDEX_LOOKUP_TABLE[value:byte(i)] + if index == 255 then + return false + end + if node.character[index] == nil then + node.character[index] = new_trie() + end + node = node.character[index] + end + node.is_leaf = true + return node, trie +end + +local function search(trie, value) + if trie == nil then return false end + local node = trie + for i = 1, #value do + local index = INDEX_LOOKUP_TABLE[value:byte(i)] + if index == 255 then + return + end + local child = node.character[index] + if child == nil then + return false + end + node = child + end + return node.is_leaf +end + +local function longest_prefix(trie, value) + if trie == nil then return false end + local node = trie + local last_i = nil + for i = 1, #value do + local index = INDEX_LOOKUP_TABLE[value:byte(i)] + if index == 255 then + break + end + local child = node.character[index] + if child == nil then + break + end + if child.is_leaf then + last_i = i + end + node = child + end + if last_i then + return value:sub(1, last_i) + end +end + +--- Printing utilities + +local function index_to_char(index) + if index < 10 then + return string.char(index + b_0) + elseif index < 36 then + return string.char(index - 10 + b_A) + else + return string.char(index - 26 - 10 + b_a) + end +end + +local function trie_structure(trie) + if trie == nil then + return nil + end + local children = {} + for i = 0, 61 do + local child = trie.character[i] + if child ~= nil then + local child_table = trie_structure(child) + child_table.c = index_to_char(i) + table.insert(children, child_table) + end + end + return { + is_leaf = trie.is_leaf; + children = children; + } +end + +local function print_structure(s) + local mark + if not s then + return nil + end + if s.c then + if s.is_leaf then + mark = s.c.."*" + else + mark = s.c.."─" + end + else + mark = "├─" + end + if #s.children == 0 then + return {mark} + end + local lines = {} + for _, child in ipairs(s.children) do + local child_lines = print_structure(child) + for _, child_line in ipairs(child_lines) do + table.insert(lines, child_line) + end + end + for i, v in ipairs(lines) do + if v:match("^[%w%d]") then + if i == 1 then + lines[i] = mark.."─"..v + elseif i == #lines then + lines[i] = "└──"..v + else + lines[i] = "├──"..v + end + else + if i == 1 then + lines[i] = mark.."─"..v + elseif #s.children > 1 then + lines[i] = "│ "..v + else + lines[i] = " "..v + end + end + end + return lines +end + +local function free_trie(trie) + for i = 0, 61 do + local child = trie.character[i] + if child ~= nil then + free_trie(child) + end + end + ffi.C.free(trie) +end + +local Trie_mt = { + __index = { + insert = insert; + search = search; + longest_prefix = longest_prefix; + }; + __tostring = function(trie) + local structure = trie_structure(trie) + if structure then + return table.concat(print_structure(structure), '\n') + else + return 'nil' + end + end; + __gc = free_trie; + } +local Trie = ffi.metatype(Trie_t, Trie_mt) + +return Trie + +-- local tests = { +-- "cat"; +-- "car"; +-- "celtic"; +-- "carb"; +-- "carb0"; +-- "CART0"; +-- "CaRT0"; +-- "Cart0"; +-- "931"; +-- "191"; +-- "121"; +-- "cardio"; +-- "call"; +-- "calcium"; +-- "calciur"; +-- "carry"; +-- "dog"; +-- "catdog"; +-- } +-- local trie = Trie() +-- for i, v in ipairs(tests) do +-- trie:insert(v) +-- end + +-- print(trie) +-- print(trie.character[0]) +-- print("catdo", trie:longest_prefix("catdo")) +-- print("catastrophic", trie:longest_prefix("catastrophic")) + +-- local COLOR_MAP = vim.api.nvim_get_color_map() +-- local start = os.clock() +-- for k, v in pairs(COLOR_MAP) do +-- insert(trie, k) +-- end +-- print(os.clock() - start) + +-- print(table.concat(print_structure(trie_structure(trie)), '\n')) |