From 2b0a3f838a6129da09fba3f9da324b7e16232819 Mon Sep 17 00:00:00 2001 From: Frederick Zhang Date: Wed, 3 Jan 2024 01:42:08 +1100 Subject: [PATCH] Fix mod path when both mod file and mod dir exist When both src/parent.rs and src/parent/ exist, construct_mod_path can be called with construct_mod_path('src/parent.rs', 'child'). In this case it should return src/parent/child.rs instead of nil. --- lua/neotest-rust/dap.lua | 5 ++++- tests/dap_spec.lua | 7 +++++++ tests/data/simple-package/src/main.rs | 1 + tests/data/simple-package/src/parent.rs | 1 + tests/data/simple-package/src/parent/child.rs | 7 +++++++ tests/init_spec.lua | 12 ++++++------ 6 files changed, 26 insertions(+), 7 deletions(-) create mode 100644 tests/data/simple-package/src/parent.rs create mode 100644 tests/data/simple-package/src/parent/child.rs diff --git a/lua/neotest-rust/dap.lua b/lua/neotest-rust/dap.lua index 886ca96..174dce9 100644 --- a/lua/neotest-rust/dap.lua +++ b/lua/neotest-rust/dap.lua @@ -83,15 +83,18 @@ end -- Determine if mod is in .rs or /mod.rs local function construct_mod_path(src_path, mod_name) local match_str = "(.-)[^\\/]-%.?(%w+)%.?[^\\/]*$" - local abs_path, _ = string.match(src_path, match_str) + local abs_path, parent_mod = string.match(src_path, match_str) local mod_file = abs_path .. mod_name .. ".rs" local mod_dir = abs_path .. mod_name .. sep .. "mod.rs" + local child_mod = abs_path .. parent_mod .. sep .. mod_name .. ".rs" if util.file_exists(mod_file) then return mod_file elseif util.file_exists(mod_dir) then return mod_dir + elseif util.file_exists(child_mod) then + return child_mod end return nil diff --git a/tests/dap_spec.lua b/tests/dap_spec.lua index 158f85b..344c3af 100644 --- a/tests/dap_spec.lua +++ b/tests/dap_spec.lua @@ -61,6 +61,13 @@ describe("get_test_binary", function() assert.equal(expected, actual) end) + async.it("returns the test binary for src/parent/child.rs", function() + local expected = main_actual + local actual = dap.get_test_binary(root, root .. "/src/parent/child.rs") + + assert.equal(expected, actual) + end) + async.it("returns the test binary for tests/test_it.rs", function() assert(test_it_actual) local expected = root .. "/target/debug/deps/test_it-" diff --git a/tests/data/simple-package/src/main.rs b/tests/data/simple-package/src/main.rs index e3457dd..bb0a2f4 100644 --- a/tests/data/simple-package/src/main.rs +++ b/tests/data/simple-package/src/main.rs @@ -1,4 +1,5 @@ mod mymod; +mod parent; fn main() { println!("Hello, world!"); diff --git a/tests/data/simple-package/src/parent.rs b/tests/data/simple-package/src/parent.rs new file mode 100644 index 0000000..21d0b0d --- /dev/null +++ b/tests/data/simple-package/src/parent.rs @@ -0,0 +1 @@ +pub mod child; diff --git a/tests/data/simple-package/src/parent/child.rs b/tests/data/simple-package/src/parent/child.rs new file mode 100644 index 0000000..c1e9866 --- /dev/null +++ b/tests/data/simple-package/src/parent/child.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +mod tests { + #[test] + fn math() { + assert_eq!(1 + 1, 2); + } +} diff --git a/tests/init_spec.lua b/tests/init_spec.lua index a6c35a6..73e33fc 100644 --- a/tests/init_spec.lua +++ b/tests/init_spec.lua @@ -26,7 +26,7 @@ describe("discover_positions", function() id = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", name = "main.rs", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 0, 0, 25, 0 }, + range = { 0, 0, 26, 0 }, type = "file", }, { @@ -34,7 +34,7 @@ describe("discover_positions", function() id = "tests", name = "tests", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 7, 0, 24, 1 }, + range = { 8, 0, 25, 1 }, type = "namespace", }, { @@ -42,7 +42,7 @@ describe("discover_positions", function() id = "tests::basic_math", name = "basic_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 9, 4, 11, 5 }, + range = { 10, 4, 12, 5 }, type = "test", }, }, @@ -51,7 +51,7 @@ describe("discover_positions", function() id = "tests::failed_math", name = "failed_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 14, 4, 16, 5 }, + range = { 15, 4, 17, 5 }, type = "test", }, }, @@ -60,7 +60,7 @@ describe("discover_positions", function() id = "tests::nested", name = "nested", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 18, 4, 23, 5 }, + range = { 19, 4, 24, 5 }, type = "namespace", }, { @@ -68,7 +68,7 @@ describe("discover_positions", function() id = "tests::nested::nested_math", name = "nested_math", path = vim.loop.cwd() .. "/tests/data/simple-package/src/main.rs", - range = { 20, 8, 22, 9 }, + range = { 21, 8, 23, 9 }, type = "test", }, },