aboutsummaryrefslogtreecommitdiff
path: root/lua
diff options
context:
space:
mode:
Diffstat (limited to 'lua')
-rw-r--r--lua/trie.lua77
1 files changed, 19 insertions, 58 deletions
diff --git a/lua/trie.lua b/lua/trie.lua
index 9d84487..b665f76 100644
--- a/lua/trie.lua
+++ b/lua/trie.lua
@@ -14,11 +14,6 @@
-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.
local ffi = require 'ffi'
-local bit = require 'bit'
-
-local bnot = bit.bnot
-local band, bor, bxor = bit.band, bit.bor, bit.bxor
-local lshift, rshift, rol = bit.lshift, bit.rshift, bit.rol
ffi.cdef [[
struct Trie {
@@ -33,50 +28,6 @@ local Trie_t = ffi.typeof('struct Trie')
local Trie_ptr_t = ffi.typeof('$ *', Trie_t)
local Trie_size = ffi.sizeof(Trie_t)
-local function byte_to_index(b)
- -- 0-9 starts at string.byte('0') == 0x30 == 48 == 0b0011_0000
- -- A-Z starts at string.byte('A') == 0x41 == 65 == 0b0100_0001
- -- a-z starts at string.byte('a') == 0x61 == 97 == 0b0110_0001
-
- -- This works for mapping characters to
- -- 0-9 A-Z a-z in that order
- -- Letters have bit 0x40 set, so we use that as an indicator for
- -- an additional offset from the space of the digits, and then
- -- add the 10 allocated for the range of digits.
- -- Then, within that indicator for letters, we subtract another
- -- (65 - 97) which is the difference between lower and upper case
- -- and add back another 26 to allocate for the range of uppercase
- -- letters.
- -- return b - 0x30
- -- + rshift(b, 6) * (
- -- 0x30 - 0x41
- -- + 10
- -- + band(1, rshift(b, 5)) * (
- -- 0x61 - 0x41
- -- + 26
- -- ))
- return b - 0x30 - rshift(b, 6) * (7 + band(1, rshift(b, 5)) * 6)
-end
-
-local function insensitive_byte_to_index(b)
- -- return b - 0x30
- -- + rshift(b, 6) * (
- -- 0x30 - 0x61
- -- + 10
- -- )
- b = bor(b, 0x20)
- return b - 0x30 - rshift(b, 6) * 39
-end
-
-local function verify_byte_to_index()
- local chars = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
- for i = 1, #chars do
- local c = chars:sub(i,i)
- local index = byte_to_index(string.byte(c))
- assert((i-1) == index, vim.inspect{index=index,c=c})
- end
-end
-
local function trie_create()
local ptr = ffi.C.malloc(Trie_size)
ffi.fill(ptr, Trie_size)
@@ -84,7 +35,7 @@ local function trie_create()
end
local INDEX_LOOKUP_TABLE = ffi.new 'uint8_t[256]'
-local CHAR_LOOKUP_TABLE = ffi.new('char[62]', '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')
+local CHAR_LOOKUP_TABLE = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
do
local b = string.byte
for i = 0, 255 do
@@ -117,10 +68,10 @@ local function trie_insert(trie, value)
return node, trie
end
-local function trie_search(trie, value)
+local function trie_search(trie, value, start)
if trie == nil then return false end
local node = trie
- for i = 1, #value do
+ for i = (start or 1), #value do
local index = INDEX_LOOKUP_TABLE[value:byte(i)]
if index == 255 then
return
@@ -134,12 +85,15 @@ local function trie_search(trie, value)
return node.is_leaf
end
-local function trie_longest_prefix(trie, value)
+local function trie_longest_prefix(trie, value, start)
if trie == nil then return false end
+ -- insensitive = insensitive and 0x20 or 0
+ start = start or 1
local node = trie
local last_i = nil
- for i = 1, #value do
+ for i = start, #value do
local index = INDEX_LOOKUP_TABLE[value:byte(i)]
+-- local index = INDEX_LOOKUP_TABLE[bor(insensitive, value:byte(i))]
if index == 255 then
break
end
@@ -153,7 +107,12 @@ local function trie_longest_prefix(trie, value)
node = child
end
if last_i then
- return value:sub(1, last_i)
+ -- Avoid a copy if the whole string is a match.
+ if start == 1 and last_i == #value then
+ return value
+ else
+ return value:sub(start, last_i)
+ end
end
end
@@ -168,7 +127,7 @@ end
local function index_to_char(index)
if index < 0 or index > 61 then return end
- return CHAR_LOOKUP_TABLE[index]
+ return CHAR_LOOKUP_TABLE:sub(index+1, index+1)
end
local function trie_as_table(trie)
@@ -214,11 +173,13 @@ local function print_trie_table(s)
table.insert(lines, child_line)
end
end
+ local child_count = 0
for i, v in ipairs(lines) do
if v:match("^[%w%d]") then
+ child_count = child_count + 1
if i == 1 then
lines[i] = mark.."─"..v
- elseif i == #lines then
+ elseif i == #lines or child_count == #s.children then
lines[i] = "└──"..v
else
lines[i] = "├──"..v
@@ -226,7 +187,7 @@ local function print_trie_table(s)
else
if i == 1 then
lines[i] = mark.."─"..v
- elseif #s.children > 1 then
+ elseif #s.children > 1 and child_count ~= #s.children then
lines[i] = "│ "..v
else
lines[i] = " "..v