Skip to content

Commit

Permalink
feat(poc): support running testify suite test function
Browse files Browse the repository at this point in the history
This uses vim.treesitter instead of nvim-treesitter. There are hacks
injected all over the place and the solution is brittle.
  • Loading branch information
fredrikaverpil committed Jul 7, 2024
1 parent b521556 commit 72bec4a
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 5 deletions.
66 changes: 63 additions & 3 deletions lua/neotest-golang/ast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

local lib = require("neotest.lib")

local testify = require("neotest-golang.testify")

local ts = require("nvim-treesitter.ts_utils")
local parsers = require("nvim-treesitter.parsers")

local M = {}

--- Detect test names in Go *._test.go files.
Expand All @@ -27,6 +32,16 @@ function M.detect_tests(file_path)
name: (field_identifier) @test.name (#match? @test.name "^(Test|Example)")) @test.definition
]]

local receiver_method = [[
; query for receiver method, to be used as test suite namespace
(method_declaration
receiver: (parameter_list
(parameter_declaration
; name: (identifier)
type: (pointer_type
(type_identifier) @namespace.name )))) @namespace.definition
]]

local table_tests = [[
;; query for list table tests
(block
Expand Down Expand Up @@ -127,13 +142,58 @@ function M.detect_tests(file_path)
(#eq? @test.key.name @test.key.name1))))))))
]]

local query = test_function .. test_method .. table_tests
local query = test_function .. test_method .. table_tests .. receiver_method
local opts = { nested_tests = true }

---@type neotest.Tree
local positions = lib.treesitter.parse_positions(file_path, query, opts)
local tree = lib.treesitter.parse_positions(file_path, query, opts)

-- HACK: code below for testify suite support.
-- TODO: hide functionality behind opt-in option.
local tree_with_merged_namespaces =
testify.merge_duplicate_namespaces(tree:root())
local testify_query = [[
; query
(function_declaration ; [38, 0] - [40, 1]
name: (identifier) @testify.function_name ; [38, 5] - [38, 14]
;parameters: (parameter_list ; [38, 14] - [38, 28]
; (parameter_declaration ; [38, 15] - [38, 27]
; name: (identifier) ; [38, 15] - [38, 16]
; type: (pointer_type ; [38, 17] - [38, 27]
; (qualified_type ; [38, 18] - [38, 27]
; package: (package_identifier) ; [38, 18] - [38, 25]
; name: (type_identifier))))) ; [38, 26] - [38, 27]
body: (block ; [38, 29] - [40, 1]
(expression_statement ; [39, 1] - [39, 34]
(call_expression ; [39, 1] - [39, 34]
function: (selector_expression ; [39, 1] - [39, 10]
operand: (identifier) @testify.module ; [39, 1] - [39, 6]
field: (field_identifier) @testify.run ) @testify.call ; [39, 7] - [39, 10]
arguments: (argument_list ; [39, 10] - [39, 34]
(identifier) @testify.t ; [39, 11] - [39, 12]
(call_expression ; [39, 14] - [39, 33]
function: (identifier) ; [39, 14] - [39, 17]
arguments: (argument_list ; [39, 17] - [39, 33]
(type_identifier) @testify.receiver ))))))) @testify.definition
]]

local testify_nodes = testify.run_query_on_file(file_path, testify_query)

for test_function, data in pairs(testify_nodes) do
local function_name = nil
local receiver = nil
for _, node in ipairs(data) do
if node.name == "testify.function_name" then
function_name = node.text
end
if node.name == "testify.receiver" then
receiver = node.text
end
end
testify.add(file_path, function_name, receiver) -- FIXME: accumulates forever
end

return positions
return tree_with_merged_namespaces
end

return M
7 changes: 7 additions & 0 deletions lua/neotest-golang/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local options = require("neotest-golang.options")
local ast = require("neotest-golang.ast")
local runspec_dir = require("neotest-golang.runspec_dir")
local runspec_file = require("neotest-golang.runspec_file")
local runspec_namespace = require("neotest-golang.runspec_namespace")
local runspec_test = require("neotest-golang.runspec_test")
local parse = require("neotest-golang.parse")

Expand Down Expand Up @@ -115,6 +116,12 @@ function M.Adapter.build_spec(args)
-- A runspec is to be created, based on on running all tests in the given
-- file.
return runspec_file.build(pos, tree)
elseif pos.type == "namespace" then
-- A runspec is to be created, based on running all tests in the given
-- namespace.

-- return runspec_namespace.build(pos)
return -- delegate to type 'test'
elseif pos.type == "test" then
-- A runspec is to be created, based on on running the given test.
return runspec_test.build(pos, args.strategy)
Expand Down
21 changes: 20 additions & 1 deletion lua/neotest-golang/parse.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ local async = require("neotest.async")
local options = require("neotest-golang.options")
local convert = require("neotest-golang.convert")
local json = require("neotest-golang.json")
local testify = require("neotest-golang.testify")

-- TODO: remove pos_type when properly supporting all position types.
-- and instead get this from the pos.type field.
Expand Down Expand Up @@ -192,6 +193,22 @@ function M.gather_neotest_data_and_set_defaults(tree)
return res
end

local function hack(test_name)
-- HACK: replace receiver with suite for testify.
-- TODO: place this under opt-in option.
-- TODO: could make more efficient by matching on filename first?
for filename, data in pairs(testify.get()) do
for _, entry in ipairs(data) do
-- TODO: better, more reliable matching needed
if string.match(test_name, "^" .. entry.suite .. "/") then
test_name = string.gsub(test_name, entry.suite, entry.receiver)
return test_name
end
end
end
return test_name
end

--- Decorate the internal test result data with go package and test name.
--- This is an important step, in which we figure out exactly which test output
--- belongs to which test in the Neotest position tree.
Expand Down Expand Up @@ -225,14 +242,16 @@ function M.decorate_with_go_package_and_test_name(
if gotestline.Package == golistline.ImportPath then
local pattern = convert.to_lua_pattern(folderpath)
.. "/(.-)/"
.. convert.to_lua_pattern(gotestline.Test)
.. convert.to_lua_pattern(hack(gotestline.Test))
.. "$"
match = tweaked_pos_id:find(pattern, 1, false)
if match ~= nil then
test_data.gotest_data.pkg = gotestline.Package
test_data.gotest_data.name = gotestline.Test
break
end

-- HACK: testify suites
end
if match ~= nil then
break
Expand Down
15 changes: 15 additions & 0 deletions lua/neotest-golang/runspec_namespace.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
local M = {}

--- Build runspec for a namespace.
--- @param pos neotest.Position
--- @return neotest.RunSpec | neotest.RunSpec[] | nil
function M.build(pos)
-- vim.notify(vim.inspect(pos), vim.levels.log.DEBUG) -- FIXME: remove when done implementing/debugging

-- TODO: Implement a runspec for a namespace of tests.
-- A bare return will delegate test execution to per-test execution, which
-- will have to do for now.
return
end

return M
15 changes: 14 additions & 1 deletion lua/neotest-golang/runspec_test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ local convert = require("neotest-golang.convert")
local options = require("neotest-golang.options")
local cmd = require("neotest-golang.cmd")
local dap = require("neotest-golang.dap")
local testify = require("neotest-golang.testify")

local M = {}

Expand All @@ -16,8 +17,20 @@ function M.build(pos, strategy)
local test_folder_absolute_path = string.match(pos.path, "(.+)/")
local golist_data = cmd.golist_data(test_folder_absolute_path)

local pos_id = pos.id

-- HACK: replace receiver with suite for testify.
-- TODO: place this under opt-in option.
for filename, data in pairs(testify.get()) do
for _, entry in ipairs(data) do
if string.match(pos_id, "::" .. entry.receiver .. "::") then
pos_id = string.gsub(pos_id, entry.receiver, entry.suite)
end
end
end

--- @type string
local test_name = convert.to_gotest_test_name(pos.id)
local test_name = convert.to_gotest_test_name(pos_id)
test_name = convert.to_gotest_regex_pattern(test_name)

local test_cmd, json_filepath = cmd.test_command_in_package_with_regexp(
Expand Down
121 changes: 121 additions & 0 deletions lua/neotest-golang/testify.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
local M = {}

--- A lookup map between receiver method name and suite name.
--- Example:

local lookup_map = {}

function M.get()
return lookup_map
end

function M.add(file_name, suite_name, receiver_name)
if not lookup_map[file_name] then
lookup_map[file_name] = {}
end
table.insert(
lookup_map[file_name],
{ suite = suite_name, receiver = receiver_name }
)
end

function M.clear()
lookup_map = {}
end

function M.merge_duplicate_namespaces(node)
if not node._children or #node._children == 0 then
return node
end

local namespaces = {}
local new_children = {}

for _, child in ipairs(node._children) do
if child._data.type == "namespace" then
local existing = namespaces[child._data.name]
if existing then
-- Merge children of duplicate namespace
for _, grandchild in ipairs(child._children) do
table.insert(existing._children, grandchild)
grandchild._parent = existing
end
else
namespaces[child._data.name] = child
table.insert(new_children, child)
end
else
table.insert(new_children, child)
end
end

-- Recursively process children
for _, child in ipairs(new_children) do
M.merge_duplicate_namespaces(child)
end

node._children = new_children
return node
end

function M.find_parent_function(node)
while node do
if node:type() == "function_declaration" then
return node
end
node = node:parent()
end
return nil
end

function M.get_function_name(func_node, content)
for child in func_node:iter_children() do
if child:type() == "identifier" then
return vim.treesitter.get_node_text(child, content)
end
end
return "anonymous"
end

function M.run_query_on_file(filepath, query_string)
local file = io.open(filepath, "r")
if not file then
error("Could not open file: " .. filepath)
end
local content = file:read("*all")
file:close()

local lang = "go"
local parser = vim.treesitter.get_string_parser(content, lang)
local tree = parser:parse()[1]
local root = tree:root()

local query = vim.treesitter.query.parse(lang, query_string)
local matches = {}

for id, node, metadata in query:iter_captures(root, content, 0, -1) do
local name = query.captures[id]
local text = vim.treesitter.get_node_text(node, content)

local func_node = M.find_parent_function(node)
if func_node then
local func_name = M.get_function_name(func_node, content)
if not matches[func_name] then
matches[func_name] = {}
end
table.insert(
matches[func_name],
{ name = name, node = node, text = text }
)
else
if not matches["global"] then
matches["global"] = {}
end
table.insert(matches["global"], { name = name, node = node, text = text })
end
end

return matches
end

return M
8 changes: 8 additions & 0 deletions tests/go/go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
module github.com/fredrikaverpil/neotest-golang

go 1.22.2

require github.com/stretchr/testify v1.9.0

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
10 changes: 10 additions & 0 deletions tests/go/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
36 changes: 36 additions & 0 deletions tests/go/testify_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package main

// Basic imports
import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)

// Define the suite, and absorb the built-in basic suite
// functionality from testify - including a T() method which
// returns the current testing context
type ExampleTestSuite struct {
suite.Suite
VariableThatShouldStartAtFive int
}

// Make sure that VariableThatShouldStartAtFive is set to five
// before each test
func (suite *ExampleTestSuite) SetupTest() {
suite.VariableThatShouldStartAtFive = 5
}

// All methods that begin with "Test" are run as tests within a
// suite.
func (suite *ExampleTestSuite) TestExample() {
assert.Equal(suite.T(), 5, suite.VariableThatShouldStartAtFive)
suite.Equal(5, suite.VariableThatShouldStartAtFive)
}

// In order for 'go test' to run this suite, we need to create
// a normal test function and pass our suite to suite.Run
func TestExampleTestSuite(t *testing.T) {
suite.Run(t, new(ExampleTestSuite))
}

0 comments on commit 72bec4a

Please sign in to comment.