diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 64717a4c..6a2322a2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,16 +34,11 @@ jobs: with: python-version: ${{matrix.python}} - name: Lint - run: pipx run ruff check . + continue-on-error: true + run: pipx run ruff check . --output-format=github - name: Build - run: pip install -e . + run: pip install -v -e .[tests] env: - CFLAGS: "-O0 -g" + CFLAGS: -Wextra -Og -g -fno-omit-frame-pointer - name: Test - shell: python - # run: python -Wignore:::tree_sitter -munittest - run: |- - try: __import__('tree_sitter').Language(1048576) - except RuntimeError as err: print(err) - env: - PYTHONFAULTHANDLER: 1 + run: python -munittest -v diff --git a/.gitmodules b/.gitmodules index fa624605..a5f5accd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,21 +1,3 @@ [submodule "tree-sitter"] path = tree_sitter/core url = https://github.com/tree-sitter/tree-sitter -[submodule "tree-sitter-embedded-template"] - path = tests/fixtures/tree-sitter-embedded-template - url = https://github.com/tree-sitter/tree-sitter-embedded-template -[submodule "tree-sitter-html"] - path = tests/fixtures/tree-sitter-html - url = https://github.com/tree-sitter/tree-sitter-html -[submodule "tree-sitter-javascript"] - path = tests/fixtures/tree-sitter-javascript - url = https://github.com/tree-sitter/tree-sitter-javascript -[submodule "tree-sitter-json"] - path = tests/fixtures/tree-sitter-json - url = https://github.com/tree-sitter/tree-sitter-json -[submodule "tree-sitter-python"] - path = tests/fixtures/tree-sitter-python - url = https://github.com/tree-sitter/tree-sitter-python -[submodule "tree-sitter-rust"] - path = tests/fixtures/tree-sitter-rust - url = https://github.com/tree-sitter/tree-sitter-rust diff --git a/docs/conf.py b/docs/conf.py index 4331e53b..773f4701 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,30 +54,30 @@ html_favicon = "_static/favicon.png" -special_doc = regex('\S*self[^.]+') +special_doc = regex("\S*self[^.]+") def process_signature(_app, _what, name, _obj, _options, _signature, return_annotation): - if name == 'tree_sitter.Language': - return '(ptr)', return_annotation - if name == 'tree_sitter.Query': - return '(language, source)', return_annotation - if name == 'tree_sitter.Parser': - return '(language, *, included_ranges=None, timeout_micros=None)', return_annotation - if name == 'tree_sitter.Range': - return '(start_point, end_point, start_byte, end_byte)', return_annotation + if name == "tree_sitter.Language": + return "(ptr)", return_annotation + if name == "tree_sitter.Query": + return "(language, source)", return_annotation + if name == "tree_sitter.Parser": + return "(language, *, included_ranges=None, timeout_micros=None)", return_annotation + if name == "tree_sitter.Range": + return "(start_point, end_point, start_byte, end_byte)", return_annotation def process_docstring(_app, what, name, _obj, _options, lines): - if what == 'data': + if what == "data": lines.clear() - elif what == 'method': - if name.endswith('__index__'): - lines[0] = 'Converts ``self`` to an integer for use as an index.' - elif name.endswith('__') and lines and 'self' in lines[0]: - lines[0] = f'Implements ``{special_doc.search(lines[0]).group(0)}``.' + elif what == "method": + if name.endswith("__index__"): + lines[0] = "Converts ``self`` to an integer for use as an index." + elif name.endswith("__") and lines and "self" in lines[0]: + lines[0] = f"Implements ``{special_doc.search(lines[0]).group(0)}``." def setup(app): - app.connect('autodoc-process-signature', process_signature) - app.connect('autodoc-process-docstring', process_docstring) + app.connect("autodoc-process-signature", process_signature) + app.connect("autodoc-process-docstring", process_docstring) diff --git a/pyproject.toml b/pyproject.toml index c6b24b6d..d2ca87d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,13 @@ email = "maxbrunsfeld@gmail.com" [project.optional-dependencies] docs = ["sphinx~=7.3", "sphinx-book-theme"] +tests = [ + "tree-sitter-html", + "tree-sitter-javascript", + "tree-sitter-json", + "tree-sitter-python", + "tree-sitter-rust", +] [tool.ruff] target-version = "py39" @@ -39,7 +46,7 @@ indent-width = 4 extend-exclude = [ ".github", "__pycache__", - "tests/fixtures", + "setup.py", "tree_sitter/core", ] @@ -49,6 +56,7 @@ indent-style = "space" [tool.cibuildwheel] build-frontend = "build" +test-extras = ["tests"] test-command = "python -munittest discover -s {project}/tests" [tool.cibuildwheel.environment] diff --git a/setup.py b/setup.py index 4bf875bc..77ea76a9 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,12 @@ -"""Py-Tree-sitter""" - from platform import system -from setuptools import Extension, setup +from setuptools import Extension, setup # type: ignore setup( packages=["tree_sitter"], include_package_data=False, package_data={ - "tree_sitter": ["py.typed", "*.pyi"] + "tree_sitter": ["py.typed", "*.pyi"], }, ext_modules=[ Extension( @@ -41,8 +39,12 @@ extra_compile_args=[ "-std=c11", "-fvisibility=hidden", + "-Wno-cast-function-type", "-Werror=implicit-function-declaration", - ] if system() != "Windows" else None + ] if system() != "Windows" else [ + "/std:c11", + "/wd4244", + ], ) - ] + ], ) diff --git a/tests/fixtures/tree-sitter-embedded-template b/tests/fixtures/tree-sitter-embedded-template deleted file mode 160000 index 6d791b89..00000000 --- a/tests/fixtures/tree-sitter-embedded-template +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d791b897ecda59baa0689a85a9906348a2a6414 diff --git a/tests/fixtures/tree-sitter-html b/tests/fixtures/tree-sitter-html deleted file mode 160000 index b285e25c..00000000 --- a/tests/fixtures/tree-sitter-html +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b285e25c1ba8729399ce4f15ac5375cf6c3aa5be diff --git a/tests/fixtures/tree-sitter-javascript b/tests/fixtures/tree-sitter-javascript deleted file mode 160000 index de1e6822..00000000 --- a/tests/fixtures/tree-sitter-javascript +++ /dev/null @@ -1 +0,0 @@ -Subproject commit de1e682289a417354df5b4437a3e4f92e0722a0f diff --git a/tests/fixtures/tree-sitter-json b/tests/fixtures/tree-sitter-json deleted file mode 160000 index 3b129203..00000000 --- a/tests/fixtures/tree-sitter-json +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3b129203f4b72d532f58e72c5310c0a7db3b8e6d diff --git a/tests/fixtures/tree-sitter-python b/tests/fixtures/tree-sitter-python deleted file mode 160000 index b8a4c641..00000000 --- a/tests/fixtures/tree-sitter-python +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b8a4c64121ba66b460cb878e934e3157ecbfb124 diff --git a/tests/fixtures/tree-sitter-rust b/tests/fixtures/tree-sitter-rust deleted file mode 160000 index 3a56481f..00000000 --- a/tests/fixtures/tree-sitter-rust +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3a56481f8d13b6874a28752502a58520b9139dc7 diff --git a/tests/test_language.py b/tests/test_language.py new file mode 100644 index 00000000..50d86047 --- /dev/null +++ b/tests/test_language.py @@ -0,0 +1,86 @@ +from unittest import TestCase + +from tree_sitter import Language, Query + +import tree_sitter_html +import tree_sitter_javascript +import tree_sitter_json +import tree_sitter_python +import tree_sitter_rust + + +class TestLanguage(TestCase): + def setUp(self): + self.html = tree_sitter_html.language() + self.javascript = tree_sitter_javascript.language() + self.json = tree_sitter_json.language() + self.python = tree_sitter_python.language() + self.rust = tree_sitter_rust.language() + + def test_init_not_positive(self): + self.assertRaises(ValueError, Language, -1) + + def test_init_segv(self): + self.assertRaises(RuntimeError, Language, 1024) + + def test_properties(self): + lang = Language(self.python) + self.assertEqual(lang.version, 14) + self.assertEqual(lang.node_kind_count, 274) + self.assertEqual(lang.parse_state_count, 2831) + self.assertEqual(lang.field_count, 32) + + def test_node_kind_for_id(self): + lang = Language(self.json) + self.assertEqual(lang.node_kind_for_id(1), "{") + self.assertEqual(lang.node_kind_for_id(3), "}") + + def test_id_for_node_kind(self): + lang = Language(self.json) + self.assertEqual(lang.id_for_node_kind(":", False), 4) + self.assertEqual(lang.id_for_node_kind("string", True), 20) + + def test_node_kind_is_named(self): + lang = Language(self.json) + self.assertFalse(lang.node_kind_is_named(4)) + self.assertTrue(lang.node_kind_is_named(20)) + + def test_node_kind_is_visible(self): + lang = Language(self.json) + self.assertTrue(lang.node_kind_is_visible(2)) + + def test_field_name_for_id(self): + lang = Language(self.json) + self.assertEqual(lang.field_name_for_id(1), "key") + self.assertEqual(lang.field_name_for_id(2), "value") + + def test_field_id_for_name(self): + lang = Language(self.json) + self.assertEqual(lang.field_id_for_name("key"), 1) + self.assertEqual(lang.field_id_for_name("value"), 2) + + def test_next_state(self): + lang = Language(self.javascript) + self.assertNotEqual(lang.next_state(1, 1), 0) + + def test_lookahead_iterator(self): + lang = Language(self.javascript) + self.assertIsNotNone(lang.lookahead_iterator(0)) + self.assertIsNone(lang.lookahead_iterator(9999)) + + def test_query(self): + lang = Language(self.json) + query = lang.query("(string) @string") + self.assertIsInstance(query, Query) + + def test_eq(self): + self.assertEqual(Language(self.json), Language(self.json)) + self.assertNotEqual(Language(self.rust), Language(self.html)) + + def test_int(self): + for name in ["html", "javascript", "json", "python", "rust"]: + with self.subTest(language=name): + ptr = getattr(self, name) + lang = Language(ptr) + self.assertEqual(int(lang), ptr) + self.assertEqual(hash(lang), ptr) diff --git a/tests/test_lookahead_iterator.py b/tests/test_lookahead_iterator.py new file mode 100644 index 00000000..540da6a1 --- /dev/null +++ b/tests/test_lookahead_iterator.py @@ -0,0 +1,43 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser + +import tree_sitter_rust + + +class TestLookaheadIterator(TestCase): + @classmethod + def setUpClass(self): + self.rust = Language(tree_sitter_rust.language()) + + def test_lookahead_iterator(self): + parser = Parser(self.rust) + cursor = parser.parse(b"struct Stuff{}").walk() + + self.assertEqual(cursor.goto_first_child(), True) # struct + self.assertEqual(cursor.goto_first_child(), True) # struct keyword + + next_state = cursor.node.next_parse_state + + self.assertNotEqual(next_state, 0) + self.assertEqual( + next_state, self.rust.next_state(cursor.node.parse_state, cursor.node.grammar_id) + ) + self.assertLess(next_state, self.rust.parse_state_count) + self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier + self.assertEqual(next_state, cursor.node.parse_state) + self.assertEqual(cursor.node.grammar_name, "identifier") + self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id) + + expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"] + lookahead = self.rust.lookahead_iterator(next_state) + self.assertEqual(lookahead.language, self.rust) + self.assertListEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset_state(next_state) + self.assertListEqual(list(lookahead.iter_names()), expected_symbols) + + lookahead.reset_state(next_state, self.rust) + self.assertListEqual( + list(map(self.rust.node_kind_for_id, list(iter(lookahead)))), expected_symbols + ) diff --git a/tests/test_node.py b/tests/test_node.py new file mode 100644 index 00000000..6df15656 --- /dev/null +++ b/tests/test_node.py @@ -0,0 +1,480 @@ +from unittest import TestCase + +import tree_sitter_python +import tree_sitter_javascript +import tree_sitter_json + +from tree_sitter import Language, Parser + +JSON_EXAMPLE = b""" + +[ + 123, + false, + { + "x": null + } +] +""" + + +def get_all_nodes(tree): + result = [] + visited_children = False + cursor = tree.walk() + while True: + if not visited_children: + result.append(cursor.node) + if not cursor.goto_first_child(): + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent(): + break + return result + + +class TestNode(TestCase): + @classmethod + def setUpClass(cls): + cls.javascript = Language(tree_sitter_javascript.language()) + cls.json = Language(tree_sitter_json.language()) + cls.python = Language(tree_sitter_python.language()) + + def test_child_by_field_id(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + root_node = tree.root_node + fn_node = tree.root_node.children[0] + + self.assertIsNone(self.python.field_id_for_name("noname")) + name_field = self.python.field_id_for_name("name") + alias_field = self.python.field_id_for_name("alias") + self.assertIsNone(root_node.child_by_field_id(alias_field)) + self.assertIsNone(root_node.child_by_field_id(name_field)) + self.assertIsNone(fn_node.child_by_field_id(alias_field)) + self.assertIsNone(fn_node.child_by_field_name("noname")) + self.assertEqual(fn_node.child_by_field_name("name"), fn_node.child_by_field_name("name")) + + def test_child_by_field_name(self): + parser = Parser(self.python) + tree = parser.parse(b"while a:\n pass") + while_node = tree.root_node.child(0) + self.assertIsNotNone(while_node) + self.assertEqual(while_node.type, "while_statement") + self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3)) + + def test_children_by_field_id(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + attribute_field = self.javascript.field_id_for_name("attribute") + attributes = jsx_node.children_by_field_id(attribute_field) + self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + + def test_children_by_field_name(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + attributes = jsx_node.children_by_field_name("attribute") + self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) + + def test_field_name_for_child(self): + parser = Parser(self.javascript) + tree = parser.parse(b"
") + jsx_node = tree.root_node.children[0].children[0] + + self.assertIsNone(jsx_node.field_name_for_child(0)) + self.assertEqual(jsx_node.field_name_for_child(1), "name") + + def test_root_node_with_offset(self): + parser = Parser(self.javascript) + tree = parser.parse(b" if (a) b") + + node = tree.root_node_with_offset(6, (2, 2)) + self.assertIsNotNone(node) + self.assertEqual(node.byte_range, (8, 16)) + self.assertEqual(node.start_point, (2, 4)) + self.assertEqual(node.end_point, (2, 12)) + + child = node.child(0).child(2) + self.assertIsNotNone(child) + self.assertEqual(child.type, "expression_statement") + self.assertEqual(child.byte_range, (15, 16)) + self.assertEqual(child.start_point, (2, 11)) + self.assertEqual(child.end_point, (2, 12)) + + cursor = node.walk() + cursor.goto_first_child() + cursor.goto_first_child() + cursor.goto_next_sibling() + child = cursor.node + self.assertIsNotNone(child) + self.assertEqual(child.type, "parenthesized_expression") + self.assertEqual(child.byte_range, (11, 14)) + self.assertEqual(child.start_point, (2, 7)) + self.assertEqual(child.end_point, (2, 10)) + + def test_descendant_count(self): + parser = Parser(self.json) + tree = parser.parse(JSON_EXAMPLE) + value_node = tree.root_node + all_nodes = get_all_nodes(tree) + + self.assertEqual(value_node.descendant_count, len(all_nodes)) + + cursor = value_node.walk() + for i, node in enumerate(all_nodes): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"index {i}") + + for i, node in reversed(list(enumerate(all_nodes))): + cursor.goto_descendant(i) + self.assertEqual(cursor.node, node, f"rev index {i}") + + def test_descendant_for_byte_range(self): + parser = Parser(self.json) + tree = parser.parse(JSON_EXAMPLE) + array_node = tree.root_node + + colon_index = JSON_EXAMPLE.index(b":") + + # Leaf node exactly matches the given bounds - byte query + colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # Leaf node exactly matches the given bounds - point query + colon_node = array_node.descendant_for_point_range((6, 7), (6, 8)) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # The given point is between two adjacent leaf nodes - byte query + colon_node = array_node.descendant_for_byte_range(colon_index, colon_index) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # The given point is between two adjacent leaf nodes - point query + colon_node = array_node.descendant_for_point_range((6, 7), (6, 7)) + self.assertIsNotNone(colon_node) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.start_byte, colon_index) + self.assertEqual(colon_node.end_byte, colon_index + 1) + self.assertEqual(colon_node.start_point, (6, 7)) + self.assertEqual(colon_node.end_point, (6, 8)) + + # Leaf node starts at the lower bound, ends after the upper bound - byte query + string_index = JSON_EXAMPLE.index(b'"x"') + string_node = array_node.descendant_for_byte_range(string_index, string_index + 2) + self.assertIsNotNone(string_node) + self.assertEqual(string_node.type, "string") + self.assertEqual(string_node.start_byte, string_index) + self.assertEqual(string_node.end_byte, string_index + 3) + self.assertEqual(string_node.start_point, (6, 4)) + self.assertEqual(string_node.end_point, (6, 7)) + + # Leaf node starts at the lower bound, ends after the upper bound - point query + string_node = array_node.descendant_for_point_range((6, 4), (6, 6)) + self.assertIsNotNone(string_node) + self.assertEqual(string_node.type, "string") + self.assertEqual(string_node.start_byte, string_index) + self.assertEqual(string_node.end_byte, string_index + 3) + self.assertEqual(string_node.start_point, (6, 4)) + self.assertEqual(string_node.end_point, (6, 7)) + + # Leaf node starts before the lower bound, ends at the upper bound - byte query + null_index = JSON_EXAMPLE.index(b"null") + null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4) + self.assertIsNotNone(null_node) + self.assertEqual(null_node.type, "null") + self.assertEqual(null_node.start_byte, null_index) + self.assertEqual(null_node.end_byte, null_index + 4) + self.assertEqual(null_node.start_point, (6, 9)) + self.assertEqual(null_node.end_point, (6, 13)) + + # Leaf node starts before the lower bound, ends at the upper bound - point query + null_node = array_node.descendant_for_point_range((6, 11), (6, 13)) + self.assertIsNotNone(null_node) + self.assertEqual(null_node.type, "null") + self.assertEqual(null_node.start_byte, null_index) + self.assertEqual(null_node.end_byte, null_index + 4) + self.assertEqual(null_node.start_point, (6, 9)) + self.assertEqual(null_node.end_point, (6, 13)) + + # The bounds span multiple leaf nodes - return the smallest node that does span it. + pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4) + self.assertIsNotNone(pair_node) + self.assertEqual(pair_node.type, "pair") + self.assertEqual(pair_node.start_byte, string_index) + self.assertEqual(pair_node.end_byte, string_index + 9) + self.assertEqual(pair_node.start_point, (6, 4)) + self.assertEqual(pair_node.end_point, (6, 13)) + + self.assertEqual(colon_node.parent, pair_node) + + # No leaf spans the given range - return the smallest node that does span it. + pair_node = array_node.descendant_for_point_range((6, 6), (6, 8)) + self.assertIsNotNone(pair_node) + self.assertEqual(pair_node.type, "pair") + self.assertEqual(pair_node.start_byte, string_index) + self.assertEqual(pair_node.end_byte, string_index + 9) + self.assertEqual(pair_node.start_point, (6, 4)) + self.assertEqual(pair_node.end_point, (6, 13)) + + def test_children(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + root_node = tree.root_node + self.assertEqual(root_node.type, "module") + self.assertEqual(root_node.start_byte, 0) + self.assertEqual(root_node.end_byte, 18) + self.assertEqual(root_node.start_point, (0, 0)) + self.assertEqual(root_node.end_point, (1, 7)) + + # List object is reused + self.assertIs(root_node.children, root_node.children) + + fn_node = root_node.children[0] + self.assertEqual(fn_node, root_node.child(0)) + self.assertEqual(fn_node.type, "function_definition") + self.assertEqual(fn_node.start_byte, 0) + self.assertEqual(fn_node.end_byte, 18) + self.assertEqual(fn_node.start_point, (0, 0)) + self.assertEqual(fn_node.end_point, (1, 7)) + + def_node = fn_node.children[0] + self.assertEqual(def_node, fn_node.child(0)) + self.assertEqual(def_node.type, "def") + self.assertEqual(def_node.is_named, False) + + id_node = fn_node.children[1] + self.assertEqual(id_node, fn_node.child(1)) + self.assertEqual(id_node.type, "identifier") + self.assertEqual(id_node.is_named, True) + self.assertEqual(len(id_node.children), 0) + + params_node = fn_node.children[2] + self.assertEqual(params_node, fn_node.child(2)) + self.assertEqual(params_node.type, "parameters") + self.assertEqual(params_node.is_named, True) + + colon_node = fn_node.children[3] + self.assertEqual(colon_node, fn_node.child(3)) + self.assertEqual(colon_node.type, ":") + self.assertEqual(colon_node.is_named, False) + + statement_node = fn_node.children[4] + self.assertEqual(statement_node, fn_node.child(4)) + self.assertEqual(statement_node.type, "block") + self.assertEqual(statement_node.is_named, True) + + def test_is_extra(self): + parser = Parser(self.javascript) + tree = parser.parse(b"foo(/* hi */);") + + root_node = tree.root_node + comment_node = root_node.descendant_for_byte_range(7, 7) + self.assertIsNotNone(comment_node) + + self.assertEqual(root_node.type, "program") + self.assertEqual(comment_node.type, "comment") + self.assertEqual(root_node.is_extra, False) + self.assertEqual(comment_node.is_extra, True) + + def test_properties(self): + parser = Parser(self.python) + tree = parser.parse(b"[1, 2, 3]") + + root_node = tree.root_node + self.assertEqual(root_node.type, "module") + self.assertEqual(root_node.start_byte, 0) + self.assertEqual(root_node.end_byte, 9) + self.assertEqual(root_node.start_point, (0, 0)) + self.assertEqual(root_node.end_point, (0, 9)) + + exp_stmt_node = root_node.children[0] + self.assertEqual(exp_stmt_node, root_node.child(0)) + self.assertEqual(exp_stmt_node.type, "expression_statement") + self.assertEqual(exp_stmt_node.start_byte, 0) + self.assertEqual(exp_stmt_node.end_byte, 9) + self.assertEqual(exp_stmt_node.start_point, (0, 0)) + self.assertEqual(exp_stmt_node.end_point, (0, 9)) + self.assertEqual(exp_stmt_node.parent, root_node) + + list_node = exp_stmt_node.children[0] + self.assertEqual(list_node, exp_stmt_node.child(0)) + self.assertEqual(list_node.type, "list") + self.assertEqual(list_node.start_byte, 0) + self.assertEqual(list_node.end_byte, 9) + self.assertEqual(list_node.start_point, (0, 0)) + self.assertEqual(list_node.end_point, (0, 9)) + self.assertEqual(list_node.parent, exp_stmt_node) + + named_children = list_node.named_children + + open_delim_node = list_node.children[0] + self.assertEqual(open_delim_node, list_node.child(0)) + self.assertEqual(open_delim_node.type, "[") + self.assertEqual(open_delim_node.start_byte, 0) + self.assertEqual(open_delim_node.end_byte, 1) + self.assertEqual(open_delim_node.start_point, (0, 0)) + self.assertEqual(open_delim_node.end_point, (0, 1)) + self.assertEqual(open_delim_node.parent, list_node) + + first_num_node = list_node.children[1] + self.assertEqual(first_num_node, list_node.child(1)) + self.assertEqual(first_num_node, open_delim_node.next_named_sibling) + self.assertEqual(first_num_node.parent, list_node) + self.assertEqual(named_children[0], first_num_node) + self.assertEqual(first_num_node, list_node.named_child(0)) + + first_comma_node = list_node.children[2] + self.assertEqual(first_comma_node, list_node.child(2)) + self.assertEqual(first_comma_node, first_num_node.next_sibling) + self.assertEqual(first_num_node, first_comma_node.prev_sibling) + self.assertEqual(first_comma_node.parent, list_node) + + second_num_node = list_node.children[3] + self.assertEqual(second_num_node, list_node.child(3)) + self.assertEqual(second_num_node, first_comma_node.next_sibling) + self.assertEqual(second_num_node, first_num_node.next_named_sibling) + self.assertEqual(first_num_node, second_num_node.prev_named_sibling) + self.assertEqual(second_num_node.parent, list_node) + self.assertEqual(named_children[1], second_num_node) + self.assertEqual(second_num_node, list_node.named_child(1)) + + second_comma_node = list_node.children[4] + self.assertEqual(second_comma_node, list_node.child(4)) + self.assertEqual(second_comma_node, second_num_node.next_sibling) + self.assertEqual(second_num_node, second_comma_node.prev_sibling) + self.assertEqual(second_comma_node.parent, list_node) + + third_num_node = list_node.children[5] + self.assertEqual(third_num_node, list_node.child(5)) + self.assertEqual(third_num_node, second_comma_node.next_sibling) + self.assertEqual(third_num_node, second_num_node.next_named_sibling) + self.assertEqual(second_num_node, third_num_node.prev_named_sibling) + self.assertEqual(third_num_node.parent, list_node) + self.assertEqual(named_children[2], third_num_node) + self.assertEqual(third_num_node, list_node.named_child(2)) + + close_delim_node = list_node.children[6] + self.assertEqual(close_delim_node, list_node.child(6)) + self.assertEqual(close_delim_node.type, "]") + self.assertEqual(close_delim_node.start_byte, 8) + self.assertEqual(close_delim_node.end_byte, 9) + self.assertEqual(close_delim_node.start_point, (0, 8)) + self.assertEqual(close_delim_node.end_point, (0, 9)) + self.assertEqual(close_delim_node, third_num_node.next_sibling) + self.assertEqual(third_num_node, close_delim_node.prev_sibling) + self.assertEqual(third_num_node, close_delim_node.prev_named_sibling) + self.assertEqual(close_delim_node.parent, list_node) + + self.assertEqual(list_node.child_count, 7) + self.assertEqual(list_node.named_child_count, 3) + + def test_numeric_symbols_respect_simple_aliases(self): + parser = Parser(self.python) + + # Example 1: + # Python argument lists can contain "splat" arguments, which are not allowed within + # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These + # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric + # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`. + tree = parser.parse(b"(a((*b)))") + root_node = tree.root_node + self.assertEqual( + str(root_node), + "(module (expression_statement (parenthesized_expression (call " + + "function: (identifier) arguments: (argument_list (parenthesized_expression " + + "(list_splat (identifier))))))))", + ) + + outer_expr_node = root_node.child(0).child(0) + self.assertIsNotNone(outer_expr_node) + self.assertEqual(outer_expr_node.type, "parenthesized_expression") + + inner_expr_node = ( + outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0) + ) + self.assertIsNotNone(inner_expr_node) + + self.assertEqual(inner_expr_node.type, "parenthesized_expression") + self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id) + + def test_tree(self): + code = b"def foo():\n bar()\n\ndef foo():\n bar()" + parser = Parser(self.python) + + for item in parser.parse(code).root_node.children: + self.assertIsNotNone(item.is_named) + + for item in parser.parse(code).root_node.children: + self.assertIsNotNone(item.is_named) + + def test_text(self): + parser = Parser(self.python) + tree = parser.parse(b"[0, [1, 2, 3]]") + + root_node = tree.root_node + self.assertEqual(root_node.text, b"[0, [1, 2, 3]]") + + exp_stmt_node = root_node.children[0] + self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]") + + list_node = exp_stmt_node.children[0] + self.assertEqual(list_node.text, b"[0, [1, 2, 3]]") + + open_delim_node = list_node.children[0] + self.assertEqual(open_delim_node.text, b"[") + + first_num_node = list_node.children[1] + self.assertEqual(first_num_node.text, b"0") + + first_comma_node = list_node.children[2] + self.assertEqual(first_comma_node.text, b",") + + child_list_node = list_node.children[3] + self.assertEqual(child_list_node.text, b"[1, 2, 3]") + + close_delim_node = list_node.children[4] + self.assertEqual(close_delim_node.text, b"]") + + def test_hash(self): + parser = Parser(self.python) + source_code = b"def foo():\n bar()\n bar()" + tree = parser.parse(source_code) + root_node = tree.root_node + first_function_node = root_node.children[0] + second_function_node = root_node.children[0] + + # Uniqueness and consistency + self.assertEqual(hash(first_function_node), hash(first_function_node)) + self.assertNotEqual(hash(root_node), hash(first_function_node)) + + # Equality implication + self.assertEqual(hash(first_function_node), hash(second_function_node)) + self.assertEqual(first_function_node, second_function_node) + + # Different nodes with different properties + different_tree = parser.parse(b"def baz():\n qux()") + different_node = different_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(different_node)) + + # Same code, different parse trees + another_tree = parser.parse(source_code) + another_node = another_tree.root_node.children[0] + self.assertNotEqual(hash(first_function_node), hash(another_node)) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 00000000..2c921ceb --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,426 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser, Range, Tree + +import tree_sitter_html +import tree_sitter_javascript +import tree_sitter_json +import tree_sitter_python +import tree_sitter_rust + + +def simple_range(start, end): + return Range((0, start), (0, end), start, end) + + +class TestParser(TestCase): + @classmethod + def setUpClass(cls): + cls.html = Language(tree_sitter_html.language()) + cls.python = Language(tree_sitter_python.language()) + cls.javascript = Language(tree_sitter_javascript.language()) + cls.json = Language(tree_sitter_json.language()) + cls.rust = Language(tree_sitter_rust.language()) + cls.max_range = Range((0, 0), (0xFFFFFFFF, 0xFFFFFFFF), 0, 0xFFFFFFFF) + cls.min_range = Range((0, 0), (0, 1), 0, 1) + cls.timeout = 1000 + + def test_init_no_args(self): + parser = Parser() + self.assertIsNone(parser.language) + self.assertListEqual(parser.included_ranges, [self.max_range]) + self.assertEqual(parser.timeout_micros, 0) + + def test_init_args(self): + parser = Parser( + language=self.python, included_ranges=[self.min_range], timeout_micros=self.timeout + ) + self.assertEqual(parser.language, self.python) + self.assertListEqual(parser.included_ranges, [self.min_range]) + self.assertEqual(parser.timeout_micros, self.timeout) + + def test_setters(self): + parser = Parser() + + with self.subTest(setter="language"): + parser.language = self.python + self.assertEqual(parser.language, self.python) + + with self.subTest(setter="included_ranges"): + parser.included_ranges = [self.min_range] + self.assertListEqual(parser.included_ranges, [self.min_range]) + with self.assertRaises(ValueError): + parser.included_ranges = [ + Range( + start_byte=23, + end_byte=29, + start_point=(0, 23), + end_point=(0, 29), + ), + Range( + start_byte=0, + end_byte=5, + start_point=(0, 0), + end_point=(0, 5), + ), + Range( + start_byte=50, + end_byte=60, + start_point=(0, 50), + end_point=(0, 60), + ), + ] + with self.assertRaises(ValueError): + parser.included_ranges = [ + Range( + start_byte=10, + end_byte=5, + start_point=(0, 10), + end_point=(0, 5), + ) + ] + + with self.subTest(setter="timeout_micros"): + parser.timeout_micros = self.timeout + self.assertEqual(parser.timeout_micros, self.timeout) + + def test_deleters(self): + parser = Parser() + + with self.subTest(deleter="language"): + del parser.language + self.assertIsNone(parser.language) + + with self.subTest(deleter="included_ranges"): + del parser.included_ranges + self.assertListEqual(parser.included_ranges, [self.max_range]) + + with self.subTest(setter="timeout_micros"): + del parser.timeout_micros + self.assertEqual(parser.timeout_micros, 0) + + def test_parse_buffer(self): + parser = Parser(self.javascript) + with self.subTest(type="bytes"): + self.assertIsInstance(parser.parse(b"test"), Tree) + with self.subTest(type="memoryview"): + self.assertIsInstance(parser.parse(memoryview(b"test")), Tree) + with self.subTest(type="bytearray"): + self.assertIsInstance(parser.parse(bytearray(b"test")), Tree) + + def test_parse_callback(self): + parser = Parser(self.python) + source_lines = ["def foo():\n", " bar()"] + + def read_callback(_, point): + row, column = point + if row >= len(source_lines): + return None + if column >= len(source_lines[row]): + return None + return source_lines[row][column:].encode("utf8") + + tree = parser.parse(read_callback) + self.assertEqual( + str(tree.root_node), + "(module (function_definition" + + " name: (identifier)" + + " parameters: (parameters)" + + " body: (block (expression_statement (call" + + " function: (identifier)" + + " arguments: (argument_list))))))", + ) + + def test_parse_with_one_included_range(self): + source_code = b"hi" + parser = Parser(self.html) + html_tree = parser.parse(source_code) + script_content_node = html_tree.root_node.child(1).child(1) + self.assertIsNotNone(script_content_node) + self.assertEqual(script_content_node.type, "raw_text") + + parser.included_ranges = [script_content_node.range] + parser.language = self.javascript + js_tree = parser.parse(source_code) + self.assertEqual( + str(js_tree.root_node), + "(program (expression_statement (call_expression" + + " function: (member_expression object: (identifier) property: (property_identifier))" + + " arguments: (arguments (string (string_fragment))))))", + ) + self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console"))) + self.assertEqual(js_tree.included_ranges, [script_content_node.range]) + + def test_parse_with_multiple_included_ranges(self): + source_code = b"html `
Hello, ${name.toUpperCase()}, it's ${now()}.
`" + + parser = Parser(self.javascript) + js_tree = parser.parse(source_code) + template_string_node = js_tree.root_node.descendant_for_byte_range( + source_code.index(b"`<"), source_code.index(b">`") + ) + self.assertIsNotNone(template_string_node) + + self.assertEqual(template_string_node.type, "template_string") + + open_quote_node = template_string_node.child(0) + self.assertIsNotNone(open_quote_node) + interpolation_node1 = template_string_node.child(2) + self.assertIsNotNone(interpolation_node1) + interpolation_node2 = template_string_node.child(4) + self.assertIsNotNone(interpolation_node2) + close_quote_node = template_string_node.child(6) + self.assertIsNotNone(close_quote_node) + + html_ranges = [ + Range( + start_byte=open_quote_node.end_byte, + start_point=open_quote_node.end_point, + end_byte=interpolation_node1.start_byte, + end_point=interpolation_node1.start_point, + ), + Range( + start_byte=interpolation_node1.end_byte, + start_point=interpolation_node1.end_point, + end_byte=interpolation_node2.start_byte, + end_point=interpolation_node2.start_point, + ), + Range( + start_byte=interpolation_node2.end_byte, + start_point=interpolation_node2.end_point, + end_byte=close_quote_node.start_byte, + end_point=close_quote_node.start_point, + ), + ] + parser.included_ranges = html_ranges + parser.language = self.html + html_tree = parser.parse(source_code) + + self.assertEqual( + str(html_tree.root_node), + "(document (element" + + " (start_tag (tag_name))" + + " (text)" + + " (element (start_tag (tag_name)) (end_tag (tag_name)))" + + " (text)" + + " (end_tag (tag_name))))" + ) + self.assertEqual(html_tree.included_ranges, html_ranges) + + div_element_node = html_tree.root_node.child(0) + self.assertIsNotNone(div_element_node) + hello_text_node = div_element_node.child(1) + self.assertIsNotNone(hello_text_node) + b_element_node = div_element_node.child(2) + self.assertIsNotNone(b_element_node) + b_start_tag_node = b_element_node.child(0) + self.assertIsNotNone(b_start_tag_node) + b_end_tag_node = b_element_node.child(1) + self.assertIsNotNone(b_end_tag_node) + + self.assertEqual(hello_text_node.type, "text") + self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello")) + self.assertEqual(hello_text_node.end_byte, source_code.index(b" ")) + + self.assertEqual(b_start_tag_node.type, "start_tag") + self.assertEqual(b_start_tag_node.start_byte, source_code.index(b"")) + self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}")) + + self.assertEqual(b_end_tag_node.type, "end_tag") + self.assertEqual(b_end_tag_node.start_byte, source_code.index(b"")) + self.assertEqual(b_end_tag_node.end_byte, source_code.index(b".
")) + + def test_parse_with_included_range_containing_mismatched_positions(self): + source_code = b"
test
{_ignore_this_part_}" + end_byte = source_code.index(b"{_ignore_this_part_") + + range_to_parse = Range( + start_byte=0, + start_point=(10, 12), + end_byte=end_byte, + end_point=(10, 12 + end_byte), + ) + + parser = Parser(self.html, included_ranges=[range_to_parse]) + html_tree = parser.parse(source_code) + + self.assertEqual( + str(html_tree.root_node), + "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))" + ) + + def test_parse_with_included_range_boundaries(self): + source_code = b"a <%= b() %> c <% d() %>" + range1_start_byte = source_code.index(b" b() ") + range1_end_byte = range1_start_byte + len(b" b() ") + range2_start_byte = source_code.index(b" d() ") + range2_end_byte = range2_start_byte + len(b" d() ") + + parser = Parser(self.javascript, included_ranges=[ + Range( + start_byte=range1_start_byte, + end_byte=range1_end_byte, + start_point=(0, range1_start_byte), + end_point=(0, range1_end_byte), + ), + Range( + start_byte=range2_start_byte, + end_byte=range2_end_byte, + start_point=(0, range2_start_byte), + end_point=(0, range2_end_byte), + ) + ]) + + tree = parser.parse(source_code) + root = tree.root_node + statement1 = root.child(0) + self.assertIsNotNone(statement1) + statement2 = root.child(1) + self.assertIsNotNone(statement2) + + self.assertEqual( + str(root), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + + self.assertEqual(statement1.start_byte, source_code.index(b"b()")) + self.assertEqual(statement1.end_byte, source_code.find(b" %> c")) + self.assertEqual(statement2.start_byte, source_code.find(b"d()")) + self.assertEqual(statement2.end_byte, len(source_code) - len(" %>")) + + def test_parse_with_a_newly_excluded_range(self): + source_code = b"
<%= something %>
" + + # Parse HTML including the template directive, which will cause an error + parser = Parser(self.html) + first_tree = parser.parse(source_code) + + prefix = b"a very very long line of plain text. " + first_tree.edit( + start_byte=0, + old_end_byte=0, + new_end_byte=len(prefix), + start_point=(0, 0), + old_end_point=(0, 0), + new_end_point=(0, len(prefix)), + ) + source_code = prefix + source_code + + # Parse the HTML again, this time *excluding* the template directive + # (which has moved since the previous parse). + directive_start = source_code.index(b"<%=") + directive_end = source_code.index(b"") + source_code_end = len(source_code) + parser.included_ranges = [ + Range( + start_byte=0, + end_byte=directive_start, + start_point=(0, 0), + end_point=(0, directive_start), + ), + Range( + start_byte=directive_end, + end_byte=source_code_end, + start_point=(0, directive_end), + end_point=(0, source_code_end), + ), + ] + + tree = parser.parse(source_code, first_tree) + + self.assertEqual( + str(tree.root_node), + "(document (text) (element" + + " (start_tag (tag_name))" + + " (element (start_tag (tag_name)) (end_tag (tag_name)))" + + " (end_tag (tag_name))))" + ) + + self.assertEqual( + tree.changed_ranges(first_tree), + [ + # The first range that has changed syntax is the range of the newly-inserted text. + Range( + start_byte=0, + end_byte=len(prefix), + start_point=(0, 0), + end_point=(0, len(prefix)), + ), + # Even though no edits were applied to the outer `div` element, + # its contents have changed syntax because a range of text that + # was previously included is now excluded. + Range( + start_byte=directive_start, + end_byte=directive_end, + start_point=(0, directive_start), + end_point=(0, directive_end), + ) + ] + ) + + def test_parsing_with_a_newly_included_range(self): + source_code = b"
<%= foo() %>
<%= bar() %><%= baz() %>" + range1_start = source_code.index(b" foo") + range2_start = source_code.index(b" bar") + range3_start = source_code.index(b" baz") + range1_end = range1_start + 7 + range2_end = range2_start + 7 + range3_end = range3_start + 7 + + # Parse only the first code directive as JavaScript + parser = Parser(self.javascript) + parser.included_ranges = [simple_range(range1_start, range1_end)] + tree = parser.parse(source_code) + self.assertEqual( + str(tree.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + + # Parse both the first and third code directives as JavaScript, using the old tree as a + # reference. + parser.included_ranges = [ + simple_range(range1_start, range1_end), + simple_range(range3_start, range3_end), + ] + tree2 = parser.parse(source_code) + self.assertEqual( + str(tree2.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + self.assertEqual( + tree2.changed_ranges(tree), + [simple_range(range1_end, range3_end)] + ) + + # Parse all three code directives as JavaScript, using the old tree as a + # reference. + parser.included_ranges = [ + simple_range(range1_start, range1_end), + simple_range(range2_start, range2_end), + simple_range(range3_start, range3_end), + ] + tree3 = parser.parse(source_code) + self.assertEqual( + str(tree3.root_node), + "(program" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments)))" + + " (expression_statement (call_expression" + + " function: (identifier) arguments: (arguments))))" + ) + self.assertEqual( + tree3.changed_ranges(tree2), + [simple_range(range2_start + 1, range2_end - 1)], + ) diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 00000000..1c14a2c4 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,509 @@ +from unittest import TestCase + +import tree_sitter_python +import tree_sitter_javascript + +from tree_sitter import Language, Parser, Query + + +def collect_matches(matches): + return [(m[0], format_captures(m[1])) for m in matches] + + +def format_captures(captures): + return [(name, format_capture(capture)) for name, capture in captures.items()] + + +def format_capture(capture): + return ( + [n.text.decode("utf-8") for n in capture] + if isinstance(capture, list) + else capture.text.decode("utf-8") + ) + + +class TestQuery(TestCase): + @classmethod + def setUpClass(cls): + cls.javascript = Language(tree_sitter_javascript.language()) + cls.python = Language(tree_sitter_python.language()) + + def assert_query_matches(self, language, query, source, expected): + parser = Parser(language) + tree = parser.parse(source) + matches = language.query(query).matches(tree.root_node) + matches = collect_matches(matches) + self.assertEqual(matches, expected) + + def test_errors(self): + with self.assertRaises(NameError, msg="Invalid node type foo"): + Query(self.python, "(list (foo))") + with self.assertRaises(NameError, msg="Invalid field name buzz"): + Query(self.python, "(function_definition buzz: (identifier))") + with self.assertRaises(NameError, msg="Invalid capture name garbage"): + Query(self.python, "((function_definition) (#eq? @garbage foo))") + with self.assertRaises(SyntaxError, msg="Invalid syntax at offset 6"): + Query(self.python, "(list))") + + def test_matches_with_simple_pattern(self): + self.assert_query_matches( + self.javascript, + "(function_declaration name: (identifier) @fn-name)", + b"function one() { two(); function three() {} }", + [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])], + ) + + def test_matches_with_multiple_on_same_root(self): + self.assert_query_matches( + self.javascript, + """ + (class_declaration + name: (identifier) @the-class-name + (class_body + (method_definition + name: (property_identifier) @the-method-name))) + """, + b""" + class Person { + // the constructor + constructor(name) { this.name = name; } + + // the getter + getFullName() { return this.name; } + } + """, + [ + (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), + (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), + ], + ) + + def test_matches_with_multiple_patterns_different_roots(self): + self.assert_query_matches( + self.javascript, + """ + (function_declaration name: (identifier) @fn-def) + (call_expression function: (identifier) @fn-ref) + """, + b""" + function f1() { + f2(f3()); + } + """, + [ + (0, [("fn-def", "f1")]), + (1, [("fn-ref", "f2")]), + (1, [("fn-ref", "f3")]), + ], + ) + + def test_matches_with_nesting_and_no_fields(self): + self.assert_query_matches( + self.javascript, + "(array (array (identifier) @x1 (identifier) @x2))", + b""" + [[a]]; + [[c, d], [e, f, g, h]]; + [[h], [i]]; + """, + [ + (0, [("x1", "c"), ("x2", "d")]), + (0, [("x1", "e"), ("x2", "f")]), + (0, [("x1", "e"), ("x2", "g")]), + (0, [("x1", "f"), ("x2", "g")]), + (0, [("x1", "e"), ("x2", "h")]), + (0, [("x1", "f"), ("x2", "h")]), + (0, [("x1", "g"), ("x2", "h")]), + ], + ) + + def test_matches_with_list_capture(self): + self.assert_query_matches( + self.javascript, + """ + (function_declaration + name: (identifier) @fn-name + body: (statement_block (_)* @fn-statements)) + """, + b"""function one() { + x = 1; + y = 2; + z = 3; + } + function two() { + x = 1; + } + """, + [ + ( + 0, + [ + ("fn-name", "one"), + ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]), + ], + ), + (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]), + ], + ) + + def test_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node) + + self.assertEqual(captures[0][0].start_point, (0, 4)) + self.assertEqual(captures[0][0].end_point, (0, 7)) + self.assertEqual(captures[0][1], "func-def") + + self.assertEqual(captures[1][0].start_point, (1, 2)) + self.assertEqual(captures[1][0].end_point, (1, 5)) + self.assertEqual(captures[1][1], "func-call") + + self.assertEqual(captures[2][0].start_point, (2, 4)) + self.assertEqual(captures[2][0].end_point, (2, 7)) + self.assertEqual(captures[2][1], "func-def") + + self.assertEqual(captures[3][0].start_point, (3, 2)) + self.assertEqual(captures[3][0].end_point, (3, 6)) + self.assertEqual(captures[3][1], "func-call") + + def test_text_predicates(self): + parser = Parser(self.javascript) + source = b""" + keypair_object = { + key1: value1, + equal: equal + } + + function fun1(arg) { + return 1; + } + + function fun2(arg) { + return 2; + } + """ + tree = parser.parse(source) + root_node = tree.root_node + + # function with name equal to 'fun1' -> test for #eq? @capture string + query1 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? @function-name fun1)) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(1, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + + # functions with name not equal to 'fun1' -> test for #not-eq? @capture string + query2 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#not-eq? @function-name fun1)) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(1, len(captures2)) + self.assertEqual(b"fun2", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + + # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2 + query3 = self.javascript.query( + """ + ((pair + key: (property_identifier) @key-name + value: (identifier) @value-name) + (#eq? @key-name @value-name)) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(2, len(captures3)) + self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3])) + self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3])) + + # key pairs whose key is not equal to its value + # -> test for #not-eq? @capture1 @capture2 + query4 = self.javascript.query( + """ + ((pair + key: (property_identifier) @key-name + value: (identifier) @value-name) + (#not-eq? @key-name @value-name)) + """ + ) + captures4 = query4.captures(root_node) + self.assertEqual(2, len(captures4)) + self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4])) + self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4])) + + # equality that is satisfied by *another* capture + query5 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters (identifier) @parameter-name)) + (#eq? @function-name arg)) + """ + ) + captures5 = query5.captures(root_node) + self.assertEqual(0, len(captures5)) + + # functions that match the regex .*1 -> test for #match @capture regex + query6 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name ".*1")) + """ + ) + captures6 = query6.captures(root_node) + self.assertEqual(1, len(captures6)) + self.assertEqual(b"fun1", captures6[0][0].text) + + # functions that do not match the regex .*1 -> test for #not-match @capture regex + query6 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#not-match? @function-name ".*1")) + """ + ) + captures6 = query6.captures(root_node) + self.assertEqual(1, len(captures6)) + self.assertEqual(b"fun2", captures6[0][0].text) + + # after editing there is no text property, so predicates are ignored + tree.edit( + start_byte=0, + old_end_byte=0, + new_end_byte=2, + start_point=(0, 0), + old_end_point=(0, 0), + new_end_point=(0, 2), + ) + captures_notext = query1.captures(root_node) + self.assertEqual(2, len(captures_notext)) + self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) + + def test_text_predicate_on_optional_capture(self): + parser = Parser(self.javascript) + source = b"fun1(1)" + tree = parser.parse(source) + root_node = tree.root_node + + # optional capture that is missing in source used in #eq? @capture string + query1 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg "1"))) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(1, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + + # optional capture that is missing in source used in #eq? @capture @capture + query2 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg @function-name))) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(1, len(captures2)) + self.assertEqual(b"fun1", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + + # optional capture that is missing in source used in #match? @capture string + query3 = self.javascript.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#match? @optional-string-arg "\\d+"))) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(1, len(captures3)) + self.assertEqual(b"fun1", captures3[0][0].text) + self.assertEqual("function-name", captures3[0][1]) + + def test_text_predicates_errors(self): + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? @function-name @function-name fun1)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#eq? fun1 @function-name)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name @function-name fun1)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? fun1 @function-name)) + """ + ) + + with self.assertRaises(RuntimeError): + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name @function-name)) + """ + ) + + def test_multiple_text_predicates(self): + parser = Parser(self.javascript) + source = b""" + keypair_object = { + key1: value1, + equal: equal + } + + function fun1(arg) { + return 1; + } + + function fun1(notarg) { + return 1 + 1; + } + + function fun2(arg) { + return 2; + } + """ + tree = parser.parse(source) + root_node = tree.root_node + + # function with name equal to 'fun1' -> test for first #eq? @capture string + query1 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @function-name fun1)) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(4, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + self.assertEqual(b"arg", captures1[1][0].text) + self.assertEqual("argument-name", captures1[1][1]) + self.assertEqual(b"fun1", captures1[2][0].text) + self.assertEqual("function-name", captures1[2][1]) + self.assertEqual(b"notarg", captures1[3][0].text) + self.assertEqual("argument-name", captures1[3][1]) + + # function with argument equal to 'arg' -> test for second #eq? @capture string + query2 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @argument-name arg)) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(4, len(captures2)) + self.assertEqual(b"fun1", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + self.assertEqual(b"arg", captures2[1][0].text) + self.assertEqual("argument-name", captures2[1][1]) + self.assertEqual(b"fun2", captures2[2][0].text) + self.assertEqual("function-name", captures2[2][1]) + self.assertEqual(b"arg", captures2[3][0].text) + self.assertEqual("argument-name", captures2[3][1]) + + # function with name equal to 'fun1' & argument 'arg' -> test for both together + query3 = self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name + parameters: (formal_parameters + (identifier) @argument-name)) + (#eq? @function-name fun1) + (#eq? @argument-name arg)) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(2, len(captures3)) + self.assertEqual(b"fun1", captures3[0][0].text) + self.assertEqual("function-name", captures3[0][1]) + self.assertEqual(b"arg", captures3[1][0].text) + self.assertEqual("argument-name", captures3[1][1]) + + def test_point_range_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0)) + + self.assertEqual(captures[0][0].start_point, (1, 2)) + self.assertEqual(captures[0][0].end_point, (1, 5)) + self.assertEqual(captures[0][1], "func-call") + + def test_byte_range_captures(self): + parser = Parser(self.python) + source = b"def foo():\n bar()\ndef baz():\n quux()\n" + tree = parser.parse(source) + query = self.python.query( + """ + (function_definition name: (identifier) @func-def) + (call function: (identifier) @func-call) + """ + ) + + captures = query.captures(tree.root_node, start_byte=10, end_byte=20) + self.assertEqual(captures[0][0].start_point, (1, 2)) + self.assertEqual(captures[0][0].end_point, (1, 5)) + self.assertEqual(captures[0][1], "func-call") diff --git a/tests/test_tree.py b/tests/test_tree.py new file mode 100644 index 00000000..df7b8ae5 --- /dev/null +++ b/tests/test_tree.py @@ -0,0 +1,152 @@ +from unittest import TestCase + +from tree_sitter import Language, Parser + +import tree_sitter_python +import tree_sitter_rust + + +class TestTree(TestCase): + @classmethod + def setUpClass(cls): + cls.python = Language(tree_sitter_python.language()) + cls.rust = Language(tree_sitter_rust.language()) + + def test_edit(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + edit_offset = len(b"def foo(") + tree.edit( + start_byte=edit_offset, + old_end_byte=edit_offset, + new_end_byte=edit_offset + 2, + start_point=(0, edit_offset), + old_end_point=(0, edit_offset), + new_end_point=(0, edit_offset + 2), + ) + + fn_node = tree.root_node.children[0] + self.assertEqual(fn_node.type, "function_definition") + self.assertTrue(fn_node.has_changes) + self.assertFalse(fn_node.children[0].has_changes) + self.assertFalse(fn_node.children[1].has_changes) + self.assertFalse(fn_node.children[3].has_changes) + + params_node = fn_node.children[2] + self.assertEqual(params_node.type, "parameters") + self.assertTrue(params_node.has_changes) + self.assertEqual(params_node.start_point, (0, edit_offset - 1)) + self.assertEqual(params_node.end_point, (0, edit_offset + 3)) + + new_tree = parser.parse(b"def foo(ab):\n bar()", tree) + self.assertEqual( + str(new_tree.root_node), + "(module (function_definition" + + " name: (identifier)" + + " parameters: (parameters (identifier))" + + " body: (block" + + " (expression_statement (call" + + " function: (identifier)" + + " arguments: (argument_list))))))", + ) + + def test_changed_ranges(self): + parser = Parser(self.python) + tree = parser.parse(b"def foo():\n bar()") + + edit_offset = len(b"def foo(") + tree.edit( + start_byte=edit_offset, + old_end_byte=edit_offset, + new_end_byte=edit_offset + 2, + start_point=(0, edit_offset), + old_end_point=(0, edit_offset), + new_end_point=(0, edit_offset + 2), + ) + + new_tree = parser.parse(b"def foo(ab):\n bar()", tree) + changed_ranges = tree.changed_ranges(new_tree) + + self.assertEqual(len(changed_ranges), 1) + self.assertEqual(changed_ranges[0].start_byte, edit_offset) + self.assertEqual(changed_ranges[0].start_point, (0, edit_offset)) + self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2) + self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2)) + + def test_walk(self): + parser = Parser(self.rust) + + tree = parser.parse( + b""" + struct Stuff { + a: A, + b: Option, + } + """ + ) + + cursor = tree.walk() + + # Node always returns the same instance + self.assertIs(cursor.node, cursor.node) + + self.assertEqual(cursor.node.type, "source_file") + + self.assertEqual(cursor.goto_first_child(), True) + self.assertEqual(cursor.node.type, "struct_item") + + self.assertEqual(cursor.goto_first_child(), True) + self.assertEqual(cursor.node.type, "struct") + self.assertEqual(cursor.node.is_named, False) + + self.assertEqual(cursor.goto_next_sibling(), True) + self.assertEqual(cursor.node.type, "type_identifier") + self.assertEqual(cursor.node.is_named, True) + + self.assertEqual(cursor.goto_next_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration_list") + self.assertEqual(cursor.node.is_named, True) + + self.assertEqual(cursor.goto_last_child(), True) + self.assertEqual(cursor.node.type, "}") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (4, 16)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, ",") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (3, 32)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration") + self.assertEqual(cursor.node.is_named, True) + self.assertEqual(cursor.node.start_point, (3, 20)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, ",") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (2, 24)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "field_declaration") + self.assertEqual(cursor.node.is_named, True) + self.assertEqual(cursor.node.start_point, (2, 20)) + + self.assertEqual(cursor.goto_previous_sibling(), True) + self.assertEqual(cursor.node.type, "{") + self.assertEqual(cursor.node.is_named, False) + self.assertEqual(cursor.node.start_point, (1, 29)) + + copy = tree.walk() + copy.reset_to(cursor) + + self.assertEqual(copy.node.type, "{") + self.assertEqual(copy.node.is_named, False) + + self.assertEqual(copy.goto_parent(), True) + self.assertEqual(copy.node.type, "field_declaration_list") + self.assertEqual(copy.node.is_named, True) + + self.assertEqual(copy.goto_parent(), True) + self.assertEqual(copy.node.type, "struct_item") diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py deleted file mode 100644 index 0826af32..00000000 --- a/tests/test_tree_sitter.py +++ /dev/null @@ -1,1906 +0,0 @@ -import re -from os import path -from typing import Dict, List, Optional, Tuple, Union -from unittest import TestCase - -from tree_sitter import Language, LookaheadIterator, Node, Parser, Query, Range, Tree - -LIB_PATH = path.join("build", "languages.so") - -# cibuildwheel uses a funny working directory when running tests. -# This is by design, this way tests import whatever is installed and not from the project. -# -# The languages binary is still relative to current working directory to prevent reusing -# a 32-bit languages binary in a 64-bit build. The working directory is clean every time. -project_root = path.dirname(path.dirname(path.abspath(__file__))) -Language.build_library( - LIB_PATH, - [ - path.join(project_root, "tests", "fixtures", "tree-sitter-embedded-template"), - path.join(project_root, "tests", "fixtures", "tree-sitter-html"), - path.join(project_root, "tests", "fixtures", "tree-sitter-javascript"), - path.join(project_root, "tests", "fixtures", "tree-sitter-json"), - path.join(project_root, "tests", "fixtures", "tree-sitter-python"), - path.join(project_root, "tests", "fixtures", "tree-sitter-rust"), - ], -) - -EMBEDDED_TEMPLATE = Language(LIB_PATH, "embedded_template") -HTML = Language(LIB_PATH, "html") -JAVASCRIPT = Language(LIB_PATH, "javascript") -JSON = Language(LIB_PATH, "json") -PYTHON = Language(LIB_PATH, "python") -RUST = Language(LIB_PATH, "rust") - -JSON_EXAMPLE: bytes = b""" - -[ - 123, - false, - { - "x": null - } -] -""" - - -class TestParser(TestCase): - def test_set_language(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - self.assertEqual( - tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters) - body: (block (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"function foo() {\n bar();\n}") - self.assertEqual( - tree.root_node.sexp(), - trim( - """(program (function_declaration - name: (identifier) - parameters: (formal_parameters) - body: (statement_block - (expression_statement - (call_expression - function: (identifier) - arguments: (arguments))))))""" - ), - ) - - def test_read_callback(self): - parser = Parser() - parser.set_language(PYTHON) - source_lines = ["def foo():\n", " bar()"] - - def read_callback(_: int, point: Tuple[int, int]) -> Optional[bytes]: - row, column = point - if row >= len(source_lines): - return None - if column >= len(source_lines[row]): - return None - return source_lines[row][column:].encode("utf8") - - tree = parser.parse(read_callback) - self.assertEqual( - tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters) - body: (block (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - - def test_multibyte_characters(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source_code = bytes("'😎' && '🐍'", "utf8") - tree = parser.parse(source_code) - root_node = tree.root_node - statement_node = root_node.children[0] - binary_node = statement_node.children[0] - snake_node = binary_node.children[2] - - self.assertEqual(binary_node.type, "binary_expression") - self.assertEqual(snake_node.type, "string") - self.assertEqual( - source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"), - "'🐍'", - ) - - def test_buffer_protocol(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.parse(b"test") - parser.parse(memoryview(b"test")) - parser.parse(bytearray(b"test")) - - def test_multibyte_characters_via_read_callback(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source_code = bytes("'😎' && '🐍'", "utf8") - - def read(byte_position, _): - return source_code[byte_position : byte_position + 1] - - tree = parser.parse(read) - root_node = tree.root_node - statement_node = root_node.children[0] - binary_node = statement_node.children[0] - snake_node = binary_node.children[2] - - self.assertEqual(binary_node.type, "binary_expression") - self.assertEqual(snake_node.type, "string") - self.assertEqual( - source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"), - "'🐍'", - ) - - def test_parsing_with_one_included_range(self): - source_code = b"hi" - parser = Parser() - parser.set_language(HTML) - html_tree = parser.parse(source_code) - script_content_node = html_tree.root_node.child(1).child(1) - if script_content_node is None: - self.fail("script_content_node is None") - self.assertEqual(script_content_node.type, "raw_text") - - parser.set_included_ranges([script_content_node.range]) - parser.set_language(JAVASCRIPT) - js_tree = parser.parse(source_code) - - self.assertEqual( - js_tree.root_node.sexp(), - "(program (expression_statement (call_expression " - + "function: (member_expression object: (identifier) property: (property_identifier)) " - + "arguments: (arguments (string (string_fragment))))))", - ) - self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console"))) - self.assertEqual(js_tree.included_ranges, [script_content_node.range]) - - def test_parsing_with_multiple_included_ranges(self): - source_code = b"html `
Hello, ${name.toUpperCase()}, it's ${now()}.
`" - - parser = Parser() - parser.set_language(JAVASCRIPT) - js_tree = parser.parse(source_code) - template_string_node = js_tree.root_node.descendant_for_byte_range( - source_code.index(b"`<"), source_code.index(b">`") - ) - if template_string_node is None: - self.fail("template_string_node is None") - - self.assertEqual(template_string_node.type, "template_string") - - open_quote_node = template_string_node.child(0) - if open_quote_node is None: - self.fail("open_quote_node is None") - interpolation_node1 = template_string_node.child(2) - if interpolation_node1 is None: - self.fail("interpolation_node1 is None") - interpolation_node2 = template_string_node.child(4) - if interpolation_node2 is None: - self.fail("interpolation_node2 is None") - close_quote_node = template_string_node.child(6) - if close_quote_node is None: - self.fail("close_quote_node is None") - - html_ranges = [ - Range( - start_byte=open_quote_node.end_byte, - start_point=open_quote_node.end_point, - end_byte=interpolation_node1.start_byte, - end_point=interpolation_node1.start_point, - ), - Range( - start_byte=interpolation_node1.end_byte, - start_point=interpolation_node1.end_point, - end_byte=interpolation_node2.start_byte, - end_point=interpolation_node2.start_point, - ), - Range( - start_byte=interpolation_node2.end_byte, - start_point=interpolation_node2.end_point, - end_byte=close_quote_node.start_byte, - end_point=close_quote_node.start_point, - ), - ] - parser.set_included_ranges(html_ranges) - parser.set_language(HTML) - html_tree = parser.parse(source_code) - - self.assertEqual( - html_tree.root_node.sexp(), - "(document (element" - + " (start_tag (tag_name))" - + " (text)" - + " (element (start_tag (tag_name)) (end_tag (tag_name)))" - + " (text)" - + " (end_tag (tag_name))))", - ) - self.assertEqual(html_tree.included_ranges, html_ranges) - - div_element_node = html_tree.root_node.child(0) - if div_element_node is None: - self.fail("div_element_node is None") - hello_text_node = div_element_node.child(1) - if hello_text_node is None: - self.fail("hello_text_node is None") - b_element_node = div_element_node.child(2) - if b_element_node is None: - self.fail("b_element_node is None") - b_start_tag_node = b_element_node.child(0) - if b_start_tag_node is None: - self.fail("b_start_tag_node is None") - b_end_tag_node = b_element_node.child(1) - if b_end_tag_node is None: - self.fail("b_end_tag_node is None") - - self.assertEqual(hello_text_node.type, "text") - self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello")) - self.assertEqual(hello_text_node.end_byte, source_code.index(b" ")) - - self.assertEqual(b_start_tag_node.type, "start_tag") - self.assertEqual(b_start_tag_node.start_byte, source_code.index(b"")) - self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}")) - - self.assertEqual(b_end_tag_node.type, "end_tag") - self.assertEqual(b_end_tag_node.start_byte, source_code.index(b"")) - self.assertEqual(b_end_tag_node.end_byte, source_code.index(b".
")) - - def test_parsing_with_included_range_containing_mismatched_positions(self): - source_code = b"
test
{_ignore_this_part_}" - - parser = Parser() - parser.set_language(HTML) - - end_byte = source_code.index(b"{_ignore_this_part_") - - range_to_parse = Range( - start_byte=0, - start_point=(10, 12), - end_byte=end_byte, - end_point=(10, 12 + end_byte), - ) - - parser.set_included_ranges([range_to_parse]) - - html_tree = parser.parse(source_code) - - self.assertEqual( - html_tree.root_node.sexp(), - "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))", - ) - - def test_parsing_error_in_invalid_included_ranges(self): - parser = Parser() - with self.assertRaises(Exception): - parser.set_included_ranges( - [ - Range( - start_byte=23, - end_byte=29, - start_point=(0, 23), - end_point=(0, 29), - ), - Range( - start_byte=0, - end_byte=5, - start_point=(0, 0), - end_point=(0, 5), - ), - Range( - start_byte=50, - end_byte=60, - start_point=(0, 50), - end_point=(0, 60), - ), - ] - ) - - with self.assertRaises(Exception): - parser.set_included_ranges( - [ - Range( - start_byte=10, - end_byte=5, - start_point=(0, 10), - end_point=(0, 5), - ) - ] - ) - - def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self): - source_code = b"a <%= b() %> c <% d() %>" - range1_start_byte = source_code.index(b" b() ") - range1_end_byte = range1_start_byte + len(b" b() ") - range2_start_byte = source_code.index(b" d() ") - range2_end_byte = range2_start_byte + len(b" d() ") - - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.set_included_ranges( - [ - Range( - start_byte=range1_start_byte, - end_byte=range1_end_byte, - start_point=(0, range1_start_byte), - end_point=(0, range1_end_byte), - ), - Range( - start_byte=range2_start_byte, - end_byte=range2_end_byte, - start_point=(0, range2_start_byte), - end_point=(0, range2_end_byte), - ), - ] - ) - - tree = parser.parse(source_code) - root = tree.root_node - statement1 = root.child(0) - if statement1 is None: - self.fail("statement1 is None") - statement2 = root.child(1) - if statement2 is None: - self.fail("statement2 is None") - - self.assertEqual( - root.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - - self.assertEqual(statement1.start_byte, source_code.index(b"b()")) - self.assertEqual(statement1.end_byte, source_code.find(b" %> c")) - self.assertEqual(statement2.start_byte, source_code.find(b"d()")) - self.assertEqual(statement2.end_byte, len(source_code) - len(" %>")) - - def test_parsing_with_a_newly_excluded_range(self): - source_code = b"
<%= something %>
" - - # Parse HTML including the template directive, which will cause an error - parser = Parser() - parser.set_language(HTML) - first_tree = parser.parse(source_code) - - prefix = b"a very very long line of plain text. " - first_tree.edit( - start_byte=0, - old_end_byte=0, - new_end_byte=len(prefix), - start_point=(0, 0), - old_end_point=(0, 0), - new_end_point=(0, len(prefix)), - ) - source_code = prefix + source_code - - # Parse the HTML again, this time *excluding* the template directive - # (which has moved since the previous parse). - directive_start = source_code.index(b"<%=") - directive_end = source_code.index(b"") - source_code_end = len(source_code) - parser.set_included_ranges( - [ - Range( - start_byte=0, - end_byte=directive_start, - start_point=(0, 0), - end_point=(0, directive_start), - ), - Range( - start_byte=directive_end, - end_byte=source_code_end, - start_point=(0, directive_end), - end_point=(0, source_code_end), - ), - ] - ) - - tree = parser.parse(source_code, first_tree) - - self.assertEqual( - tree.root_node.sexp(), - "(document (text) (element" - + " (start_tag (tag_name))" - + " (element (start_tag (tag_name)) (end_tag (tag_name)))" - + " (end_tag (tag_name))))", - ) - - self.assertEqual( - tree.changed_ranges(first_tree), - [ - # The first range that has changed syntax is the range of the newly-inserted text. - Range( - start_byte=0, - end_byte=len(prefix), - start_point=(0, 0), - end_point=(0, len(prefix)), - ), - # Even though no edits were applied to the outer `div` element, - # its contents have changed syntax because a range of text that - # was previously included is now excluded. - Range( - start_byte=directive_start, - end_byte=directive_end, - start_point=(0, directive_start), - end_point=(0, directive_end), - ), - ], - ) - - def test_parsing_with_a_newly_included_range(self): - source_code = b"
<%= foo() %>
<%= bar() %><%= baz() %>" - range1_start = source_code.index(b" foo") - range2_start = source_code.index(b" bar") - range3_start = source_code.index(b" baz") - range1_end = range1_start + 7 - range2_end = range2_start + 7 - range3_end = range3_start + 7 - - def simple_range(start: int, end: int) -> Range: - return Range( - start_byte=start, - end_byte=end, - start_point=(0, start), - end_point=(0, end), - ) - - # Parse only the first code directive as JavaScript - parser = Parser() - parser.set_language(JAVASCRIPT) - parser.set_included_ranges([simple_range(range1_start, range1_end)]) - tree = parser.parse(source_code) - self.assertEqual( - tree.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - - # Parse both the first and third code directives as JavaScript, using the old tree as a - # reference. - parser.set_included_ranges( - [ - simple_range(range1_start, range1_end), - simple_range(range3_start, range3_end), - ] - ) - tree2 = parser.parse(source_code) - self.assertEqual( - tree2.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - self.assertEqual(tree2.changed_ranges(tree), [simple_range(range1_end, range3_end)]) - - # Parse all three code directives as JavaScript, using the old tree as a - # reference. - parser.set_included_ranges( - [ - simple_range(range1_start, range1_end), - simple_range(range2_start, range2_end), - simple_range(range3_start, range3_end), - ] - ) - tree3 = parser.parse(source_code) - self.assertEqual( - tree3.root_node.sexp(), - "(program" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + " " - + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))" - + ")", - ) - self.assertEqual( - tree3.changed_ranges(tree2), - [simple_range(range2_start + 1, range2_end - 1)], - ) - - -class TestNode(TestCase): - def test_child_by_field_id(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - root_node = tree.root_node - fn_node = tree.root_node.children[0] - - self.assertEqual(PYTHON.field_id_for_name("nameasdf"), None) - name_field = PYTHON.field_id_for_name("name") - alias_field = PYTHON.field_id_for_name("alias") - if not isinstance(alias_field, int): - self.fail("alias_field is not an int") - if not isinstance(name_field, int): - self.fail("name_field is not an int") - self.assertEqual(root_node.child_by_field_id(alias_field), None) - self.assertEqual(root_node.child_by_field_id(name_field), None) - self.assertEqual(fn_node.child_by_field_id(alias_field), None) - self.assertEqual(fn_node.child_by_field_id(name_field).type, "identifier") - self.assertRaises(TypeError, root_node.child_by_field_id, "") - self.assertRaises(TypeError, root_node.child_by_field_name, True) - self.assertRaises(TypeError, root_node.child_by_field_name, 1) - - self.assertEqual(fn_node.child_by_field_name("name").type, "identifier") - self.assertEqual(fn_node.child_by_field_name("asdfasdfname"), None) - - self.assertEqual( - fn_node.child_by_field_name("name"), - fn_node.child_by_field_name("name"), - ) - - def test_children_by_field_id(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - attribute_field = PYTHON.field_id_for_name("attribute") - if not isinstance(attribute_field, int): - self.fail("attribute_field is not an int") - - attributes = jsx_node.children_by_field_id(attribute_field) - self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) - - def test_children_by_field_name(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - - attributes = jsx_node.children_by_field_name("attribute") - self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"]) - - def test_node_child_by_field_name_with_extra_hidden_children(self): - parser = Parser() - parser.set_language(PYTHON) - - tree = parser.parse(b"while a:\n pass") - while_node = tree.root_node.child(0) - if while_node is None: - self.fail("while_node is None") - self.assertEqual(while_node.type, "while_statement") - self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3)) - - def test_node_descendant_count(self): - parser = Parser() - parser.set_language(JSON) - tree = parser.parse(JSON_EXAMPLE) - value_node = tree.root_node - all_nodes = get_all_nodes(tree) - - self.assertEqual(value_node.descendant_count, len(all_nodes)) - - cursor = value_node.walk() - for i, node in enumerate(all_nodes): - cursor.goto_descendant(i) - self.assertEqual(cursor.node, node, f"index {i}") - - for i, node in reversed(list(enumerate(all_nodes))): - cursor.goto_descendant(i) - self.assertEqual(cursor.node, node, f"rev index {i}") - - def test_descendant_count_single_node_tree(self): - parser = Parser() - parser.set_language(EMBEDDED_TEMPLATE) - tree = parser.parse(b"hello") - - nodes = get_all_nodes(tree) - self.assertEqual(len(nodes), 2) - self.assertEqual(tree.root_node.descendant_count, 2) - - cursor = tree.walk() - - cursor.goto_descendant(0) - self.assertEqual(cursor.depth, 0) - self.assertEqual(cursor.node, nodes[0]) - cursor.goto_descendant(1) - self.assertEqual(cursor.depth, 1) - self.assertEqual(cursor.node, nodes[1]) - - def test_field_name_for_child(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"
") - jsx_node = tree.root_node.children[0].children[0] - - self.assertEqual(jsx_node.field_name_for_child(0), None) - self.assertEqual(jsx_node.field_name_for_child(1), "name") - - def test_descendant_for_byte_range(self): - parser = Parser() - parser.set_language(JSON) - tree = parser.parse(JSON_EXAMPLE) - array_node = tree.root_node - - colon_index = JSON_EXAMPLE.index(b":") - - # Leaf node exactly matches the given bounds - byte query - colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # Leaf node exactly matches the given bounds - point query - colon_node = array_node.descendant_for_point_range((6, 7), (6, 8)) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # The given point is between two adjacent leaf nodes - byte query - colon_node = array_node.descendant_for_byte_range(colon_index, colon_index) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # The given point is between two adjacent leaf nodes - point query - colon_node = array_node.descendant_for_point_range((6, 7), (6, 7)) - if colon_node is None: - self.fail("colon_node is None") - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.start_byte, colon_index) - self.assertEqual(colon_node.end_byte, colon_index + 1) - self.assertEqual(colon_node.start_point, (6, 7)) - self.assertEqual(colon_node.end_point, (6, 8)) - - # Leaf node starts at the lower bound, ends after the upper bound - byte query - string_index = JSON_EXAMPLE.index(b'"x"') - string_node = array_node.descendant_for_byte_range(string_index, string_index + 2) - if string_node is None: - self.fail("string_node is None") - self.assertEqual(string_node.type, "string") - self.assertEqual(string_node.start_byte, string_index) - self.assertEqual(string_node.end_byte, string_index + 3) - self.assertEqual(string_node.start_point, (6, 4)) - self.assertEqual(string_node.end_point, (6, 7)) - - # Leaf node starts at the lower bound, ends after the upper bound - point query - string_node = array_node.descendant_for_point_range((6, 4), (6, 6)) - if string_node is None: - self.fail("string_node is None") - self.assertEqual(string_node.type, "string") - self.assertEqual(string_node.start_byte, string_index) - self.assertEqual(string_node.end_byte, string_index + 3) - self.assertEqual(string_node.start_point, (6, 4)) - self.assertEqual(string_node.end_point, (6, 7)) - - # Leaf node starts before the lower bound, ends at the upper bound - byte query - null_index = JSON_EXAMPLE.index(b"null") - null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4) - if null_node is None: - self.fail("null_node is None") - self.assertEqual(null_node.type, "null") - self.assertEqual(null_node.start_byte, null_index) - self.assertEqual(null_node.end_byte, null_index + 4) - self.assertEqual(null_node.start_point, (6, 9)) - self.assertEqual(null_node.end_point, (6, 13)) - - # Leaf node starts before the lower bound, ends at the upper bound - point query - null_node = array_node.descendant_for_point_range((6, 11), (6, 13)) - if null_node is None: - self.fail("null_node is None") - self.assertEqual(null_node.type, "null") - self.assertEqual(null_node.start_byte, null_index) - self.assertEqual(null_node.end_byte, null_index + 4) - self.assertEqual(null_node.start_point, (6, 9)) - self.assertEqual(null_node.end_point, (6, 13)) - - # The bounds span multiple leaf nodes - return the smallest node that does span it. - pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4) - if pair_node is None: - self.fail("pair_node is None") - self.assertEqual(pair_node.type, "pair") - self.assertEqual(pair_node.start_byte, string_index) - self.assertEqual(pair_node.end_byte, string_index + 9) - self.assertEqual(pair_node.start_point, (6, 4)) - self.assertEqual(pair_node.end_point, (6, 13)) - - self.assertEqual(colon_node.parent, pair_node) - - # No leaf spans the given range - return the smallest node that does span it. - pair_node = array_node.descendant_for_point_range((6, 6), (6, 8)) - if pair_node is None: - self.fail("pair_node is None") - self.assertEqual(pair_node.type, "pair") - self.assertEqual(pair_node.start_byte, string_index) - self.assertEqual(pair_node.end_byte, string_index + 9) - self.assertEqual(pair_node.start_point, (6, 4)) - self.assertEqual(pair_node.end_point, (6, 13)) - - def test_root_node_with_offset(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b" if (a) b") - - node = tree.root_node_with_offset(6, (2, 2)) - if node is None: - self.fail("node is None") - self.assertEqual(node.byte_range, (8, 16)) - self.assertEqual(node.start_point, (2, 4)) - self.assertEqual(node.end_point, (2, 12)) - - child = node.child(0).child(2) - if child is None: - self.fail("child is None") - self.assertEqual(child.type, "expression_statement") - self.assertEqual(child.byte_range, (15, 16)) - self.assertEqual(child.start_point, (2, 11)) - self.assertEqual(child.end_point, (2, 12)) - - cursor = node.walk() - cursor.goto_first_child() - cursor.goto_first_child() - cursor.goto_next_sibling() - child = cursor.node - if child is None: - self.fail("child is None") - self.assertEqual(child.type, "parenthesized_expression") - self.assertEqual(child.byte_range, (11, 14)) - self.assertEqual(child.start_point, (2, 7)) - self.assertEqual(child.end_point, (2, 10)) - - def test_node_is_extra(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - tree = parser.parse(b"foo(/* hi */);") - - root_node = tree.root_node - comment_node = root_node.descendant_for_byte_range(7, 7) - if comment_node is None: - self.fail("comment_node is None") - - self.assertEqual(root_node.type, "program") - self.assertEqual(comment_node.type, "comment") - self.assertEqual(root_node.is_extra, False) - self.assertEqual(comment_node.is_extra, True) - - def test_children(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - root_node = tree.root_node - self.assertEqual(root_node.type, "module") - self.assertEqual(root_node.start_byte, 0) - self.assertEqual(root_node.end_byte, 18) - self.assertEqual(root_node.start_point, (0, 0)) - self.assertEqual(root_node.end_point, (1, 7)) - - # List object is reused - self.assertIs(root_node.children, root_node.children) - - fn_node = root_node.children[0] - self.assertEqual(fn_node, root_node.child(0)) - self.assertEqual(fn_node.type, "function_definition") - self.assertEqual(fn_node.start_byte, 0) - self.assertEqual(fn_node.end_byte, 18) - self.assertEqual(fn_node.start_point, (0, 0)) - self.assertEqual(fn_node.end_point, (1, 7)) - - def_node = fn_node.children[0] - self.assertEqual(def_node, fn_node.child(0)) - self.assertEqual(def_node.type, "def") - self.assertEqual(def_node.is_named, False) - - id_node = fn_node.children[1] - self.assertEqual(id_node, fn_node.child(1)) - self.assertEqual(id_node.type, "identifier") - self.assertEqual(id_node.is_named, True) - self.assertEqual(len(id_node.children), 0) - - params_node = fn_node.children[2] - self.assertEqual(params_node, fn_node.child(2)) - self.assertEqual(params_node.type, "parameters") - self.assertEqual(params_node.is_named, True) - - colon_node = fn_node.children[3] - self.assertEqual(colon_node, fn_node.child(3)) - self.assertEqual(colon_node.type, ":") - self.assertEqual(colon_node.is_named, False) - - statement_node = fn_node.children[4] - self.assertEqual(statement_node, fn_node.child(4)) - self.assertEqual(statement_node.type, "block") - self.assertEqual(statement_node.is_named, True) - - def test_named_and_sibling_and_count_and_parent(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"[1, 2, 3]") - - root_node = tree.root_node - self.assertEqual(root_node.type, "module") - self.assertEqual(root_node.start_byte, 0) - self.assertEqual(root_node.end_byte, 9) - self.assertEqual(root_node.start_point, (0, 0)) - self.assertEqual(root_node.end_point, (0, 9)) - - exp_stmt_node = root_node.children[0] - self.assertEqual(exp_stmt_node, root_node.child(0)) - self.assertEqual(exp_stmt_node.type, "expression_statement") - self.assertEqual(exp_stmt_node.start_byte, 0) - self.assertEqual(exp_stmt_node.end_byte, 9) - self.assertEqual(exp_stmt_node.start_point, (0, 0)) - self.assertEqual(exp_stmt_node.end_point, (0, 9)) - self.assertEqual(exp_stmt_node.parent, root_node) - - list_node = exp_stmt_node.children[0] - self.assertEqual(list_node, exp_stmt_node.child(0)) - self.assertEqual(list_node.type, "list") - self.assertEqual(list_node.start_byte, 0) - self.assertEqual(list_node.end_byte, 9) - self.assertEqual(list_node.start_point, (0, 0)) - self.assertEqual(list_node.end_point, (0, 9)) - self.assertEqual(list_node.parent, exp_stmt_node) - - named_children = list_node.named_children - - open_delim_node = list_node.children[0] - self.assertEqual(open_delim_node, list_node.child(0)) - self.assertEqual(open_delim_node.type, "[") - self.assertEqual(open_delim_node.start_byte, 0) - self.assertEqual(open_delim_node.end_byte, 1) - self.assertEqual(open_delim_node.start_point, (0, 0)) - self.assertEqual(open_delim_node.end_point, (0, 1)) - self.assertEqual(open_delim_node.parent, list_node) - - first_num_node = list_node.children[1] - self.assertEqual(first_num_node, list_node.child(1)) - self.assertEqual(first_num_node, open_delim_node.next_named_sibling) - self.assertEqual(first_num_node.parent, list_node) - self.assertEqual(named_children[0], first_num_node) - self.assertEqual(first_num_node, list_node.named_child(0)) - - first_comma_node = list_node.children[2] - self.assertEqual(first_comma_node, list_node.child(2)) - self.assertEqual(first_comma_node, first_num_node.next_sibling) - self.assertEqual(first_num_node, first_comma_node.prev_sibling) - self.assertEqual(first_comma_node.parent, list_node) - - second_num_node = list_node.children[3] - self.assertEqual(second_num_node, list_node.child(3)) - self.assertEqual(second_num_node, first_comma_node.next_sibling) - self.assertEqual(second_num_node, first_num_node.next_named_sibling) - self.assertEqual(first_num_node, second_num_node.prev_named_sibling) - self.assertEqual(second_num_node.parent, list_node) - self.assertEqual(named_children[1], second_num_node) - self.assertEqual(second_num_node, list_node.named_child(1)) - - second_comma_node = list_node.children[4] - self.assertEqual(second_comma_node, list_node.child(4)) - self.assertEqual(second_comma_node, second_num_node.next_sibling) - self.assertEqual(second_num_node, second_comma_node.prev_sibling) - self.assertEqual(second_comma_node.parent, list_node) - - third_num_node = list_node.children[5] - self.assertEqual(third_num_node, list_node.child(5)) - self.assertEqual(third_num_node, second_comma_node.next_sibling) - self.assertEqual(third_num_node, second_num_node.next_named_sibling) - self.assertEqual(second_num_node, third_num_node.prev_named_sibling) - self.assertEqual(third_num_node.parent, list_node) - self.assertEqual(named_children[2], third_num_node) - self.assertEqual(third_num_node, list_node.named_child(2)) - - close_delim_node = list_node.children[6] - self.assertEqual(close_delim_node, list_node.child(6)) - self.assertEqual(close_delim_node.type, "]") - self.assertEqual(close_delim_node.start_byte, 8) - self.assertEqual(close_delim_node.end_byte, 9) - self.assertEqual(close_delim_node.start_point, (0, 8)) - self.assertEqual(close_delim_node.end_point, (0, 9)) - self.assertEqual(close_delim_node, third_num_node.next_sibling) - self.assertEqual(third_num_node, close_delim_node.prev_sibling) - self.assertEqual(third_num_node, close_delim_node.prev_named_sibling) - self.assertEqual(close_delim_node.parent, list_node) - - self.assertEqual(list_node.child_count, 7) - self.assertEqual(list_node.named_child_count, 3) - - def test_node_text(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"[0, [1, 2, 3]]") - - self.assertEqual(tree.text, b"[0, [1, 2, 3]]") - - root_node = tree.root_node - self.assertEqual(root_node.text, b"[0, [1, 2, 3]]") - - exp_stmt_node = root_node.children[0] - self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]") - - list_node = exp_stmt_node.children[0] - self.assertEqual(list_node.text, b"[0, [1, 2, 3]]") - - open_delim_node = list_node.children[0] - self.assertEqual(open_delim_node.text, b"[") - - first_num_node = list_node.children[1] - self.assertEqual(first_num_node.text, b"0") - - first_comma_node = list_node.children[2] - self.assertEqual(first_comma_node.text, b",") - - child_list_node = list_node.children[3] - self.assertEqual(child_list_node.text, b"[1, 2, 3]") - - close_delim_node = list_node.children[4] - self.assertEqual(close_delim_node.text, b"]") - - edit_offset = len(b"[0, [") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - self.assertEqual(tree.text, None) - - root_node_again = tree.root_node - self.assertEqual(root_node_again.text, None) - - tree_text_false = parser.parse(b"[0, [1, 2, 3]]", keep_text=False) - self.assertIsNone(tree_text_false.text) - root_node_text_false = tree_text_false.root_node - self.assertIsNone(root_node_text_false.text) - - tree_text_true = parser.parse(b"[0, [1, 2, 3]]", keep_text=True) - self.assertEqual(tree_text_true.text, b"[0, [1, 2, 3]]") - root_node_text_true = tree_text_true.root_node - self.assertEqual(root_node_text_true.text, b"[0, [1, 2, 3]]") - - def test_tree(self): - code = b"def foo():\n bar()\n\ndef foo():\n bar()" - parser = Parser() - parser.set_language(PYTHON) - - def parse_root(bytes_): - tree = parser.parse(bytes_) - return tree.root_node - - root = parse_root(code) - for item in root.children: - self.assertIsNotNone(item.is_named) - - def parse_root_children(bytes_): - tree = parser.parse(bytes_) - return tree.root_node.children - - children = parse_root_children(code) - for item in children: - self.assertIsNotNone(item.is_named) - - def test_node_numeric_symbols_respect_simple_aliases(self): - parser = Parser() - parser.set_language(PYTHON) - - # Example 1: - # Python argument lists can contain "splat" arguments, which are not allowed within - # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These - # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric - # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`. - tree = parser.parse(b"(a((*b)))") - root_node = tree.root_node - self.assertEqual( - root_node.sexp(), - "(module (expression_statement (parenthesized_expression (call " - + "function: (identifier) arguments: (argument_list (parenthesized_expression " - + "(list_splat (identifier))))))))", - ) - - outer_expr_node = root_node.child(0).child(0) - if outer_expr_node is None: - self.fail("outer_expr_node is None") - self.assertEqual(outer_expr_node.type, "parenthesized_expression") - - inner_expr_node = ( - outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0) - ) - if inner_expr_node is None: - self.fail("inner_expr_node is None") - - self.assertEqual(inner_expr_node.type, "parenthesized_expression") - self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id) - - -class TestTree(TestCase): - def test_tree_cursor_without_tree(self): - parser = Parser() - parser.set_language(PYTHON) - - def parse(): - tree = parser.parse(b"def foo():\n bar()") - return tree.walk() - - cursor = parse() - self.assertIs(cursor.node, cursor.node) - for item in cursor.node.children: - self.assertIsNotNone(item.is_named) - - cursor = cursor.copy() - self.assertIs(cursor.node, cursor.node) - for item in cursor.node.children: - self.assertIsNotNone(item.is_named) - - def test_walk(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - cursor = tree.walk() - - # Node always returns the same instance - self.assertIs(cursor.node, cursor.node) - - self.assertEqual(cursor.node.type, "module") - self.assertEqual(cursor.node.start_byte, 0) - self.assertEqual(cursor.node.end_byte, 18) - self.assertEqual(cursor.node.start_point, (0, 0)) - self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.field_name, None) - - self.assertTrue(cursor.goto_first_child()) - self.assertEqual(cursor.node.type, "function_definition") - self.assertEqual(cursor.node.start_byte, 0) - self.assertEqual(cursor.node.end_byte, 18) - self.assertEqual(cursor.node.start_point, (0, 0)) - self.assertEqual(cursor.node.end_point, (1, 7)) - self.assertEqual(cursor.field_name, None) - - self.assertTrue(cursor.goto_first_child()) - self.assertEqual(cursor.node.type, "def") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.sexp(), '("def")') - self.assertEqual(cursor.field_name, None) - def_node = cursor.node - - # Node remains cached after a failure to move - self.assertFalse(cursor.goto_first_child()) - self.assertIs(cursor.node, def_node) - - self.assertTrue(cursor.goto_next_sibling()) - self.assertEqual(cursor.node.type, "identifier") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.field_name, "name") - self.assertFalse(cursor.goto_first_child()) - - self.assertTrue(cursor.goto_next_sibling()) - self.assertEqual(cursor.node.type, "parameters") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.field_name, "parameters") - - def test_edit(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - edit_offset = len(b"def foo(") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - - fn_node = tree.root_node.children[0] - self.assertEqual(fn_node.type, "function_definition") - self.assertTrue(fn_node.has_changes) - self.assertFalse(fn_node.children[0].has_changes) - self.assertFalse(fn_node.children[1].has_changes) - self.assertFalse(fn_node.children[3].has_changes) - - params_node = fn_node.children[2] - self.assertEqual(params_node.type, "parameters") - self.assertTrue(params_node.has_changes) - self.assertEqual(params_node.start_point, (0, edit_offset - 1)) - self.assertEqual(params_node.end_point, (0, edit_offset + 3)) - - new_tree = parser.parse(b"def foo(ab):\n bar()", tree) - self.assertEqual( - new_tree.root_node.sexp(), - trim( - """(module (function_definition - name: (identifier) - parameters: (parameters (identifier)) - body: (block - (expression_statement (call - function: (identifier) - arguments: (argument_list))))))""" - ), - ) - - def test_changed_ranges(self): - parser = Parser() - parser.set_language(PYTHON) - tree = parser.parse(b"def foo():\n bar()") - - edit_offset = len(b"def foo(") - tree.edit( - start_byte=edit_offset, - old_end_byte=edit_offset, - new_end_byte=edit_offset + 2, - start_point=(0, edit_offset), - old_end_point=(0, edit_offset), - new_end_point=(0, edit_offset + 2), - ) - - new_tree = parser.parse(b"def foo(ab):\n bar()", tree) - changed_ranges = tree.changed_ranges(new_tree) - - self.assertEqual(len(changed_ranges), 1) - self.assertEqual(changed_ranges[0].start_byte, edit_offset) - self.assertEqual(changed_ranges[0].start_point, (0, edit_offset)) - self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2) - self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2)) - - def test_tree_cursor(self): - parser = Parser() - parser.set_language(RUST) - - tree = parser.parse( - b""" - struct Stuff { - a: A, - b: Option, - } - """ - ) - - cursor = tree.walk() - self.assertEqual(cursor.node.type, "source_file") - - self.assertEqual(cursor.goto_first_child(), True) - self.assertEqual(cursor.node.type, "struct_item") - - self.assertEqual(cursor.goto_first_child(), True) - self.assertEqual(cursor.node.type, "struct") - self.assertEqual(cursor.node.is_named, False) - - self.assertEqual(cursor.goto_next_sibling(), True) - self.assertEqual(cursor.node.type, "type_identifier") - self.assertEqual(cursor.node.is_named, True) - - self.assertEqual(cursor.goto_next_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration_list") - self.assertEqual(cursor.node.is_named, True) - - self.assertEqual(cursor.goto_last_child(), True) - self.assertEqual(cursor.node.type, "}") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (4, 16)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, ",") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (3, 32)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.node.start_point, (3, 20)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, ",") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (2, 24)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "field_declaration") - self.assertEqual(cursor.node.is_named, True) - self.assertEqual(cursor.node.start_point, (2, 20)) - - self.assertEqual(cursor.goto_previous_sibling(), True) - self.assertEqual(cursor.node.type, "{") - self.assertEqual(cursor.node.is_named, False) - self.assertEqual(cursor.node.start_point, (1, 29)) - - copy = tree.walk() - copy.reset_to(cursor) - - self.assertEqual(copy.node.type, "{") - self.assertEqual(copy.node.is_named, False) - - self.assertEqual(copy.goto_parent(), True) - self.assertEqual(copy.node.type, "field_declaration_list") - self.assertEqual(copy.node.is_named, True) - - self.assertEqual(copy.goto_parent(), True) - self.assertEqual(copy.node.type, "struct_item") - - -class TestQuery(TestCase): - def test_errors(self): - with self.assertRaisesRegex(NameError, "Invalid node type foo"): - PYTHON.query("(list (foo))") - with self.assertRaisesRegex(NameError, "Invalid field name buzz"): - PYTHON.query("(function_definition buzz: (identifier))") - with self.assertRaisesRegex(NameError, "Invalid capture name garbage"): - PYTHON.query("((function_definition) (#eq? @garbage foo))") - with self.assertRaisesRegex(SyntaxError, "Invalid syntax at offset 6"): - PYTHON.query("(list))") - PYTHON.query("(function_definition)") - - def collect_matches( - self, - matches: List[Tuple[int, Dict[str, Union[Node, List[Node]]]]], - ) -> List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]]: - return [(m[0], self.format_captures(m[1])) for m in matches] - - def format_captures( - self, captures: Dict[str, Union[Node, List[Node]]] - ) -> List[Tuple[str, Union[str, List[str]]]]: - return [(name, self.format_capture(capture)) for name, capture in captures.items()] - - def format_capture(self, capture: Union[Node, List[Node]]) -> Union[str, List[str]]: - return ( - [n.text.decode("utf-8") for n in capture] - if isinstance(capture, List) - else capture.text.decode("utf-8") - ) - - def assert_query_matches( - self, - language: Language, - query: Query, - source: bytes, - expected: List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]], - ): - parser = Parser() - parser.set_language(language) - tree = parser.parse(source) - matches = query.matches(tree.root_node) - matches = self.collect_matches(matches) - self.assertEqual(matches, expected) - - def test_matches_with_simple_pattern(self): - query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)") - self.assert_query_matches( - JAVASCRIPT, - query, - b"function one() { two(); function three() {} }", - [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])], - ) - - def test_matches_with_multiple_on_same_root(self): - query = JAVASCRIPT.query( - """ - (class_declaration - name: (identifier) @the-class-name - (class_body - (method_definition - name: (property_identifier) @the-method-name))) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - class Person { - // the constructor - constructor(name) { this.name = name; } - - // the getter - getFullName() { return this.name; } - } - """, - [ - (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]), - (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]), - ], - ) - - def test_matches_with_multiple_patterns_different_roots(self): - query = JAVASCRIPT.query( - """ - (function_declaration name:(identifier) @fn-def) - (call_expression function:(identifier) @fn-ref) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - function f1() { - f2(f3()); - } - """, - [ - (0, [("fn-def", "f1")]), - (1, [("fn-ref", "f2")]), - (1, [("fn-ref", "f3")]), - ], - ) - - def test_matches_with_nesting_and_no_fields(self): - query = JAVASCRIPT.query( - """ - (array - (array - (identifier) @x1 - (identifier) @x2)) - """ - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b""" - [[a]]; - [[c, d], [e, f, g, h]]; - [[h], [i]]; - """, - [ - (0, [("x1", "c"), ("x2", "d")]), - (0, [("x1", "e"), ("x2", "f")]), - (0, [("x1", "e"), ("x2", "g")]), - (0, [("x1", "f"), ("x2", "g")]), - (0, [("x1", "e"), ("x2", "h")]), - (0, [("x1", "f"), ("x2", "h")]), - (0, [("x1", "g"), ("x2", "h")]), - ], - ) - - def test_matches_with_list_capture(self): - query = JAVASCRIPT.query( - """(function_declaration name: (identifier) @fn-name - body: (statement_block (_)* @fn-statements) - )""" - ) - self.assert_query_matches( - JAVASCRIPT, - query, - b"""function one() { - x = 1; - y = 2; - z = 3; - } - function two() { - x = 1; - } - """, - [ - ( - 0, - [ - ("fn-name", "one"), - ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]), - ], - ), - (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]), - ], - ) - - def test_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node) - - self.assertEqual(captures[0][0].start_point, (0, 4)) - self.assertEqual(captures[0][0].end_point, (0, 7)) - self.assertEqual(captures[0][1], "func-def") - - self.assertEqual(captures[1][0].start_point, (1, 2)) - self.assertEqual(captures[1][0].end_point, (1, 5)) - self.assertEqual(captures[1][1], "func-call") - - self.assertEqual(captures[2][0].start_point, (2, 4)) - self.assertEqual(captures[2][0].end_point, (2, 7)) - self.assertEqual(captures[2][1], "func-def") - - self.assertEqual(captures[3][0].start_point, (3, 2)) - self.assertEqual(captures[3][0].end_point, (3, 6)) - self.assertEqual(captures[3][1], "func-call") - - def test_text_predicates(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b""" - keypair_object = { - key1: value1, - equal: equal - } - - function fun1(arg) { - return 1; - } - - function fun2(arg) { - return 2; - } - """ - tree = parser.parse(source) - root_node = tree.root_node - - # function with name equal to 'fun1' -> test for #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? @function-name fun1) - ) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - - # functions with name not equal to 'fun1' -> test for #not-eq? @capture string - query2 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#not-eq? @function-name fun1) - ) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun2", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2 - query3 = JAVASCRIPT.query( - """ - ( - (pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#eq? @key-name @value-name) - ) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3])) - - # key pairs whose key is not equal to its value - # -> test for #not-eq? @capture1 @capture2 - query4 = JAVASCRIPT.query( - """ - ( - (pair - key: (property_identifier) @key-name - value: (identifier) @value-name) - (#not-eq? @key-name @value-name) - ) - """ - ) - captures4 = query4.captures(root_node) - self.assertEqual(2, len(captures4)) - self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4])) - self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4])) - - # equality that is satisfied by *another* capture - query5 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters (identifier) @parameter-name) - ) - (#eq? @function-name arg) - ) - """ - ) - captures5 = query5.captures(root_node) - self.assertEqual(0, len(captures5)) - - # functions that match the regex .*1 -> test for #match @capture regex - query6 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name ".*1") - ) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun1", captures6[0][0].text) - - # functions that do not match the regex .*1 -> test for #not-match @capture regex - query6 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#not-match? @function-name ".*1") - ) - """ - ) - captures6 = query6.captures(root_node) - self.assertEqual(1, len(captures6)) - self.assertEqual(b"fun2", captures6[0][0].text) - - # after editing there is no text property, so predicates are ignored - tree.edit( - start_byte=0, - old_end_byte=0, - new_end_byte=2, - start_point=(0, 0), - old_end_point=(0, 0), - new_end_point=(0, 2), - ) - captures_notext = query1.captures(root_node) - self.assertEqual(2, len(captures_notext)) - self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) - - def test_text_predicate_on_optional_capture(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b"fun1(1)" - tree = parser.parse(source) - root_node = tree.root_node - - # optional capture that is missing in source used in #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg "1"))) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(1, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - - # optional capture that is missing in source used in #eq? @capture @capture - query2 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#eq? @optional-string-arg @function-name))) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(1, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - - # optional capture that is missing in source used in #match? @capture string - query3 = JAVASCRIPT.query( - """ - ((call_expression - function: (identifier) @function-name - arguments: (arguments (string)? @optional-string-arg) - (#match? @optional-string-arg "\\d+"))) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(1, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) - - def test_text_predicates_errors(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? @function-name @function-name fun1) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#eq? fun1 @function-name) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name @function-name fun1) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? fun1 @function-name) - ) - """ - ) - - with self.assertRaises(RuntimeError): - JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - ) - (#match? @function-name @function-name) - ) - """ - ) - - def test_multiple_text_predicates(self): - parser = Parser() - parser.set_language(JAVASCRIPT) - source = b""" - keypair_object = { - key1: value1, - equal: equal - } - - function fun1(arg) { - return 1; - } - - function fun1(notarg) { - return 1 + 1; - } - - function fun2(arg) { - return 2; - } - """ - tree = parser.parse(source) - root_node = tree.root_node - - # function with name equal to 'fun1' -> test for first #eq? @capture string - query1 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @function-name fun1) - ) - """ - ) - captures1 = query1.captures(root_node) - self.assertEqual(4, len(captures1)) - self.assertEqual(b"fun1", captures1[0][0].text) - self.assertEqual("function-name", captures1[0][1]) - self.assertEqual(b"arg", captures1[1][0].text) - self.assertEqual("argument-name", captures1[1][1]) - self.assertEqual(b"fun1", captures1[2][0].text) - self.assertEqual("function-name", captures1[2][1]) - self.assertEqual(b"notarg", captures1[3][0].text) - self.assertEqual("argument-name", captures1[3][1]) - - # function with argument equal to 'arg' -> test for second #eq? @capture string - query2 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @argument-name arg) - ) - """ - ) - captures2 = query2.captures(root_node) - self.assertEqual(4, len(captures2)) - self.assertEqual(b"fun1", captures2[0][0].text) - self.assertEqual("function-name", captures2[0][1]) - self.assertEqual(b"arg", captures2[1][0].text) - self.assertEqual("argument-name", captures2[1][1]) - self.assertEqual(b"fun2", captures2[2][0].text) - self.assertEqual("function-name", captures2[2][1]) - self.assertEqual(b"arg", captures2[3][0].text) - self.assertEqual("argument-name", captures2[3][1]) - - # function with name equal to 'fun1' & argument 'arg' -> test for both together - query3 = JAVASCRIPT.query( - """ - ( - (function_declaration - name: (identifier) @function-name - parameters: (formal_parameters - (identifier) @argument-name - ) - ) - (#eq? @function-name fun1) - (#eq? @argument-name arg) - ) - """ - ) - captures3 = query3.captures(root_node) - self.assertEqual(2, len(captures3)) - self.assertEqual(b"fun1", captures3[0][0].text) - self.assertEqual("function-name", captures3[0][1]) - self.assertEqual(b"arg", captures3[1][0].text) - self.assertEqual("argument-name", captures3[1][1]) - - def test_byte_range_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node, start_byte=10, end_byte=20) - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") - - def test_point_range_captures(self): - parser = Parser() - parser.set_language(PYTHON) - source = b"def foo():\n bar()\ndef baz():\n quux()\n" - tree = parser.parse(source) - query = PYTHON.query( - """ - (function_definition name: (identifier) @func-def) - (call function: (identifier) @func-call) - """ - ) - - captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0)) - - self.assertEqual(captures[0][0].start_point, (1, 2)) - self.assertEqual(captures[0][0].end_point, (1, 5)) - self.assertEqual(captures[0][1], "func-call") - - def test_node_hash(self): - parser = Parser() - parser.set_language(PYTHON) - source_code = b"def foo():\n bar()\n bar()" - tree = parser.parse(source_code) - root_node = tree.root_node - first_function_node = root_node.children[0] - second_function_node = root_node.children[0] - - # Uniqueness and consistency - self.assertEqual(hash(first_function_node), hash(first_function_node)) - self.assertNotEqual(hash(root_node), hash(first_function_node)) - - # Equality implication - self.assertEqual(hash(first_function_node), hash(second_function_node)) - self.assertTrue(first_function_node == second_function_node) - - # Different nodes with different properties - different_tree = parser.parse(b"def baz():\n qux()") - different_node = different_tree.root_node.children[0] - self.assertNotEqual(hash(first_function_node), hash(different_node)) - - # Same code, different parse trees - another_tree = parser.parse(source_code) - another_node = another_tree.root_node.children[0] - self.assertNotEqual(hash(first_function_node), hash(another_node)) - - -class TestLookaheadIterator(TestCase): - def test_lookahead_iterator(self): - parser = Parser() - parser.set_language(RUST) - tree = parser.parse(b"struct Stuff{}") - - cursor = tree.walk() - - self.assertEqual(cursor.goto_first_child(), True) # struct - self.assertEqual(cursor.goto_first_child(), True) # struct keyword - - next_state = cursor.node.next_parse_state - - self.assertNotEqual(next_state, 0) - self.assertEqual( - next_state, RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id) - ) - self.assertLess(next_state, RUST.parse_state_count) - self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier - self.assertEqual(next_state, cursor.node.parse_state) - self.assertEqual(cursor.node.grammar_name, "identifier") - self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id) - - expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"] - lookahead: LookaheadIterator = RUST.lookahead_iterator(next_state) - self.assertEqual(lookahead.language, RUST.language_id) - self.assertEqual(list(lookahead.iter_names()), expected_symbols) - - lookahead.reset_state(next_state) - self.assertEqual(list(lookahead.iter_names()), expected_symbols) - - lookahead.reset(RUST.language_id, next_state) - self.assertEqual(list(map(RUST.node_kind_for_id, list(iter(lookahead)))), expected_symbols) - - -def trim(string): - return re.sub(r"\s+", " ", string).strip() - - -def get_all_nodes(tree: Tree) -> List[Node]: - result = [] - visited_children = False - cursor = tree.walk() - while True: - if not visited_children: - result.append(cursor.node) - if not cursor.goto_first_child(): - visited_children = True - elif cursor.goto_next_sibling(): - visited_children = False - elif not cursor.goto_parent(): - break - return result diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py index 18d8f0f6..a21fca5c 100644 --- a/tree_sitter/__init__.py +++ b/tree_sitter/__init__.py @@ -14,9 +14,9 @@ MIN_COMPATIBLE_LANGUAGE_VERSION, ) -Point.__doc__ = 'A position in a multi-line text document, in terms of rows and columns.' -Point.row.__doc__ = 'The zero-based row of the document.' -Point.column.__doc__ = 'The zero-based column of the document.' +Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns." +Point.row.__doc__ = "The zero-based row of the document." +Point.column.__doc__ = "The zero-based column of the document." __all__ = [ "Language", diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index 07d24610..bd88ccd0 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -4,16 +4,14 @@ from typing_extensions import deprecated _Ptr = Annotated[int, "TSLanguage *"] -_ParseCB = Callable[[int, Point], bytes] +_ParseCB = Callable[[int, Point | tuple[int, int]], bytes] _UINT32_MAX = 0xFFFFFFFF - class Point(NamedTuple): row: int column: int - @final class Language: def __init__(self, ptr: _Ptr, /) -> None: ... @@ -24,288 +22,195 @@ class Language: @property def version(self) -> int: ... - @property def node_kind_count(self) -> int: ... - @property def parse_state_count(self) -> int: ... - @property def field_count(self) -> int: ... - def node_kind_for_id(self, id: int, /) -> str | None: ... - def id_for_node_kind(self, kind: str, named: bool, /) -> int | None: ... - def node_kind_is_named(self, id: int, /) -> bool: ... - def node_kind_is_visible(self, id: int, /) -> bool: ... - def field_name_for_id(self, field_id: int, /) -> str | None: ... - def field_id_for_name(self, name: str, /) -> int | None: ... - def next_state(self, state: int, id: int, /) -> int: ... - def lookahead_iterator(self, state: int, /) -> LookaheadIterator | None: ... - def query(self, source: str, /) -> Query: ... - def __repr__(self) -> str: ... - def __eq__(self, other: Any, /) -> bool: ... - def __ne__(self, other: Any, /) -> bool: ... - def __int__(self) -> int: ... - def __index__(self) -> int: ... - def __hash__(self) -> int: ... - @final class Node: @property def id(self) -> int: ... - @property def kind_id(self) -> int: ... - @property def grammar_id(self) -> int: ... - @property def grammar_name(self) -> str: ... - @property def type(self) -> str: ... - @property def is_named(self) -> bool: ... - @property def is_extra(self) -> bool: ... - @property def has_changes(self) -> bool: ... - @property def has_error(self) -> bool: ... - @property def is_error(self) -> bool: ... - @property def parse_state(self) -> int: ... - @property def next_parse_state(self) -> int: ... - @property def is_missing(self) -> bool: ... - @property def start_byte(self) -> int: ... - @property def end_byte(self) -> int: ... - @property def byte_range(self) -> tuple[int, int]: ... - @property def range(self) -> Range: ... - @property def start_point(self) -> Point: ... - @property def end_point(self) -> Point: ... - @property def children(self) -> list[Node]: ... - @property def child_count(self) -> int: ... - @property def named_children(self) -> list[Node]: ... - @property def named_child_count(self) -> int: ... - @property def parent(self) -> Node | None: ... - @property def next_sibling(self) -> Node | None: ... - @property def prev_sibling(self) -> Node | None: ... - @property def next_named_sibling(self) -> Node | None: ... - @property def prev_named_sibling(self) -> Node | None: ... - @property def descendant_count(self) -> int: ... - @property def text(self) -> bytes | None: ... - def walk(self) -> TreeCursor: ... - def edit( self, start_byte: int, old_end_byte: int, new_end_byte: int, - start_point: Point, - old_end_point: Point, - new_end_point: Point, + start_point: Point | tuple[int, int], + old_end_point: Point | tuple[int, int], + new_end_point: Point | tuple[int, int], ) -> None: ... - def child(self, index: int, /) -> Node | None: ... - def named_child(self, index: int, /) -> Node | None: ... - def child_by_field_id(self, id: int, /) -> Node | None: ... - def child_by_field_name(self, name: str, /) -> Node | None: ... - def children_by_field_id(self, id: int, /) -> list[Node]: ... - def children_by_field_name(self, name: str, /) -> list[Node]: ... - def field_name_for_child(self, child_index: int, /) -> str | None: ... - def descendant_for_byte_range( self, start_byte: int, end_byte: int, /, ) -> Node | None: ... - def named_descendant_for_byte_range( self, start_byte: int, end_byte: int, /, ) -> Node | None: ... - def descendant_for_point_range( self, - start_point: Point, - end_point: Point, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], /, ) -> Node | None: ... - def named_descendant_for_point_range( self, - start_point: Point, - end_point: Point, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], /, ) -> Node | None: ... - @deprecated("Use `str()` instead") def sexp(self) -> str: ... - def __repr__(self) -> str: ... - def __str__(self) -> str: ... - def __eq__(self, other: Any, /) -> bool: ... - def __ne__(self, other: Any, /) -> bool: ... - def __hash__(self) -> int: ... - @final -class Tree(): +class Tree: @property def root_node(self) -> Node: ... - @property def included_ranges(self) -> list[Range]: ... - @property @deprecated("Use `root_node.text` instead") def text(self) -> bytes | None: ... - def root_node_with_offset( self, offset_bytes: int, - offset_extent: Point, + offset_extent: Point | tuple[int, int], /, ) -> Node | None: ... - def edit( self, start_byte: int, old_end_byte: int, new_end_byte: int, - start_point: Point, - old_end_point: Point, - new_end_point: Point, + start_point: Point | tuple[int, int], + old_end_point: Point | tuple[int, int], + new_end_point: Point | tuple[int, int], ) -> None: ... - def walk(self) -> TreeCursor: ... - def changed_ranges(self, new_tree: Tree) -> list[Range]: ... - @final class TreeCursor: @property def node(self) -> Node: ... - @property def field_id(self) -> int | None: ... - @property def field_name(self) -> str | None: ... - @property def depth(self) -> int: ... - @property def descendant_index(self) -> int: ... - def copy(self) -> TreeCursor: ... - def reset(self, node: Node, /) -> None: ... - def reset_to(self, cursor: TreeCursor, /) -> None: ... - def goto_first_child(self) -> bool: ... - def goto_last_child(self) -> bool: ... - def goto_parent(self) -> bool: ... - def goto_next_sibling(self) -> bool: ... - def goto_previous_sibling(self) -> bool: ... - def goto_descendant(self, index: int, /) -> None: ... - def goto_first_child_for_byte(self, byte: int, /) -> bool: ... - @overload - def goto_first_child_for_point(self, point: Point, /) -> bool: ... - + def goto_first_child_for_point(self, point: Point | tuple[int, int], /) -> bool: ... @overload @deprecated("Use `goto_first_child_for_point(point)` instead") def goto_first_child_for_point(self, row: int, column: int, /) -> bool: ... - def __copy__(self) -> TreeCursor: ... - @final class Parser: def __init__( @@ -313,33 +218,24 @@ class Parser: language: Language | None = None, *, included_ranges: Sequence[Range] | None = None, - timeout_micros: int | None = None + timeout_micros: int | None = None, ) -> None: ... - @property def language(self) -> Language | None: ... - @language.setter def language(self, language: Language) -> None: ... - @language.deleter def language(self) -> None: ... - @property def included_ranges(self) -> list[Range]: ... - @included_ranges.setter def included_ranges(self, ranges: Sequence[Range]) -> None: ... - @included_ranges.deleter def included_ranges(self) -> None: ... - @property def timeout_micros(self) -> int: ... - @timeout_micros.setter def timeout_micros(self, timeout: int) -> None: ... - @timeout_micros.deleter def timeout_micros(self) -> None: ... @@ -352,7 +248,6 @@ class Parser: /, old_tree: Tree | None = None, ) -> Tree: ... - @overload @deprecated("`keep_text` will be removed") def parse( @@ -362,19 +257,14 @@ class Parser: old_tree: Tree | None = None, keep_text: bool = True, ) -> Tree: ... - def reset(self) -> None: ... - @deprecated("Use the `language` setter instead") def set_language(self, language: Language, /) -> None: ... - @deprecated("Use the `included_ranges` setter instead") def set_included_ranges(self, ranges: Sequence[Range], /) -> None: ... - @deprecated("Use the `timeout_micros` setter instead") def set_timeout_micros(self, timeout: int, /) -> None: ... - @final class Query: def __init__(self, language: Language, source: str) -> None: ... @@ -386,76 +276,67 @@ class Query: self, node: Node, *, - start_point: Point = Point(0, 0), - end_point: Point = Point(_UINT32_MAX, _UINT32_MAX), + start_point: Point | tuple[int, int] = Point(0, 0), + end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), start_byte: int = 0, end_byte: int = _UINT32_MAX, ) -> list[tuple[Node, str]]: ... - def matches( self, node: Node, *, - start_point: Point = Point(0, 0), - end_point: Point = Point(_UINT32_MAX, _UINT32_MAX), + start_point: Point | tuple[int, int] = Point(0, 0), + end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX), start_byte: int = 0, end_byte: int = _UINT32_MAX, ) -> list[tuple[int, dict[str, Node | list[Node]]]]: ... - @final class LookaheadIterator(Iterator[int]): @property def language(self) -> Language: ... - @property def current_symbol(self) -> int: ... - @property def current_symbol_name(self) -> str: ... - @deprecated("Use `reset_state()` instead") def reset(self, language: _Ptr, state: int, /) -> None: ... # TODO(0.24): rename to reset def reset_state(self, state: int, language: Language | None = None) -> None: ... - def iter_names(self) -> Iterator[str]: ... - def __next__(self) -> int: ... - @final class Range: def __init__( self, - start_point: Point, - end_point: Point, + start_point: Point | tuple[int, int], + end_point: Point | tuple[int, int], start_byte: int, end_byte: int, ) -> None: ... - @property - def start_point(self): Point + def start_point(self): + Point @property - def end_point(self): Point + def end_point(self): + Point @property - def start_byte(self): int + def start_byte(self): + int @property - def end_byte(self): int + def end_byte(self): + int def __eq__(self, other: Any, /) -> bool: ... - def __ne__(self, other: Any, /) -> bool: ... - def __repr__(self) -> str: ... - def __hash__(self) -> int: ... - LANGUAGE_VERSION: Final[int] MIN_COMPATIBLE_LANGUAGE_VERSION: Final[int] diff --git a/tree_sitter/binding/language.c b/tree_sitter/binding/language.c index e74af1a7..94172dd6 100644 --- a/tree_sitter/binding/language.c +++ b/tree_sitter/binding/language.c @@ -16,7 +16,7 @@ static void segfault_handler(int signal) { TSLanguage *language_check_pointer(void *ptr) { PyOS_setsig(SIGSEGV, segfault_handler); if (!setjmp(segv_jmp)) { - (void)ts_language_version(ptr); + __attribute__((unused)) volatile uint32_t version = ts_language_version((TSLanguage *)ptr); } else { PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer"); } @@ -29,7 +29,7 @@ TSLanguage *language_check_pointer(void *ptr) { // HACK: recover from invalid pointer using SEH (Windows) TSLanguage *language_check_pointer(void *ptr) { __try { - (void)ts_language_version(ptr); + volatile uint32_t version = ts_language_version((TSLanguage *)ptr); } __except (GetExceptionCode() == EXCEPTION_ACCESS_VIOLATION ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH) { PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer"); @@ -43,7 +43,7 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) { if (!PyArg_ParseTuple(args, "O:__init__", &language)) { return -1; } - if (PyLong_AsLong(language) < 1) { + if (PyLong_AsSsize_t(language) < 1) { if (!PyErr_Occurred()) { PyErr_SetString(PyExc_ValueError, "language ID must be positive"); } @@ -61,7 +61,10 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) { return 0; } -void language_dealloc(Language *self) { Py_TYPE(self)->tp_free(self); } +void language_dealloc(Language *self) { + ts_language_delete(self->language); + Py_TYPE(self)->tp_free(self); +} PyObject *language_repr(Language *self) { #if HAS_LANGUAGE_NAMES @@ -88,7 +91,7 @@ PyObject *language_compare(Language *self, PyObject *other, int op) { Language *lang = (Language *)other; bool result = (Py_uintptr_t)self->language == (Py_uintptr_t)lang->language; - return PyBool_FromLong(result & (op == Py_EQ)); + return PyBool_FromLong(result ^ (op == Py_NE)); } #if HAS_LANGUAGE_NAMES @@ -128,7 +131,7 @@ PyObject *language_node_kind_for_id(Language *self, PyObject *args) { PyObject *language_id_for_node_kind(Language *self, PyObject *args) { char *kind; Py_ssize_t length; - bool named; + int named; if (!PyArg_ParseTuple(args, "s#p:id_for_node_kind", &kind, &length, &named)) { return NULL; } @@ -205,8 +208,9 @@ PyObject *language_lookahead_iterator(Language *self, PyObject *args) { if (iter == NULL) { return NULL; } - iter->lookahead_iterator = lookahead_iterator; + Py_INCREF(self); iter->language = (PyObject *)self; + iter->lookahead_iterator = lookahead_iterator; return PyObject_Init((PyObject *)iter, state->lookahead_iterator_type); } diff --git a/tree_sitter/binding/lookahead_iterator.c b/tree_sitter/binding/lookahead_iterator.c index 84fcc2ac..b4f4e571 100644 --- a/tree_sitter/binding/lookahead_iterator.c +++ b/tree_sitter/binding/lookahead_iterator.c @@ -24,11 +24,9 @@ PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *Py_UNUS } language->language = language_id; language->version = ts_language_version(language->language); - PyObject *obj = PyObject_Init((PyObject *)language, state->language_type); - Py_XSETREF(self->language, obj); - } else { - Py_INCREF(self->language); + self->language = PyObject_Init((PyObject *)language, state->language_type); } + Py_INCREF(self->language); return self->language; } @@ -65,7 +63,7 @@ PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) { PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args, PyObject *kwargs) { uint16_t state_id; - PyObject *language_obj; + PyObject *language_obj = NULL; ModuleState *state = GET_MODULE_STATE(self); char *keywords[] = {"state", "language", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "H|O!:reset_state", keywords, &state_id, diff --git a/tree_sitter/binding/node.c b/tree_sitter/binding/node.c index 0e5d5328..7a56258f 100644 --- a/tree_sitter/binding/node.c +++ b/tree_sitter/binding/node.c @@ -43,7 +43,7 @@ PyObject *node_compare(Node *self, PyObject *other, int op) { } bool result = ts_node_eq(self->node, ((Node *)other)->node); - return PyBool_FromLong(result & (op == Py_EQ)); + return PyBool_FromLong(result ^ (op == Py_NE)); } PyObject *node_sexp(Node *self, PyObject *Py_UNUSED(args)) { @@ -59,9 +59,11 @@ PyObject *node_walk(Node *self, PyObject *Py_UNUSED(args)) { if (tree_cursor == NULL) { return NULL; } - tree_cursor->cursor = ts_tree_cursor_new(self->node); + Py_INCREF(self->tree); tree_cursor->tree = self->tree; + tree_cursor->node = NULL; + tree_cursor->cursor = ts_tree_cursor_new(self->node); return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type); } @@ -436,7 +438,7 @@ PyObject *node_get_named_children(Node *self, void *payload) { PyObject *child = PyList_GetItem(self->children, i); if (ts_node_is_named(((Node *)child)->node)) { Py_INCREF(child); - if (PyList_SetItem(result, ++j, child)) { + if (PyList_SetItem(result, j++, child)) { Py_DECREF(result); return NULL; } diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c index ce47f280..ab0c4e10 100644 --- a/tree_sitter/binding/parser.c +++ b/tree_sitter/binding/parser.c @@ -92,7 +92,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo // Store return value in payload so its reference count can be decremented and // return string representation of bytes. wrapper_payload->previous_return_value = rv; - *bytes_read = PyBytes_Size(rv); + *bytes_read = (uint32_t)PyBytes_Size(rv); return PyBytes_AsString(rv); } @@ -100,7 +100,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { ModuleState *state = GET_MODULE_STATE(self); PyObject *source_or_callback; PyObject *old_tree_obj = NULL; - bool keep_text = true; + int keep_text = 1; char *keywords[] = {"", "old_tree", "keep_text", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback, state->tree_type, &old_tree_obj, &keep_text)) { @@ -114,7 +114,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) { // parse a buffer const char *source_bytes = (const char *)source_view.buf; - size_t length = source_view.len; + uint32_t length = (uint32_t)source_view.len; new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length); PyBuffer_Release(&source_view); } else if (PyCallable_Check(source_or_callback)) { @@ -137,7 +137,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { source_or_callback = Py_None; keep_text = 0; } else { - PyErr_SetString(PyExc_TypeError, "source must be a byte buffer or a callable"); + PyErr_SetString(PyExc_TypeError, "source must be a bytestring or a callable"); return NULL; } @@ -172,27 +172,20 @@ int parser_set_timeout_micros(Parser *self, PyObject *arg, void *Py_UNUSED(paylo ts_parser_set_timeout_micros(self->parser, 0); return 0; } - if (!PyLong_CheckExact(arg)) { + if (!PyLong_Check(arg)) { PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s", arg->ob_type->tp_name); return -1; } - long timeout = PyLong_AsLong(arg); - if (timeout < 0) { - PyErr_SetString(PyExc_ValueError, "'timeout_micros' must be a positive integer"); - return -1; - } - - ts_parser_set_timeout_micros(self->parser, timeout); + ts_parser_set_timeout_micros(self->parser, PyLong_AsUnsignedLong(arg)); return 0; } PyObject *parser_set_timeout_micros_old(Parser *self, PyObject *arg) { - if (PyLong_AsLong(arg) < 0) { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_ValueError, "'timeout_micros' must be a positive integer"); - } + if (!PyLong_Check(arg)) { + PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s", + arg->ob_type->tp_name); return NULL; } if (REPLACE("Parser.set_timeout_micros()", "the timeout_micros setter") < 0) { @@ -235,7 +228,7 @@ int parser_set_included_ranges(Parser *self, PyObject *arg, void *Py_UNUSED(payl return -1; } - uint32_t length = PyList_Size(arg); + uint32_t length = (uint32_t)PyList_Size(arg); TSRange *ranges = PyMem_Calloc(length, sizeof(TSRange)); if (!ranges) { PyErr_Format(PyExc_MemoryError, "Failed to allocate memory for ranges of length %u", @@ -315,8 +308,8 @@ int parser_set_language(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) { return -1; } - self->language = (PyObject *)language; - Py_INCREF(self->language); + Py_INCREF(language); + Py_XSETREF(self->language, (PyObject *)language); return 0; } diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c index 89b39aae..b43f757b 100644 --- a/tree_sitter/binding/query.c +++ b/tree_sitter/binding/query.c @@ -490,10 +490,10 @@ PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) { "node", "start_point", "end_point", "start_byte", "end_byte", NULL, }; PyObject *node_obj; - TSPoint start_point = {.row = 0, .column = 0}; - TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX}; - unsigned start_byte = 0, end_byte = UINT32_MAX; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|$(II)(II)II:matches", keywords, + TSPoint start_point = {0, 0}; + TSPoint end_point = {UINT32_MAX, UINT32_MAX}; + uint32_t start_byte = 0, end_byte = UINT32_MAX; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:matches", keywords, state->node_type, &node_obj, &start_point.row, &start_point.column, &end_point.row, &end_point.column, &start_byte, &end_byte)) { diff --git a/tree_sitter/binding/range.c b/tree_sitter/binding/range.c index fc52fb2b..a44c4f01 100644 --- a/tree_sitter/binding/range.c +++ b/tree_sitter/binding/range.c @@ -33,13 +33,35 @@ PyObject *range_repr(Range *self) { Py_hash_t range_hash(Range *self) { // FIXME: replace with an efficient integer hashing algorithm - PyObject *row_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_point.row), + PyObject *row_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.row), PyLong_FromLong(self->range.end_point.row)); - PyObject *col_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_point.column), - PyLong_FromLong(self->range.end_point.column)); - PyObject *bytes_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_byte), - PyLong_FromLong(self->range.end_byte)); + if (!row_tuple) { + return NULL; + } + + PyObject *col_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.column), + PyLong_FromSize_t(self->range.end_point.column)); + if (!col_tuple) { + Py_DECREF(row_tuple); + return NULL; + } + + PyObject *bytes_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_byte), + PyLong_FromSize_t(self->range.end_byte)); + if (!bytes_tuple) { + Py_DECREF(row_tuple); + Py_DECREF(col_tuple); + return NULL; + } + PyObject *range_tuple = PyTuple_Pack(3, row_tuple, col_tuple, bytes_tuple); + if (!range_tuple) { + Py_DECREF(row_tuple); + Py_DECREF(col_tuple); + Py_DECREF(bytes_tuple); + return NULL; + } + Py_hash_t hash = PyObject_Hash(range_tuple); Py_DECREF(range_tuple); @@ -61,7 +83,7 @@ PyObject *range_compare(Range *self, PyObject *other, int op) { (self->range.end_point.row == range->range.end_point.row) && (self->range.end_point.column == range->range.end_point.column) && (self->range.end_byte == range->range.end_byte)); - return PyBool_FromLong(result & (op == Py_EQ)); + return PyBool_FromLong(result ^ (op == Py_NE)); } PyObject *range_get_start_point(Range *self, void *Py_UNUSED(payload)) { diff --git a/tree_sitter/binding/tree.c b/tree_sitter/binding/tree.c index 97e1943b..1c03e2dc 100644 --- a/tree_sitter/binding/tree.c +++ b/tree_sitter/binding/tree.c @@ -45,9 +45,11 @@ PyObject *tree_walk(Tree *self, PyObject *Py_UNUSED(args)) { if (tree_cursor == NULL) { return NULL; } - tree_cursor->cursor = ts_tree_cursor_new(ts_tree_root_node(self->tree)); + Py_INCREF(self); tree_cursor->tree = (PyObject *)self; + tree_cursor->node = NULL; + tree_cursor->cursor = ts_tree_cursor_new(ts_tree_root_node(self->tree)); return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type); } diff --git a/tree_sitter/binding/tree_cursor.c b/tree_sitter/binding/tree_cursor.c index f2f93883..7dfbb8df 100644 --- a/tree_sitter/binding/tree_cursor.c +++ b/tree_sitter/binding/tree_cursor.c @@ -9,15 +9,14 @@ void tree_cursor_dealloc(TreeCursor *self) { } PyObject *tree_cursor_get_node(TreeCursor *self, void *Py_UNUSED(payload)) { - ModuleState *state = GET_MODULE_STATE(self); - if (!self->node) { + if (self->node == NULL) { TSNode current_node = ts_tree_cursor_current_node(&self->cursor); if (ts_node_is_null(current_node)) { Py_RETURN_NONE; } - return node_new_internal(state, current_node, self->tree); + ModuleState *state = GET_MODULE_STATE(self); + self->node = node_new_internal(state, current_node, self->tree); } - Py_INCREF(self->node); return self->node; } @@ -174,9 +173,10 @@ PyObject *tree_cursor_copy(PyObject *self, PyObject *Py_UNUSED(args)) { if (copied == NULL) { return NULL; } - copied->cursor = ts_tree_cursor_copy(&origin->cursor); + Py_INCREF(origin->tree); copied->tree = origin->tree; + copied->cursor = ts_tree_cursor_copy(&origin->cursor); return PyObject_Init((PyObject *)copied, state->tree_cursor_type); } diff --git a/tree_sitter/core b/tree_sitter/core index 4bbaee2f..6e6dcf1c 160000 --- a/tree_sitter/core +++ b/tree_sitter/core @@ -1 +1 @@ -Subproject commit 4bbaee2f56d1febc5992c6ebae556d35d571a712 +Subproject commit 6e6dcf1cafb00300338b46bb4bffcd05ad99fafc