diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 64717a4c..6a2322a2 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -34,16 +34,11 @@ jobs:
with:
python-version: ${{matrix.python}}
- name: Lint
- run: pipx run ruff check .
+ continue-on-error: true
+ run: pipx run ruff check . --output-format=github
- name: Build
- run: pip install -e .
+ run: pip install -v -e .[tests]
env:
- CFLAGS: "-O0 -g"
+ CFLAGS: -Wextra -Og -g -fno-omit-frame-pointer
- name: Test
- shell: python
- # run: python -Wignore:::tree_sitter -munittest
- run: |-
- try: __import__('tree_sitter').Language(1048576)
- except RuntimeError as err: print(err)
- env:
- PYTHONFAULTHANDLER: 1
+ run: python -munittest -v
diff --git a/.gitmodules b/.gitmodules
index fa624605..a5f5accd 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,21 +1,3 @@
[submodule "tree-sitter"]
path = tree_sitter/core
url = https://github.com/tree-sitter/tree-sitter
-[submodule "tree-sitter-embedded-template"]
- path = tests/fixtures/tree-sitter-embedded-template
- url = https://github.com/tree-sitter/tree-sitter-embedded-template
-[submodule "tree-sitter-html"]
- path = tests/fixtures/tree-sitter-html
- url = https://github.com/tree-sitter/tree-sitter-html
-[submodule "tree-sitter-javascript"]
- path = tests/fixtures/tree-sitter-javascript
- url = https://github.com/tree-sitter/tree-sitter-javascript
-[submodule "tree-sitter-json"]
- path = tests/fixtures/tree-sitter-json
- url = https://github.com/tree-sitter/tree-sitter-json
-[submodule "tree-sitter-python"]
- path = tests/fixtures/tree-sitter-python
- url = https://github.com/tree-sitter/tree-sitter-python
-[submodule "tree-sitter-rust"]
- path = tests/fixtures/tree-sitter-rust
- url = https://github.com/tree-sitter/tree-sitter-rust
diff --git a/docs/conf.py b/docs/conf.py
index 4331e53b..773f4701 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -54,30 +54,30 @@
html_favicon = "_static/favicon.png"
-special_doc = regex('\S*self[^.]+')
+special_doc = regex("\S*self[^.]+")
def process_signature(_app, _what, name, _obj, _options, _signature, return_annotation):
- if name == 'tree_sitter.Language':
- return '(ptr)', return_annotation
- if name == 'tree_sitter.Query':
- return '(language, source)', return_annotation
- if name == 'tree_sitter.Parser':
- return '(language, *, included_ranges=None, timeout_micros=None)', return_annotation
- if name == 'tree_sitter.Range':
- return '(start_point, end_point, start_byte, end_byte)', return_annotation
+ if name == "tree_sitter.Language":
+ return "(ptr)", return_annotation
+ if name == "tree_sitter.Query":
+ return "(language, source)", return_annotation
+ if name == "tree_sitter.Parser":
+ return "(language, *, included_ranges=None, timeout_micros=None)", return_annotation
+ if name == "tree_sitter.Range":
+ return "(start_point, end_point, start_byte, end_byte)", return_annotation
def process_docstring(_app, what, name, _obj, _options, lines):
- if what == 'data':
+ if what == "data":
lines.clear()
- elif what == 'method':
- if name.endswith('__index__'):
- lines[0] = 'Converts ``self`` to an integer for use as an index.'
- elif name.endswith('__') and lines and 'self' in lines[0]:
- lines[0] = f'Implements ``{special_doc.search(lines[0]).group(0)}``.'
+ elif what == "method":
+ if name.endswith("__index__"):
+ lines[0] = "Converts ``self`` to an integer for use as an index."
+ elif name.endswith("__") and lines and "self" in lines[0]:
+ lines[0] = f"Implements ``{special_doc.search(lines[0]).group(0)}``."
def setup(app):
- app.connect('autodoc-process-signature', process_signature)
- app.connect('autodoc-process-docstring', process_docstring)
+ app.connect("autodoc-process-signature", process_signature)
+ app.connect("autodoc-process-docstring", process_docstring)
diff --git a/pyproject.toml b/pyproject.toml
index c6b24b6d..d2ca87d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -31,6 +31,13 @@ email = "maxbrunsfeld@gmail.com"
[project.optional-dependencies]
docs = ["sphinx~=7.3", "sphinx-book-theme"]
+tests = [
+ "tree-sitter-html",
+ "tree-sitter-javascript",
+ "tree-sitter-json",
+ "tree-sitter-python",
+ "tree-sitter-rust",
+]
[tool.ruff]
target-version = "py39"
@@ -39,7 +46,7 @@ indent-width = 4
extend-exclude = [
".github",
"__pycache__",
- "tests/fixtures",
+ "setup.py",
"tree_sitter/core",
]
@@ -49,6 +56,7 @@ indent-style = "space"
[tool.cibuildwheel]
build-frontend = "build"
+test-extras = ["tests"]
test-command = "python -munittest discover -s {project}/tests"
[tool.cibuildwheel.environment]
diff --git a/setup.py b/setup.py
index 4bf875bc..77ea76a9 100644
--- a/setup.py
+++ b/setup.py
@@ -1,14 +1,12 @@
-"""Py-Tree-sitter"""
-
from platform import system
-from setuptools import Extension, setup
+from setuptools import Extension, setup # type: ignore
setup(
packages=["tree_sitter"],
include_package_data=False,
package_data={
- "tree_sitter": ["py.typed", "*.pyi"]
+ "tree_sitter": ["py.typed", "*.pyi"],
},
ext_modules=[
Extension(
@@ -41,8 +39,12 @@
extra_compile_args=[
"-std=c11",
"-fvisibility=hidden",
+ "-Wno-cast-function-type",
"-Werror=implicit-function-declaration",
- ] if system() != "Windows" else None
+ ] if system() != "Windows" else [
+ "/std:c11",
+ "/wd4244",
+ ],
)
- ]
+ ],
)
diff --git a/tests/fixtures/tree-sitter-embedded-template b/tests/fixtures/tree-sitter-embedded-template
deleted file mode 160000
index 6d791b89..00000000
--- a/tests/fixtures/tree-sitter-embedded-template
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 6d791b897ecda59baa0689a85a9906348a2a6414
diff --git a/tests/fixtures/tree-sitter-html b/tests/fixtures/tree-sitter-html
deleted file mode 160000
index b285e25c..00000000
--- a/tests/fixtures/tree-sitter-html
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit b285e25c1ba8729399ce4f15ac5375cf6c3aa5be
diff --git a/tests/fixtures/tree-sitter-javascript b/tests/fixtures/tree-sitter-javascript
deleted file mode 160000
index de1e6822..00000000
--- a/tests/fixtures/tree-sitter-javascript
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit de1e682289a417354df5b4437a3e4f92e0722a0f
diff --git a/tests/fixtures/tree-sitter-json b/tests/fixtures/tree-sitter-json
deleted file mode 160000
index 3b129203..00000000
--- a/tests/fixtures/tree-sitter-json
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 3b129203f4b72d532f58e72c5310c0a7db3b8e6d
diff --git a/tests/fixtures/tree-sitter-python b/tests/fixtures/tree-sitter-python
deleted file mode 160000
index b8a4c641..00000000
--- a/tests/fixtures/tree-sitter-python
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit b8a4c64121ba66b460cb878e934e3157ecbfb124
diff --git a/tests/fixtures/tree-sitter-rust b/tests/fixtures/tree-sitter-rust
deleted file mode 160000
index 3a56481f..00000000
--- a/tests/fixtures/tree-sitter-rust
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 3a56481f8d13b6874a28752502a58520b9139dc7
diff --git a/tests/test_language.py b/tests/test_language.py
new file mode 100644
index 00000000..50d86047
--- /dev/null
+++ b/tests/test_language.py
@@ -0,0 +1,86 @@
+from unittest import TestCase
+
+from tree_sitter import Language, Query
+
+import tree_sitter_html
+import tree_sitter_javascript
+import tree_sitter_json
+import tree_sitter_python
+import tree_sitter_rust
+
+
+class TestLanguage(TestCase):
+ def setUp(self):
+ self.html = tree_sitter_html.language()
+ self.javascript = tree_sitter_javascript.language()
+ self.json = tree_sitter_json.language()
+ self.python = tree_sitter_python.language()
+ self.rust = tree_sitter_rust.language()
+
+ def test_init_not_positive(self):
+ self.assertRaises(ValueError, Language, -1)
+
+ def test_init_segv(self):
+ self.assertRaises(RuntimeError, Language, 1024)
+
+ def test_properties(self):
+ lang = Language(self.python)
+ self.assertEqual(lang.version, 14)
+ self.assertEqual(lang.node_kind_count, 274)
+ self.assertEqual(lang.parse_state_count, 2831)
+ self.assertEqual(lang.field_count, 32)
+
+ def test_node_kind_for_id(self):
+ lang = Language(self.json)
+ self.assertEqual(lang.node_kind_for_id(1), "{")
+ self.assertEqual(lang.node_kind_for_id(3), "}")
+
+ def test_id_for_node_kind(self):
+ lang = Language(self.json)
+ self.assertEqual(lang.id_for_node_kind(":", False), 4)
+ self.assertEqual(lang.id_for_node_kind("string", True), 20)
+
+ def test_node_kind_is_named(self):
+ lang = Language(self.json)
+ self.assertFalse(lang.node_kind_is_named(4))
+ self.assertTrue(lang.node_kind_is_named(20))
+
+ def test_node_kind_is_visible(self):
+ lang = Language(self.json)
+ self.assertTrue(lang.node_kind_is_visible(2))
+
+ def test_field_name_for_id(self):
+ lang = Language(self.json)
+ self.assertEqual(lang.field_name_for_id(1), "key")
+ self.assertEqual(lang.field_name_for_id(2), "value")
+
+ def test_field_id_for_name(self):
+ lang = Language(self.json)
+ self.assertEqual(lang.field_id_for_name("key"), 1)
+ self.assertEqual(lang.field_id_for_name("value"), 2)
+
+ def test_next_state(self):
+ lang = Language(self.javascript)
+ self.assertNotEqual(lang.next_state(1, 1), 0)
+
+ def test_lookahead_iterator(self):
+ lang = Language(self.javascript)
+ self.assertIsNotNone(lang.lookahead_iterator(0))
+ self.assertIsNone(lang.lookahead_iterator(9999))
+
+ def test_query(self):
+ lang = Language(self.json)
+ query = lang.query("(string) @string")
+ self.assertIsInstance(query, Query)
+
+ def test_eq(self):
+ self.assertEqual(Language(self.json), Language(self.json))
+ self.assertNotEqual(Language(self.rust), Language(self.html))
+
+ def test_int(self):
+ for name in ["html", "javascript", "json", "python", "rust"]:
+ with self.subTest(language=name):
+ ptr = getattr(self, name)
+ lang = Language(ptr)
+ self.assertEqual(int(lang), ptr)
+ self.assertEqual(hash(lang), ptr)
diff --git a/tests/test_lookahead_iterator.py b/tests/test_lookahead_iterator.py
new file mode 100644
index 00000000..540da6a1
--- /dev/null
+++ b/tests/test_lookahead_iterator.py
@@ -0,0 +1,43 @@
+from unittest import TestCase
+
+from tree_sitter import Language, Parser
+
+import tree_sitter_rust
+
+
+class TestLookaheadIterator(TestCase):
+ @classmethod
+ def setUpClass(self):
+ self.rust = Language(tree_sitter_rust.language())
+
+ def test_lookahead_iterator(self):
+ parser = Parser(self.rust)
+ cursor = parser.parse(b"struct Stuff{}").walk()
+
+ self.assertEqual(cursor.goto_first_child(), True) # struct
+ self.assertEqual(cursor.goto_first_child(), True) # struct keyword
+
+ next_state = cursor.node.next_parse_state
+
+ self.assertNotEqual(next_state, 0)
+ self.assertEqual(
+ next_state, self.rust.next_state(cursor.node.parse_state, cursor.node.grammar_id)
+ )
+ self.assertLess(next_state, self.rust.parse_state_count)
+ self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier
+ self.assertEqual(next_state, cursor.node.parse_state)
+ self.assertEqual(cursor.node.grammar_name, "identifier")
+ self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id)
+
+ expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"]
+ lookahead = self.rust.lookahead_iterator(next_state)
+ self.assertEqual(lookahead.language, self.rust)
+ self.assertListEqual(list(lookahead.iter_names()), expected_symbols)
+
+ lookahead.reset_state(next_state)
+ self.assertListEqual(list(lookahead.iter_names()), expected_symbols)
+
+ lookahead.reset_state(next_state, self.rust)
+ self.assertListEqual(
+ list(map(self.rust.node_kind_for_id, list(iter(lookahead)))), expected_symbols
+ )
diff --git a/tests/test_node.py b/tests/test_node.py
new file mode 100644
index 00000000..6df15656
--- /dev/null
+++ b/tests/test_node.py
@@ -0,0 +1,480 @@
+from unittest import TestCase
+
+import tree_sitter_python
+import tree_sitter_javascript
+import tree_sitter_json
+
+from tree_sitter import Language, Parser
+
+JSON_EXAMPLE = b"""
+
+[
+ 123,
+ false,
+ {
+ "x": null
+ }
+]
+"""
+
+
+def get_all_nodes(tree):
+ result = []
+ visited_children = False
+ cursor = tree.walk()
+ while True:
+ if not visited_children:
+ result.append(cursor.node)
+ if not cursor.goto_first_child():
+ visited_children = True
+ elif cursor.goto_next_sibling():
+ visited_children = False
+ elif not cursor.goto_parent():
+ break
+ return result
+
+
+class TestNode(TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.javascript = Language(tree_sitter_javascript.language())
+ cls.json = Language(tree_sitter_json.language())
+ cls.python = Language(tree_sitter_python.language())
+
+ def test_child_by_field_id(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"def foo():\n bar()")
+ root_node = tree.root_node
+ fn_node = tree.root_node.children[0]
+
+ self.assertIsNone(self.python.field_id_for_name("noname"))
+ name_field = self.python.field_id_for_name("name")
+ alias_field = self.python.field_id_for_name("alias")
+ self.assertIsNone(root_node.child_by_field_id(alias_field))
+ self.assertIsNone(root_node.child_by_field_id(name_field))
+ self.assertIsNone(fn_node.child_by_field_id(alias_field))
+ self.assertIsNone(fn_node.child_by_field_name("noname"))
+ self.assertEqual(fn_node.child_by_field_name("name"), fn_node.child_by_field_name("name"))
+
+ def test_child_by_field_name(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"while a:\n pass")
+ while_node = tree.root_node.child(0)
+ self.assertIsNotNone(while_node)
+ self.assertEqual(while_node.type, "while_statement")
+ self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3))
+
+ def test_children_by_field_id(self):
+ parser = Parser(self.javascript)
+ tree = parser.parse(b"
")
+ jsx_node = tree.root_node.children[0].children[0]
+ attribute_field = self.javascript.field_id_for_name("attribute")
+ attributes = jsx_node.children_by_field_id(attribute_field)
+ self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
+
+ def test_children_by_field_name(self):
+ parser = Parser(self.javascript)
+ tree = parser.parse(b"")
+ jsx_node = tree.root_node.children[0].children[0]
+ attributes = jsx_node.children_by_field_name("attribute")
+ self.assertListEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
+
+ def test_field_name_for_child(self):
+ parser = Parser(self.javascript)
+ tree = parser.parse(b"")
+ jsx_node = tree.root_node.children[0].children[0]
+
+ self.assertIsNone(jsx_node.field_name_for_child(0))
+ self.assertEqual(jsx_node.field_name_for_child(1), "name")
+
+ def test_root_node_with_offset(self):
+ parser = Parser(self.javascript)
+ tree = parser.parse(b" if (a) b")
+
+ node = tree.root_node_with_offset(6, (2, 2))
+ self.assertIsNotNone(node)
+ self.assertEqual(node.byte_range, (8, 16))
+ self.assertEqual(node.start_point, (2, 4))
+ self.assertEqual(node.end_point, (2, 12))
+
+ child = node.child(0).child(2)
+ self.assertIsNotNone(child)
+ self.assertEqual(child.type, "expression_statement")
+ self.assertEqual(child.byte_range, (15, 16))
+ self.assertEqual(child.start_point, (2, 11))
+ self.assertEqual(child.end_point, (2, 12))
+
+ cursor = node.walk()
+ cursor.goto_first_child()
+ cursor.goto_first_child()
+ cursor.goto_next_sibling()
+ child = cursor.node
+ self.assertIsNotNone(child)
+ self.assertEqual(child.type, "parenthesized_expression")
+ self.assertEqual(child.byte_range, (11, 14))
+ self.assertEqual(child.start_point, (2, 7))
+ self.assertEqual(child.end_point, (2, 10))
+
+ def test_descendant_count(self):
+ parser = Parser(self.json)
+ tree = parser.parse(JSON_EXAMPLE)
+ value_node = tree.root_node
+ all_nodes = get_all_nodes(tree)
+
+ self.assertEqual(value_node.descendant_count, len(all_nodes))
+
+ cursor = value_node.walk()
+ for i, node in enumerate(all_nodes):
+ cursor.goto_descendant(i)
+ self.assertEqual(cursor.node, node, f"index {i}")
+
+ for i, node in reversed(list(enumerate(all_nodes))):
+ cursor.goto_descendant(i)
+ self.assertEqual(cursor.node, node, f"rev index {i}")
+
+ def test_descendant_for_byte_range(self):
+ parser = Parser(self.json)
+ tree = parser.parse(JSON_EXAMPLE)
+ array_node = tree.root_node
+
+ colon_index = JSON_EXAMPLE.index(b":")
+
+ # Leaf node exactly matches the given bounds - byte query
+ colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1)
+ self.assertIsNotNone(colon_node)
+ self.assertEqual(colon_node.type, ":")
+ self.assertEqual(colon_node.start_byte, colon_index)
+ self.assertEqual(colon_node.end_byte, colon_index + 1)
+ self.assertEqual(colon_node.start_point, (6, 7))
+ self.assertEqual(colon_node.end_point, (6, 8))
+
+ # Leaf node exactly matches the given bounds - point query
+ colon_node = array_node.descendant_for_point_range((6, 7), (6, 8))
+ self.assertIsNotNone(colon_node)
+ self.assertEqual(colon_node.type, ":")
+ self.assertEqual(colon_node.start_byte, colon_index)
+ self.assertEqual(colon_node.end_byte, colon_index + 1)
+ self.assertEqual(colon_node.start_point, (6, 7))
+ self.assertEqual(colon_node.end_point, (6, 8))
+
+ # The given point is between two adjacent leaf nodes - byte query
+ colon_node = array_node.descendant_for_byte_range(colon_index, colon_index)
+ self.assertIsNotNone(colon_node)
+ self.assertEqual(colon_node.type, ":")
+ self.assertEqual(colon_node.start_byte, colon_index)
+ self.assertEqual(colon_node.end_byte, colon_index + 1)
+ self.assertEqual(colon_node.start_point, (6, 7))
+ self.assertEqual(colon_node.end_point, (6, 8))
+
+ # The given point is between two adjacent leaf nodes - point query
+ colon_node = array_node.descendant_for_point_range((6, 7), (6, 7))
+ self.assertIsNotNone(colon_node)
+ self.assertEqual(colon_node.type, ":")
+ self.assertEqual(colon_node.start_byte, colon_index)
+ self.assertEqual(colon_node.end_byte, colon_index + 1)
+ self.assertEqual(colon_node.start_point, (6, 7))
+ self.assertEqual(colon_node.end_point, (6, 8))
+
+ # Leaf node starts at the lower bound, ends after the upper bound - byte query
+ string_index = JSON_EXAMPLE.index(b'"x"')
+ string_node = array_node.descendant_for_byte_range(string_index, string_index + 2)
+ self.assertIsNotNone(string_node)
+ self.assertEqual(string_node.type, "string")
+ self.assertEqual(string_node.start_byte, string_index)
+ self.assertEqual(string_node.end_byte, string_index + 3)
+ self.assertEqual(string_node.start_point, (6, 4))
+ self.assertEqual(string_node.end_point, (6, 7))
+
+ # Leaf node starts at the lower bound, ends after the upper bound - point query
+ string_node = array_node.descendant_for_point_range((6, 4), (6, 6))
+ self.assertIsNotNone(string_node)
+ self.assertEqual(string_node.type, "string")
+ self.assertEqual(string_node.start_byte, string_index)
+ self.assertEqual(string_node.end_byte, string_index + 3)
+ self.assertEqual(string_node.start_point, (6, 4))
+ self.assertEqual(string_node.end_point, (6, 7))
+
+ # Leaf node starts before the lower bound, ends at the upper bound - byte query
+ null_index = JSON_EXAMPLE.index(b"null")
+ null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4)
+ self.assertIsNotNone(null_node)
+ self.assertEqual(null_node.type, "null")
+ self.assertEqual(null_node.start_byte, null_index)
+ self.assertEqual(null_node.end_byte, null_index + 4)
+ self.assertEqual(null_node.start_point, (6, 9))
+ self.assertEqual(null_node.end_point, (6, 13))
+
+ # Leaf node starts before the lower bound, ends at the upper bound - point query
+ null_node = array_node.descendant_for_point_range((6, 11), (6, 13))
+ self.assertIsNotNone(null_node)
+ self.assertEqual(null_node.type, "null")
+ self.assertEqual(null_node.start_byte, null_index)
+ self.assertEqual(null_node.end_byte, null_index + 4)
+ self.assertEqual(null_node.start_point, (6, 9))
+ self.assertEqual(null_node.end_point, (6, 13))
+
+ # The bounds span multiple leaf nodes - return the smallest node that does span it.
+ pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4)
+ self.assertIsNotNone(pair_node)
+ self.assertEqual(pair_node.type, "pair")
+ self.assertEqual(pair_node.start_byte, string_index)
+ self.assertEqual(pair_node.end_byte, string_index + 9)
+ self.assertEqual(pair_node.start_point, (6, 4))
+ self.assertEqual(pair_node.end_point, (6, 13))
+
+ self.assertEqual(colon_node.parent, pair_node)
+
+ # No leaf spans the given range - return the smallest node that does span it.
+ pair_node = array_node.descendant_for_point_range((6, 6), (6, 8))
+ self.assertIsNotNone(pair_node)
+ self.assertEqual(pair_node.type, "pair")
+ self.assertEqual(pair_node.start_byte, string_index)
+ self.assertEqual(pair_node.end_byte, string_index + 9)
+ self.assertEqual(pair_node.start_point, (6, 4))
+ self.assertEqual(pair_node.end_point, (6, 13))
+
+ def test_children(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"def foo():\n bar()")
+
+ root_node = tree.root_node
+ self.assertEqual(root_node.type, "module")
+ self.assertEqual(root_node.start_byte, 0)
+ self.assertEqual(root_node.end_byte, 18)
+ self.assertEqual(root_node.start_point, (0, 0))
+ self.assertEqual(root_node.end_point, (1, 7))
+
+ # List object is reused
+ self.assertIs(root_node.children, root_node.children)
+
+ fn_node = root_node.children[0]
+ self.assertEqual(fn_node, root_node.child(0))
+ self.assertEqual(fn_node.type, "function_definition")
+ self.assertEqual(fn_node.start_byte, 0)
+ self.assertEqual(fn_node.end_byte, 18)
+ self.assertEqual(fn_node.start_point, (0, 0))
+ self.assertEqual(fn_node.end_point, (1, 7))
+
+ def_node = fn_node.children[0]
+ self.assertEqual(def_node, fn_node.child(0))
+ self.assertEqual(def_node.type, "def")
+ self.assertEqual(def_node.is_named, False)
+
+ id_node = fn_node.children[1]
+ self.assertEqual(id_node, fn_node.child(1))
+ self.assertEqual(id_node.type, "identifier")
+ self.assertEqual(id_node.is_named, True)
+ self.assertEqual(len(id_node.children), 0)
+
+ params_node = fn_node.children[2]
+ self.assertEqual(params_node, fn_node.child(2))
+ self.assertEqual(params_node.type, "parameters")
+ self.assertEqual(params_node.is_named, True)
+
+ colon_node = fn_node.children[3]
+ self.assertEqual(colon_node, fn_node.child(3))
+ self.assertEqual(colon_node.type, ":")
+ self.assertEqual(colon_node.is_named, False)
+
+ statement_node = fn_node.children[4]
+ self.assertEqual(statement_node, fn_node.child(4))
+ self.assertEqual(statement_node.type, "block")
+ self.assertEqual(statement_node.is_named, True)
+
+ def test_is_extra(self):
+ parser = Parser(self.javascript)
+ tree = parser.parse(b"foo(/* hi */);")
+
+ root_node = tree.root_node
+ comment_node = root_node.descendant_for_byte_range(7, 7)
+ self.assertIsNotNone(comment_node)
+
+ self.assertEqual(root_node.type, "program")
+ self.assertEqual(comment_node.type, "comment")
+ self.assertEqual(root_node.is_extra, False)
+ self.assertEqual(comment_node.is_extra, True)
+
+ def test_properties(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"[1, 2, 3]")
+
+ root_node = tree.root_node
+ self.assertEqual(root_node.type, "module")
+ self.assertEqual(root_node.start_byte, 0)
+ self.assertEqual(root_node.end_byte, 9)
+ self.assertEqual(root_node.start_point, (0, 0))
+ self.assertEqual(root_node.end_point, (0, 9))
+
+ exp_stmt_node = root_node.children[0]
+ self.assertEqual(exp_stmt_node, root_node.child(0))
+ self.assertEqual(exp_stmt_node.type, "expression_statement")
+ self.assertEqual(exp_stmt_node.start_byte, 0)
+ self.assertEqual(exp_stmt_node.end_byte, 9)
+ self.assertEqual(exp_stmt_node.start_point, (0, 0))
+ self.assertEqual(exp_stmt_node.end_point, (0, 9))
+ self.assertEqual(exp_stmt_node.parent, root_node)
+
+ list_node = exp_stmt_node.children[0]
+ self.assertEqual(list_node, exp_stmt_node.child(0))
+ self.assertEqual(list_node.type, "list")
+ self.assertEqual(list_node.start_byte, 0)
+ self.assertEqual(list_node.end_byte, 9)
+ self.assertEqual(list_node.start_point, (0, 0))
+ self.assertEqual(list_node.end_point, (0, 9))
+ self.assertEqual(list_node.parent, exp_stmt_node)
+
+ named_children = list_node.named_children
+
+ open_delim_node = list_node.children[0]
+ self.assertEqual(open_delim_node, list_node.child(0))
+ self.assertEqual(open_delim_node.type, "[")
+ self.assertEqual(open_delim_node.start_byte, 0)
+ self.assertEqual(open_delim_node.end_byte, 1)
+ self.assertEqual(open_delim_node.start_point, (0, 0))
+ self.assertEqual(open_delim_node.end_point, (0, 1))
+ self.assertEqual(open_delim_node.parent, list_node)
+
+ first_num_node = list_node.children[1]
+ self.assertEqual(first_num_node, list_node.child(1))
+ self.assertEqual(first_num_node, open_delim_node.next_named_sibling)
+ self.assertEqual(first_num_node.parent, list_node)
+ self.assertEqual(named_children[0], first_num_node)
+ self.assertEqual(first_num_node, list_node.named_child(0))
+
+ first_comma_node = list_node.children[2]
+ self.assertEqual(first_comma_node, list_node.child(2))
+ self.assertEqual(first_comma_node, first_num_node.next_sibling)
+ self.assertEqual(first_num_node, first_comma_node.prev_sibling)
+ self.assertEqual(first_comma_node.parent, list_node)
+
+ second_num_node = list_node.children[3]
+ self.assertEqual(second_num_node, list_node.child(3))
+ self.assertEqual(second_num_node, first_comma_node.next_sibling)
+ self.assertEqual(second_num_node, first_num_node.next_named_sibling)
+ self.assertEqual(first_num_node, second_num_node.prev_named_sibling)
+ self.assertEqual(second_num_node.parent, list_node)
+ self.assertEqual(named_children[1], second_num_node)
+ self.assertEqual(second_num_node, list_node.named_child(1))
+
+ second_comma_node = list_node.children[4]
+ self.assertEqual(second_comma_node, list_node.child(4))
+ self.assertEqual(second_comma_node, second_num_node.next_sibling)
+ self.assertEqual(second_num_node, second_comma_node.prev_sibling)
+ self.assertEqual(second_comma_node.parent, list_node)
+
+ third_num_node = list_node.children[5]
+ self.assertEqual(third_num_node, list_node.child(5))
+ self.assertEqual(third_num_node, second_comma_node.next_sibling)
+ self.assertEqual(third_num_node, second_num_node.next_named_sibling)
+ self.assertEqual(second_num_node, third_num_node.prev_named_sibling)
+ self.assertEqual(third_num_node.parent, list_node)
+ self.assertEqual(named_children[2], third_num_node)
+ self.assertEqual(third_num_node, list_node.named_child(2))
+
+ close_delim_node = list_node.children[6]
+ self.assertEqual(close_delim_node, list_node.child(6))
+ self.assertEqual(close_delim_node.type, "]")
+ self.assertEqual(close_delim_node.start_byte, 8)
+ self.assertEqual(close_delim_node.end_byte, 9)
+ self.assertEqual(close_delim_node.start_point, (0, 8))
+ self.assertEqual(close_delim_node.end_point, (0, 9))
+ self.assertEqual(close_delim_node, third_num_node.next_sibling)
+ self.assertEqual(third_num_node, close_delim_node.prev_sibling)
+ self.assertEqual(third_num_node, close_delim_node.prev_named_sibling)
+ self.assertEqual(close_delim_node.parent, list_node)
+
+ self.assertEqual(list_node.child_count, 7)
+ self.assertEqual(list_node.named_child_count, 3)
+
+ def test_numeric_symbols_respect_simple_aliases(self):
+ parser = Parser(self.python)
+
+ # Example 1:
+ # Python argument lists can contain "splat" arguments, which are not allowed within
+ # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These
+ # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric
+ # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`.
+ tree = parser.parse(b"(a((*b)))")
+ root_node = tree.root_node
+ self.assertEqual(
+ str(root_node),
+ "(module (expression_statement (parenthesized_expression (call "
+ + "function: (identifier) arguments: (argument_list (parenthesized_expression "
+ + "(list_splat (identifier))))))))",
+ )
+
+ outer_expr_node = root_node.child(0).child(0)
+ self.assertIsNotNone(outer_expr_node)
+ self.assertEqual(outer_expr_node.type, "parenthesized_expression")
+
+ inner_expr_node = (
+ outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0)
+ )
+ self.assertIsNotNone(inner_expr_node)
+
+ self.assertEqual(inner_expr_node.type, "parenthesized_expression")
+ self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id)
+
+ def test_tree(self):
+ code = b"def foo():\n bar()\n\ndef foo():\n bar()"
+ parser = Parser(self.python)
+
+ for item in parser.parse(code).root_node.children:
+ self.assertIsNotNone(item.is_named)
+
+ for item in parser.parse(code).root_node.children:
+ self.assertIsNotNone(item.is_named)
+
+ def test_text(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"[0, [1, 2, 3]]")
+
+ root_node = tree.root_node
+ self.assertEqual(root_node.text, b"[0, [1, 2, 3]]")
+
+ exp_stmt_node = root_node.children[0]
+ self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]")
+
+ list_node = exp_stmt_node.children[0]
+ self.assertEqual(list_node.text, b"[0, [1, 2, 3]]")
+
+ open_delim_node = list_node.children[0]
+ self.assertEqual(open_delim_node.text, b"[")
+
+ first_num_node = list_node.children[1]
+ self.assertEqual(first_num_node.text, b"0")
+
+ first_comma_node = list_node.children[2]
+ self.assertEqual(first_comma_node.text, b",")
+
+ child_list_node = list_node.children[3]
+ self.assertEqual(child_list_node.text, b"[1, 2, 3]")
+
+ close_delim_node = list_node.children[4]
+ self.assertEqual(close_delim_node.text, b"]")
+
+ def test_hash(self):
+ parser = Parser(self.python)
+ source_code = b"def foo():\n bar()\n bar()"
+ tree = parser.parse(source_code)
+ root_node = tree.root_node
+ first_function_node = root_node.children[0]
+ second_function_node = root_node.children[0]
+
+ # Uniqueness and consistency
+ self.assertEqual(hash(first_function_node), hash(first_function_node))
+ self.assertNotEqual(hash(root_node), hash(first_function_node))
+
+ # Equality implication
+ self.assertEqual(hash(first_function_node), hash(second_function_node))
+ self.assertEqual(first_function_node, second_function_node)
+
+ # Different nodes with different properties
+ different_tree = parser.parse(b"def baz():\n qux()")
+ different_node = different_tree.root_node.children[0]
+ self.assertNotEqual(hash(first_function_node), hash(different_node))
+
+ # Same code, different parse trees
+ another_tree = parser.parse(source_code)
+ another_node = another_tree.root_node.children[0]
+ self.assertNotEqual(hash(first_function_node), hash(another_node))
diff --git a/tests/test_parser.py b/tests/test_parser.py
new file mode 100644
index 00000000..2c921ceb
--- /dev/null
+++ b/tests/test_parser.py
@@ -0,0 +1,426 @@
+from unittest import TestCase
+
+from tree_sitter import Language, Parser, Range, Tree
+
+import tree_sitter_html
+import tree_sitter_javascript
+import tree_sitter_json
+import tree_sitter_python
+import tree_sitter_rust
+
+
+def simple_range(start, end):
+ return Range((0, start), (0, end), start, end)
+
+
+class TestParser(TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.html = Language(tree_sitter_html.language())
+ cls.python = Language(tree_sitter_python.language())
+ cls.javascript = Language(tree_sitter_javascript.language())
+ cls.json = Language(tree_sitter_json.language())
+ cls.rust = Language(tree_sitter_rust.language())
+ cls.max_range = Range((0, 0), (0xFFFFFFFF, 0xFFFFFFFF), 0, 0xFFFFFFFF)
+ cls.min_range = Range((0, 0), (0, 1), 0, 1)
+ cls.timeout = 1000
+
+ def test_init_no_args(self):
+ parser = Parser()
+ self.assertIsNone(parser.language)
+ self.assertListEqual(parser.included_ranges, [self.max_range])
+ self.assertEqual(parser.timeout_micros, 0)
+
+ def test_init_args(self):
+ parser = Parser(
+ language=self.python, included_ranges=[self.min_range], timeout_micros=self.timeout
+ )
+ self.assertEqual(parser.language, self.python)
+ self.assertListEqual(parser.included_ranges, [self.min_range])
+ self.assertEqual(parser.timeout_micros, self.timeout)
+
+ def test_setters(self):
+ parser = Parser()
+
+ with self.subTest(setter="language"):
+ parser.language = self.python
+ self.assertEqual(parser.language, self.python)
+
+ with self.subTest(setter="included_ranges"):
+ parser.included_ranges = [self.min_range]
+ self.assertListEqual(parser.included_ranges, [self.min_range])
+ with self.assertRaises(ValueError):
+ parser.included_ranges = [
+ Range(
+ start_byte=23,
+ end_byte=29,
+ start_point=(0, 23),
+ end_point=(0, 29),
+ ),
+ Range(
+ start_byte=0,
+ end_byte=5,
+ start_point=(0, 0),
+ end_point=(0, 5),
+ ),
+ Range(
+ start_byte=50,
+ end_byte=60,
+ start_point=(0, 50),
+ end_point=(0, 60),
+ ),
+ ]
+ with self.assertRaises(ValueError):
+ parser.included_ranges = [
+ Range(
+ start_byte=10,
+ end_byte=5,
+ start_point=(0, 10),
+ end_point=(0, 5),
+ )
+ ]
+
+ with self.subTest(setter="timeout_micros"):
+ parser.timeout_micros = self.timeout
+ self.assertEqual(parser.timeout_micros, self.timeout)
+
+ def test_deleters(self):
+ parser = Parser()
+
+ with self.subTest(deleter="language"):
+ del parser.language
+ self.assertIsNone(parser.language)
+
+ with self.subTest(deleter="included_ranges"):
+ del parser.included_ranges
+ self.assertListEqual(parser.included_ranges, [self.max_range])
+
+ with self.subTest(setter="timeout_micros"):
+ del parser.timeout_micros
+ self.assertEqual(parser.timeout_micros, 0)
+
+ def test_parse_buffer(self):
+ parser = Parser(self.javascript)
+ with self.subTest(type="bytes"):
+ self.assertIsInstance(parser.parse(b"test"), Tree)
+ with self.subTest(type="memoryview"):
+ self.assertIsInstance(parser.parse(memoryview(b"test")), Tree)
+ with self.subTest(type="bytearray"):
+ self.assertIsInstance(parser.parse(bytearray(b"test")), Tree)
+
+ def test_parse_callback(self):
+ parser = Parser(self.python)
+ source_lines = ["def foo():\n", " bar()"]
+
+ def read_callback(_, point):
+ row, column = point
+ if row >= len(source_lines):
+ return None
+ if column >= len(source_lines[row]):
+ return None
+ return source_lines[row][column:].encode("utf8")
+
+ tree = parser.parse(read_callback)
+ self.assertEqual(
+ str(tree.root_node),
+ "(module (function_definition"
+ + " name: (identifier)"
+ + " parameters: (parameters)"
+ + " body: (block (expression_statement (call"
+ + " function: (identifier)"
+ + " arguments: (argument_list))))))",
+ )
+
+ def test_parse_with_one_included_range(self):
+ source_code = b"hi"
+ parser = Parser(self.html)
+ html_tree = parser.parse(source_code)
+ script_content_node = html_tree.root_node.child(1).child(1)
+ self.assertIsNotNone(script_content_node)
+ self.assertEqual(script_content_node.type, "raw_text")
+
+ parser.included_ranges = [script_content_node.range]
+ parser.language = self.javascript
+ js_tree = parser.parse(source_code)
+ self.assertEqual(
+ str(js_tree.root_node),
+ "(program (expression_statement (call_expression"
+ + " function: (member_expression object: (identifier) property: (property_identifier))"
+ + " arguments: (arguments (string (string_fragment))))))",
+ )
+ self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console")))
+ self.assertEqual(js_tree.included_ranges, [script_content_node.range])
+
+ def test_parse_with_multiple_included_ranges(self):
+ source_code = b"html `Hello, ${name.toUpperCase()}, it's ${now()}.
`"
+
+ parser = Parser(self.javascript)
+ js_tree = parser.parse(source_code)
+ template_string_node = js_tree.root_node.descendant_for_byte_range(
+ source_code.index(b"`<"), source_code.index(b">`")
+ )
+ self.assertIsNotNone(template_string_node)
+
+ self.assertEqual(template_string_node.type, "template_string")
+
+ open_quote_node = template_string_node.child(0)
+ self.assertIsNotNone(open_quote_node)
+ interpolation_node1 = template_string_node.child(2)
+ self.assertIsNotNone(interpolation_node1)
+ interpolation_node2 = template_string_node.child(4)
+ self.assertIsNotNone(interpolation_node2)
+ close_quote_node = template_string_node.child(6)
+ self.assertIsNotNone(close_quote_node)
+
+ html_ranges = [
+ Range(
+ start_byte=open_quote_node.end_byte,
+ start_point=open_quote_node.end_point,
+ end_byte=interpolation_node1.start_byte,
+ end_point=interpolation_node1.start_point,
+ ),
+ Range(
+ start_byte=interpolation_node1.end_byte,
+ start_point=interpolation_node1.end_point,
+ end_byte=interpolation_node2.start_byte,
+ end_point=interpolation_node2.start_point,
+ ),
+ Range(
+ start_byte=interpolation_node2.end_byte,
+ start_point=interpolation_node2.end_point,
+ end_byte=close_quote_node.start_byte,
+ end_point=close_quote_node.start_point,
+ ),
+ ]
+ parser.included_ranges = html_ranges
+ parser.language = self.html
+ html_tree = parser.parse(source_code)
+
+ self.assertEqual(
+ str(html_tree.root_node),
+ "(document (element"
+ + " (start_tag (tag_name))"
+ + " (text)"
+ + " (element (start_tag (tag_name)) (end_tag (tag_name)))"
+ + " (text)"
+ + " (end_tag (tag_name))))"
+ )
+ self.assertEqual(html_tree.included_ranges, html_ranges)
+
+ div_element_node = html_tree.root_node.child(0)
+ self.assertIsNotNone(div_element_node)
+ hello_text_node = div_element_node.child(1)
+ self.assertIsNotNone(hello_text_node)
+ b_element_node = div_element_node.child(2)
+ self.assertIsNotNone(b_element_node)
+ b_start_tag_node = b_element_node.child(0)
+ self.assertIsNotNone(b_start_tag_node)
+ b_end_tag_node = b_element_node.child(1)
+ self.assertIsNotNone(b_end_tag_node)
+
+ self.assertEqual(hello_text_node.type, "text")
+ self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello"))
+ self.assertEqual(hello_text_node.end_byte, source_code.index(b" "))
+
+ self.assertEqual(b_start_tag_node.type, "start_tag")
+ self.assertEqual(b_start_tag_node.start_byte, source_code.index(b""))
+ self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}"))
+
+ self.assertEqual(b_end_tag_node.type, "end_tag")
+ self.assertEqual(b_end_tag_node.start_byte, source_code.index(b""))
+ self.assertEqual(b_end_tag_node.end_byte, source_code.index(b"."))
+
+ def test_parse_with_included_range_containing_mismatched_positions(self):
+ source_code = b"test
{_ignore_this_part_}"
+ end_byte = source_code.index(b"{_ignore_this_part_")
+
+ range_to_parse = Range(
+ start_byte=0,
+ start_point=(10, 12),
+ end_byte=end_byte,
+ end_point=(10, 12 + end_byte),
+ )
+
+ parser = Parser(self.html, included_ranges=[range_to_parse])
+ html_tree = parser.parse(source_code)
+
+ self.assertEqual(
+ str(html_tree.root_node),
+ "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))"
+ )
+
+ def test_parse_with_included_range_boundaries(self):
+ source_code = b"a <%= b() %> c <% d() %>"
+ range1_start_byte = source_code.index(b" b() ")
+ range1_end_byte = range1_start_byte + len(b" b() ")
+ range2_start_byte = source_code.index(b" d() ")
+ range2_end_byte = range2_start_byte + len(b" d() ")
+
+ parser = Parser(self.javascript, included_ranges=[
+ Range(
+ start_byte=range1_start_byte,
+ end_byte=range1_end_byte,
+ start_point=(0, range1_start_byte),
+ end_point=(0, range1_end_byte),
+ ),
+ Range(
+ start_byte=range2_start_byte,
+ end_byte=range2_end_byte,
+ start_point=(0, range2_start_byte),
+ end_point=(0, range2_end_byte),
+ )
+ ])
+
+ tree = parser.parse(source_code)
+ root = tree.root_node
+ statement1 = root.child(0)
+ self.assertIsNotNone(statement1)
+ statement2 = root.child(1)
+ self.assertIsNotNone(statement2)
+
+ self.assertEqual(
+ str(root),
+ "(program" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments)))" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments))))"
+ )
+
+ self.assertEqual(statement1.start_byte, source_code.index(b"b()"))
+ self.assertEqual(statement1.end_byte, source_code.find(b" %> c"))
+ self.assertEqual(statement2.start_byte, source_code.find(b"d()"))
+ self.assertEqual(statement2.end_byte, len(source_code) - len(" %>"))
+
+ def test_parse_with_a_newly_excluded_range(self):
+ source_code = b"<%= something %>
"
+
+ # Parse HTML including the template directive, which will cause an error
+ parser = Parser(self.html)
+ first_tree = parser.parse(source_code)
+
+ prefix = b"a very very long line of plain text. "
+ first_tree.edit(
+ start_byte=0,
+ old_end_byte=0,
+ new_end_byte=len(prefix),
+ start_point=(0, 0),
+ old_end_point=(0, 0),
+ new_end_point=(0, len(prefix)),
+ )
+ source_code = prefix + source_code
+
+ # Parse the HTML again, this time *excluding* the template directive
+ # (which has moved since the previous parse).
+ directive_start = source_code.index(b"<%=")
+ directive_end = source_code.index(b"")
+ source_code_end = len(source_code)
+ parser.included_ranges = [
+ Range(
+ start_byte=0,
+ end_byte=directive_start,
+ start_point=(0, 0),
+ end_point=(0, directive_start),
+ ),
+ Range(
+ start_byte=directive_end,
+ end_byte=source_code_end,
+ start_point=(0, directive_end),
+ end_point=(0, source_code_end),
+ ),
+ ]
+
+ tree = parser.parse(source_code, first_tree)
+
+ self.assertEqual(
+ str(tree.root_node),
+ "(document (text) (element" +
+ " (start_tag (tag_name))" +
+ " (element (start_tag (tag_name)) (end_tag (tag_name)))" +
+ " (end_tag (tag_name))))"
+ )
+
+ self.assertEqual(
+ tree.changed_ranges(first_tree),
+ [
+ # The first range that has changed syntax is the range of the newly-inserted text.
+ Range(
+ start_byte=0,
+ end_byte=len(prefix),
+ start_point=(0, 0),
+ end_point=(0, len(prefix)),
+ ),
+ # Even though no edits were applied to the outer `div` element,
+ # its contents have changed syntax because a range of text that
+ # was previously included is now excluded.
+ Range(
+ start_byte=directive_start,
+ end_byte=directive_end,
+ start_point=(0, directive_start),
+ end_point=(0, directive_end),
+ )
+ ]
+ )
+
+ def test_parsing_with_a_newly_included_range(self):
+ source_code = b"<%= foo() %>
<%= bar() %><%= baz() %>"
+ range1_start = source_code.index(b" foo")
+ range2_start = source_code.index(b" bar")
+ range3_start = source_code.index(b" baz")
+ range1_end = range1_start + 7
+ range2_end = range2_start + 7
+ range3_end = range3_start + 7
+
+ # Parse only the first code directive as JavaScript
+ parser = Parser(self.javascript)
+ parser.included_ranges = [simple_range(range1_start, range1_end)]
+ tree = parser.parse(source_code)
+ self.assertEqual(
+ str(tree.root_node),
+ "(program" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments))))"
+ )
+
+ # Parse both the first and third code directives as JavaScript, using the old tree as a
+ # reference.
+ parser.included_ranges = [
+ simple_range(range1_start, range1_end),
+ simple_range(range3_start, range3_end),
+ ]
+ tree2 = parser.parse(source_code)
+ self.assertEqual(
+ str(tree2.root_node),
+ "(program" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments)))" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments))))"
+ )
+ self.assertEqual(
+ tree2.changed_ranges(tree),
+ [simple_range(range1_end, range3_end)]
+ )
+
+ # Parse all three code directives as JavaScript, using the old tree as a
+ # reference.
+ parser.included_ranges = [
+ simple_range(range1_start, range1_end),
+ simple_range(range2_start, range2_end),
+ simple_range(range3_start, range3_end),
+ ]
+ tree3 = parser.parse(source_code)
+ self.assertEqual(
+ str(tree3.root_node),
+ "(program" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments)))" +
+ " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments)))"
+ + " (expression_statement (call_expression" +
+ " function: (identifier) arguments: (arguments))))"
+ )
+ self.assertEqual(
+ tree3.changed_ranges(tree2),
+ [simple_range(range2_start + 1, range2_end - 1)],
+ )
diff --git a/tests/test_query.py b/tests/test_query.py
new file mode 100644
index 00000000..1c14a2c4
--- /dev/null
+++ b/tests/test_query.py
@@ -0,0 +1,509 @@
+from unittest import TestCase
+
+import tree_sitter_python
+import tree_sitter_javascript
+
+from tree_sitter import Language, Parser, Query
+
+
+def collect_matches(matches):
+ return [(m[0], format_captures(m[1])) for m in matches]
+
+
+def format_captures(captures):
+ return [(name, format_capture(capture)) for name, capture in captures.items()]
+
+
+def format_capture(capture):
+ return (
+ [n.text.decode("utf-8") for n in capture]
+ if isinstance(capture, list)
+ else capture.text.decode("utf-8")
+ )
+
+
+class TestQuery(TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.javascript = Language(tree_sitter_javascript.language())
+ cls.python = Language(tree_sitter_python.language())
+
+ def assert_query_matches(self, language, query, source, expected):
+ parser = Parser(language)
+ tree = parser.parse(source)
+ matches = language.query(query).matches(tree.root_node)
+ matches = collect_matches(matches)
+ self.assertEqual(matches, expected)
+
+ def test_errors(self):
+ with self.assertRaises(NameError, msg="Invalid node type foo"):
+ Query(self.python, "(list (foo))")
+ with self.assertRaises(NameError, msg="Invalid field name buzz"):
+ Query(self.python, "(function_definition buzz: (identifier))")
+ with self.assertRaises(NameError, msg="Invalid capture name garbage"):
+ Query(self.python, "((function_definition) (#eq? @garbage foo))")
+ with self.assertRaises(SyntaxError, msg="Invalid syntax at offset 6"):
+ Query(self.python, "(list))")
+
+ def test_matches_with_simple_pattern(self):
+ self.assert_query_matches(
+ self.javascript,
+ "(function_declaration name: (identifier) @fn-name)",
+ b"function one() { two(); function three() {} }",
+ [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])],
+ )
+
+ def test_matches_with_multiple_on_same_root(self):
+ self.assert_query_matches(
+ self.javascript,
+ """
+ (class_declaration
+ name: (identifier) @the-class-name
+ (class_body
+ (method_definition
+ name: (property_identifier) @the-method-name)))
+ """,
+ b"""
+ class Person {
+ // the constructor
+ constructor(name) { this.name = name; }
+
+ // the getter
+ getFullName() { return this.name; }
+ }
+ """,
+ [
+ (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]),
+ (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]),
+ ],
+ )
+
+ def test_matches_with_multiple_patterns_different_roots(self):
+ self.assert_query_matches(
+ self.javascript,
+ """
+ (function_declaration name: (identifier) @fn-def)
+ (call_expression function: (identifier) @fn-ref)
+ """,
+ b"""
+ function f1() {
+ f2(f3());
+ }
+ """,
+ [
+ (0, [("fn-def", "f1")]),
+ (1, [("fn-ref", "f2")]),
+ (1, [("fn-ref", "f3")]),
+ ],
+ )
+
+ def test_matches_with_nesting_and_no_fields(self):
+ self.assert_query_matches(
+ self.javascript,
+ "(array (array (identifier) @x1 (identifier) @x2))",
+ b"""
+ [[a]];
+ [[c, d], [e, f, g, h]];
+ [[h], [i]];
+ """,
+ [
+ (0, [("x1", "c"), ("x2", "d")]),
+ (0, [("x1", "e"), ("x2", "f")]),
+ (0, [("x1", "e"), ("x2", "g")]),
+ (0, [("x1", "f"), ("x2", "g")]),
+ (0, [("x1", "e"), ("x2", "h")]),
+ (0, [("x1", "f"), ("x2", "h")]),
+ (0, [("x1", "g"), ("x2", "h")]),
+ ],
+ )
+
+ def test_matches_with_list_capture(self):
+ self.assert_query_matches(
+ self.javascript,
+ """
+ (function_declaration
+ name: (identifier) @fn-name
+ body: (statement_block (_)* @fn-statements))
+ """,
+ b"""function one() {
+ x = 1;
+ y = 2;
+ z = 3;
+ }
+ function two() {
+ x = 1;
+ }
+ """,
+ [
+ (
+ 0,
+ [
+ ("fn-name", "one"),
+ ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]),
+ ],
+ ),
+ (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]),
+ ],
+ )
+
+ def test_captures(self):
+ parser = Parser(self.python)
+ source = b"def foo():\n bar()\ndef baz():\n quux()\n"
+ tree = parser.parse(source)
+ query = self.python.query(
+ """
+ (function_definition name: (identifier) @func-def)
+ (call function: (identifier) @func-call)
+ """
+ )
+
+ captures = query.captures(tree.root_node)
+
+ self.assertEqual(captures[0][0].start_point, (0, 4))
+ self.assertEqual(captures[0][0].end_point, (0, 7))
+ self.assertEqual(captures[0][1], "func-def")
+
+ self.assertEqual(captures[1][0].start_point, (1, 2))
+ self.assertEqual(captures[1][0].end_point, (1, 5))
+ self.assertEqual(captures[1][1], "func-call")
+
+ self.assertEqual(captures[2][0].start_point, (2, 4))
+ self.assertEqual(captures[2][0].end_point, (2, 7))
+ self.assertEqual(captures[2][1], "func-def")
+
+ self.assertEqual(captures[3][0].start_point, (3, 2))
+ self.assertEqual(captures[3][0].end_point, (3, 6))
+ self.assertEqual(captures[3][1], "func-call")
+
+ def test_text_predicates(self):
+ parser = Parser(self.javascript)
+ source = b"""
+ keypair_object = {
+ key1: value1,
+ equal: equal
+ }
+
+ function fun1(arg) {
+ return 1;
+ }
+
+ function fun2(arg) {
+ return 2;
+ }
+ """
+ tree = parser.parse(source)
+ root_node = tree.root_node
+
+ # function with name equal to 'fun1' -> test for #eq? @capture string
+ query1 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#eq? @function-name fun1))
+ """
+ )
+ captures1 = query1.captures(root_node)
+ self.assertEqual(1, len(captures1))
+ self.assertEqual(b"fun1", captures1[0][0].text)
+ self.assertEqual("function-name", captures1[0][1])
+
+ # functions with name not equal to 'fun1' -> test for #not-eq? @capture string
+ query2 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#not-eq? @function-name fun1))
+ """
+ )
+ captures2 = query2.captures(root_node)
+ self.assertEqual(1, len(captures2))
+ self.assertEqual(b"fun2", captures2[0][0].text)
+ self.assertEqual("function-name", captures2[0][1])
+
+ # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2
+ query3 = self.javascript.query(
+ """
+ ((pair
+ key: (property_identifier) @key-name
+ value: (identifier) @value-name)
+ (#eq? @key-name @value-name))
+ """
+ )
+ captures3 = query3.captures(root_node)
+ self.assertEqual(2, len(captures3))
+ self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3]))
+ self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3]))
+
+ # key pairs whose key is not equal to its value
+ # -> test for #not-eq? @capture1 @capture2
+ query4 = self.javascript.query(
+ """
+ ((pair
+ key: (property_identifier) @key-name
+ value: (identifier) @value-name)
+ (#not-eq? @key-name @value-name))
+ """
+ )
+ captures4 = query4.captures(root_node)
+ self.assertEqual(2, len(captures4))
+ self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4]))
+ self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4]))
+
+ # equality that is satisfied by *another* capture
+ query5 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name
+ parameters: (formal_parameters (identifier) @parameter-name))
+ (#eq? @function-name arg))
+ """
+ )
+ captures5 = query5.captures(root_node)
+ self.assertEqual(0, len(captures5))
+
+ # functions that match the regex .*1 -> test for #match @capture regex
+ query6 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#match? @function-name ".*1"))
+ """
+ )
+ captures6 = query6.captures(root_node)
+ self.assertEqual(1, len(captures6))
+ self.assertEqual(b"fun1", captures6[0][0].text)
+
+ # functions that do not match the regex .*1 -> test for #not-match @capture regex
+ query6 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#not-match? @function-name ".*1"))
+ """
+ )
+ captures6 = query6.captures(root_node)
+ self.assertEqual(1, len(captures6))
+ self.assertEqual(b"fun2", captures6[0][0].text)
+
+ # after editing there is no text property, so predicates are ignored
+ tree.edit(
+ start_byte=0,
+ old_end_byte=0,
+ new_end_byte=2,
+ start_point=(0, 0),
+ old_end_point=(0, 0),
+ new_end_point=(0, 2),
+ )
+ captures_notext = query1.captures(root_node)
+ self.assertEqual(2, len(captures_notext))
+ self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext]))
+
+ def test_text_predicate_on_optional_capture(self):
+ parser = Parser(self.javascript)
+ source = b"fun1(1)"
+ tree = parser.parse(source)
+ root_node = tree.root_node
+
+ # optional capture that is missing in source used in #eq? @capture string
+ query1 = self.javascript.query(
+ """
+ ((call_expression
+ function: (identifier) @function-name
+ arguments: (arguments (string)? @optional-string-arg)
+ (#eq? @optional-string-arg "1")))
+ """
+ )
+ captures1 = query1.captures(root_node)
+ self.assertEqual(1, len(captures1))
+ self.assertEqual(b"fun1", captures1[0][0].text)
+ self.assertEqual("function-name", captures1[0][1])
+
+ # optional capture that is missing in source used in #eq? @capture @capture
+ query2 = self.javascript.query(
+ """
+ ((call_expression
+ function: (identifier) @function-name
+ arguments: (arguments (string)? @optional-string-arg)
+ (#eq? @optional-string-arg @function-name)))
+ """
+ )
+ captures2 = query2.captures(root_node)
+ self.assertEqual(1, len(captures2))
+ self.assertEqual(b"fun1", captures2[0][0].text)
+ self.assertEqual("function-name", captures2[0][1])
+
+ # optional capture that is missing in source used in #match? @capture string
+ query3 = self.javascript.query(
+ """
+ ((call_expression
+ function: (identifier) @function-name
+ arguments: (arguments (string)? @optional-string-arg)
+ (#match? @optional-string-arg "\\d+")))
+ """
+ )
+ captures3 = query3.captures(root_node)
+ self.assertEqual(1, len(captures3))
+ self.assertEqual(b"fun1", captures3[0][0].text)
+ self.assertEqual("function-name", captures3[0][1])
+
+ def test_text_predicates_errors(self):
+ with self.assertRaises(RuntimeError):
+ self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#eq? @function-name @function-name fun1))
+ """
+ )
+
+ with self.assertRaises(RuntimeError):
+ self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#eq? fun1 @function-name))
+ """
+ )
+
+ with self.assertRaises(RuntimeError):
+ self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#match? @function-name @function-name fun1))
+ """
+ )
+
+ with self.assertRaises(RuntimeError):
+ self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#match? fun1 @function-name))
+ """
+ )
+
+ with self.assertRaises(RuntimeError):
+ self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name)
+ (#match? @function-name @function-name))
+ """
+ )
+
+ def test_multiple_text_predicates(self):
+ parser = Parser(self.javascript)
+ source = b"""
+ keypair_object = {
+ key1: value1,
+ equal: equal
+ }
+
+ function fun1(arg) {
+ return 1;
+ }
+
+ function fun1(notarg) {
+ return 1 + 1;
+ }
+
+ function fun2(arg) {
+ return 2;
+ }
+ """
+ tree = parser.parse(source)
+ root_node = tree.root_node
+
+ # function with name equal to 'fun1' -> test for first #eq? @capture string
+ query1 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name
+ parameters: (formal_parameters
+ (identifier) @argument-name))
+ (#eq? @function-name fun1))
+ """
+ )
+ captures1 = query1.captures(root_node)
+ self.assertEqual(4, len(captures1))
+ self.assertEqual(b"fun1", captures1[0][0].text)
+ self.assertEqual("function-name", captures1[0][1])
+ self.assertEqual(b"arg", captures1[1][0].text)
+ self.assertEqual("argument-name", captures1[1][1])
+ self.assertEqual(b"fun1", captures1[2][0].text)
+ self.assertEqual("function-name", captures1[2][1])
+ self.assertEqual(b"notarg", captures1[3][0].text)
+ self.assertEqual("argument-name", captures1[3][1])
+
+ # function with argument equal to 'arg' -> test for second #eq? @capture string
+ query2 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name
+ parameters: (formal_parameters
+ (identifier) @argument-name))
+ (#eq? @argument-name arg))
+ """
+ )
+ captures2 = query2.captures(root_node)
+ self.assertEqual(4, len(captures2))
+ self.assertEqual(b"fun1", captures2[0][0].text)
+ self.assertEqual("function-name", captures2[0][1])
+ self.assertEqual(b"arg", captures2[1][0].text)
+ self.assertEqual("argument-name", captures2[1][1])
+ self.assertEqual(b"fun2", captures2[2][0].text)
+ self.assertEqual("function-name", captures2[2][1])
+ self.assertEqual(b"arg", captures2[3][0].text)
+ self.assertEqual("argument-name", captures2[3][1])
+
+ # function with name equal to 'fun1' & argument 'arg' -> test for both together
+ query3 = self.javascript.query(
+ """
+ ((function_declaration
+ name: (identifier) @function-name
+ parameters: (formal_parameters
+ (identifier) @argument-name))
+ (#eq? @function-name fun1)
+ (#eq? @argument-name arg))
+ """
+ )
+ captures3 = query3.captures(root_node)
+ self.assertEqual(2, len(captures3))
+ self.assertEqual(b"fun1", captures3[0][0].text)
+ self.assertEqual("function-name", captures3[0][1])
+ self.assertEqual(b"arg", captures3[1][0].text)
+ self.assertEqual("argument-name", captures3[1][1])
+
+ def test_point_range_captures(self):
+ parser = Parser(self.python)
+ source = b"def foo():\n bar()\ndef baz():\n quux()\n"
+ tree = parser.parse(source)
+ query = self.python.query(
+ """
+ (function_definition name: (identifier) @func-def)
+ (call function: (identifier) @func-call)
+ """
+ )
+
+ captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0))
+
+ self.assertEqual(captures[0][0].start_point, (1, 2))
+ self.assertEqual(captures[0][0].end_point, (1, 5))
+ self.assertEqual(captures[0][1], "func-call")
+
+ def test_byte_range_captures(self):
+ parser = Parser(self.python)
+ source = b"def foo():\n bar()\ndef baz():\n quux()\n"
+ tree = parser.parse(source)
+ query = self.python.query(
+ """
+ (function_definition name: (identifier) @func-def)
+ (call function: (identifier) @func-call)
+ """
+ )
+
+ captures = query.captures(tree.root_node, start_byte=10, end_byte=20)
+ self.assertEqual(captures[0][0].start_point, (1, 2))
+ self.assertEqual(captures[0][0].end_point, (1, 5))
+ self.assertEqual(captures[0][1], "func-call")
diff --git a/tests/test_tree.py b/tests/test_tree.py
new file mode 100644
index 00000000..df7b8ae5
--- /dev/null
+++ b/tests/test_tree.py
@@ -0,0 +1,152 @@
+from unittest import TestCase
+
+from tree_sitter import Language, Parser
+
+import tree_sitter_python
+import tree_sitter_rust
+
+
+class TestTree(TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.python = Language(tree_sitter_python.language())
+ cls.rust = Language(tree_sitter_rust.language())
+
+ def test_edit(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"def foo():\n bar()")
+
+ edit_offset = len(b"def foo(")
+ tree.edit(
+ start_byte=edit_offset,
+ old_end_byte=edit_offset,
+ new_end_byte=edit_offset + 2,
+ start_point=(0, edit_offset),
+ old_end_point=(0, edit_offset),
+ new_end_point=(0, edit_offset + 2),
+ )
+
+ fn_node = tree.root_node.children[0]
+ self.assertEqual(fn_node.type, "function_definition")
+ self.assertTrue(fn_node.has_changes)
+ self.assertFalse(fn_node.children[0].has_changes)
+ self.assertFalse(fn_node.children[1].has_changes)
+ self.assertFalse(fn_node.children[3].has_changes)
+
+ params_node = fn_node.children[2]
+ self.assertEqual(params_node.type, "parameters")
+ self.assertTrue(params_node.has_changes)
+ self.assertEqual(params_node.start_point, (0, edit_offset - 1))
+ self.assertEqual(params_node.end_point, (0, edit_offset + 3))
+
+ new_tree = parser.parse(b"def foo(ab):\n bar()", tree)
+ self.assertEqual(
+ str(new_tree.root_node),
+ "(module (function_definition"
+ + " name: (identifier)"
+ + " parameters: (parameters (identifier))"
+ + " body: (block"
+ + " (expression_statement (call"
+ + " function: (identifier)"
+ + " arguments: (argument_list))))))",
+ )
+
+ def test_changed_ranges(self):
+ parser = Parser(self.python)
+ tree = parser.parse(b"def foo():\n bar()")
+
+ edit_offset = len(b"def foo(")
+ tree.edit(
+ start_byte=edit_offset,
+ old_end_byte=edit_offset,
+ new_end_byte=edit_offset + 2,
+ start_point=(0, edit_offset),
+ old_end_point=(0, edit_offset),
+ new_end_point=(0, edit_offset + 2),
+ )
+
+ new_tree = parser.parse(b"def foo(ab):\n bar()", tree)
+ changed_ranges = tree.changed_ranges(new_tree)
+
+ self.assertEqual(len(changed_ranges), 1)
+ self.assertEqual(changed_ranges[0].start_byte, edit_offset)
+ self.assertEqual(changed_ranges[0].start_point, (0, edit_offset))
+ self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2)
+ self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2))
+
+ def test_walk(self):
+ parser = Parser(self.rust)
+
+ tree = parser.parse(
+ b"""
+ struct Stuff {
+ a: A,
+ b: Option,
+ }
+ """
+ )
+
+ cursor = tree.walk()
+
+ # Node always returns the same instance
+ self.assertIs(cursor.node, cursor.node)
+
+ self.assertEqual(cursor.node.type, "source_file")
+
+ self.assertEqual(cursor.goto_first_child(), True)
+ self.assertEqual(cursor.node.type, "struct_item")
+
+ self.assertEqual(cursor.goto_first_child(), True)
+ self.assertEqual(cursor.node.type, "struct")
+ self.assertEqual(cursor.node.is_named, False)
+
+ self.assertEqual(cursor.goto_next_sibling(), True)
+ self.assertEqual(cursor.node.type, "type_identifier")
+ self.assertEqual(cursor.node.is_named, True)
+
+ self.assertEqual(cursor.goto_next_sibling(), True)
+ self.assertEqual(cursor.node.type, "field_declaration_list")
+ self.assertEqual(cursor.node.is_named, True)
+
+ self.assertEqual(cursor.goto_last_child(), True)
+ self.assertEqual(cursor.node.type, "}")
+ self.assertEqual(cursor.node.is_named, False)
+ self.assertEqual(cursor.node.start_point, (4, 16))
+
+ self.assertEqual(cursor.goto_previous_sibling(), True)
+ self.assertEqual(cursor.node.type, ",")
+ self.assertEqual(cursor.node.is_named, False)
+ self.assertEqual(cursor.node.start_point, (3, 32))
+
+ self.assertEqual(cursor.goto_previous_sibling(), True)
+ self.assertEqual(cursor.node.type, "field_declaration")
+ self.assertEqual(cursor.node.is_named, True)
+ self.assertEqual(cursor.node.start_point, (3, 20))
+
+ self.assertEqual(cursor.goto_previous_sibling(), True)
+ self.assertEqual(cursor.node.type, ",")
+ self.assertEqual(cursor.node.is_named, False)
+ self.assertEqual(cursor.node.start_point, (2, 24))
+
+ self.assertEqual(cursor.goto_previous_sibling(), True)
+ self.assertEqual(cursor.node.type, "field_declaration")
+ self.assertEqual(cursor.node.is_named, True)
+ self.assertEqual(cursor.node.start_point, (2, 20))
+
+ self.assertEqual(cursor.goto_previous_sibling(), True)
+ self.assertEqual(cursor.node.type, "{")
+ self.assertEqual(cursor.node.is_named, False)
+ self.assertEqual(cursor.node.start_point, (1, 29))
+
+ copy = tree.walk()
+ copy.reset_to(cursor)
+
+ self.assertEqual(copy.node.type, "{")
+ self.assertEqual(copy.node.is_named, False)
+
+ self.assertEqual(copy.goto_parent(), True)
+ self.assertEqual(copy.node.type, "field_declaration_list")
+ self.assertEqual(copy.node.is_named, True)
+
+ self.assertEqual(copy.goto_parent(), True)
+ self.assertEqual(copy.node.type, "struct_item")
diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py
deleted file mode 100644
index 0826af32..00000000
--- a/tests/test_tree_sitter.py
+++ /dev/null
@@ -1,1906 +0,0 @@
-import re
-from os import path
-from typing import Dict, List, Optional, Tuple, Union
-from unittest import TestCase
-
-from tree_sitter import Language, LookaheadIterator, Node, Parser, Query, Range, Tree
-
-LIB_PATH = path.join("build", "languages.so")
-
-# cibuildwheel uses a funny working directory when running tests.
-# This is by design, this way tests import whatever is installed and not from the project.
-#
-# The languages binary is still relative to current working directory to prevent reusing
-# a 32-bit languages binary in a 64-bit build. The working directory is clean every time.
-project_root = path.dirname(path.dirname(path.abspath(__file__)))
-Language.build_library(
- LIB_PATH,
- [
- path.join(project_root, "tests", "fixtures", "tree-sitter-embedded-template"),
- path.join(project_root, "tests", "fixtures", "tree-sitter-html"),
- path.join(project_root, "tests", "fixtures", "tree-sitter-javascript"),
- path.join(project_root, "tests", "fixtures", "tree-sitter-json"),
- path.join(project_root, "tests", "fixtures", "tree-sitter-python"),
- path.join(project_root, "tests", "fixtures", "tree-sitter-rust"),
- ],
-)
-
-EMBEDDED_TEMPLATE = Language(LIB_PATH, "embedded_template")
-HTML = Language(LIB_PATH, "html")
-JAVASCRIPT = Language(LIB_PATH, "javascript")
-JSON = Language(LIB_PATH, "json")
-PYTHON = Language(LIB_PATH, "python")
-RUST = Language(LIB_PATH, "rust")
-
-JSON_EXAMPLE: bytes = b"""
-
-[
- 123,
- false,
- {
- "x": null
- }
-]
-"""
-
-
-class TestParser(TestCase):
- def test_set_language(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
- self.assertEqual(
- tree.root_node.sexp(),
- trim(
- """(module (function_definition
- name: (identifier)
- parameters: (parameters)
- body: (block (expression_statement (call
- function: (identifier)
- arguments: (argument_list))))))"""
- ),
- )
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b"function foo() {\n bar();\n}")
- self.assertEqual(
- tree.root_node.sexp(),
- trim(
- """(program (function_declaration
- name: (identifier)
- parameters: (formal_parameters)
- body: (statement_block
- (expression_statement
- (call_expression
- function: (identifier)
- arguments: (arguments))))))"""
- ),
- )
-
- def test_read_callback(self):
- parser = Parser()
- parser.set_language(PYTHON)
- source_lines = ["def foo():\n", " bar()"]
-
- def read_callback(_: int, point: Tuple[int, int]) -> Optional[bytes]:
- row, column = point
- if row >= len(source_lines):
- return None
- if column >= len(source_lines[row]):
- return None
- return source_lines[row][column:].encode("utf8")
-
- tree = parser.parse(read_callback)
- self.assertEqual(
- tree.root_node.sexp(),
- trim(
- """(module (function_definition
- name: (identifier)
- parameters: (parameters)
- body: (block (expression_statement (call
- function: (identifier)
- arguments: (argument_list))))))"""
- ),
- )
-
- def test_multibyte_characters(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- source_code = bytes("'😎' && '🐍'", "utf8")
- tree = parser.parse(source_code)
- root_node = tree.root_node
- statement_node = root_node.children[0]
- binary_node = statement_node.children[0]
- snake_node = binary_node.children[2]
-
- self.assertEqual(binary_node.type, "binary_expression")
- self.assertEqual(snake_node.type, "string")
- self.assertEqual(
- source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"),
- "'🐍'",
- )
-
- def test_buffer_protocol(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- parser.parse(b"test")
- parser.parse(memoryview(b"test"))
- parser.parse(bytearray(b"test"))
-
- def test_multibyte_characters_via_read_callback(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- source_code = bytes("'😎' && '🐍'", "utf8")
-
- def read(byte_position, _):
- return source_code[byte_position : byte_position + 1]
-
- tree = parser.parse(read)
- root_node = tree.root_node
- statement_node = root_node.children[0]
- binary_node = statement_node.children[0]
- snake_node = binary_node.children[2]
-
- self.assertEqual(binary_node.type, "binary_expression")
- self.assertEqual(snake_node.type, "string")
- self.assertEqual(
- source_code[snake_node.start_byte : snake_node.end_byte].decode("utf8"),
- "'🐍'",
- )
-
- def test_parsing_with_one_included_range(self):
- source_code = b"hi"
- parser = Parser()
- parser.set_language(HTML)
- html_tree = parser.parse(source_code)
- script_content_node = html_tree.root_node.child(1).child(1)
- if script_content_node is None:
- self.fail("script_content_node is None")
- self.assertEqual(script_content_node.type, "raw_text")
-
- parser.set_included_ranges([script_content_node.range])
- parser.set_language(JAVASCRIPT)
- js_tree = parser.parse(source_code)
-
- self.assertEqual(
- js_tree.root_node.sexp(),
- "(program (expression_statement (call_expression "
- + "function: (member_expression object: (identifier) property: (property_identifier)) "
- + "arguments: (arguments (string (string_fragment))))))",
- )
- self.assertEqual(js_tree.root_node.start_point, (0, source_code.index(b"console")))
- self.assertEqual(js_tree.included_ranges, [script_content_node.range])
-
- def test_parsing_with_multiple_included_ranges(self):
- source_code = b"html `Hello, ${name.toUpperCase()}, it's ${now()}.
`"
-
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- js_tree = parser.parse(source_code)
- template_string_node = js_tree.root_node.descendant_for_byte_range(
- source_code.index(b"`<"), source_code.index(b">`")
- )
- if template_string_node is None:
- self.fail("template_string_node is None")
-
- self.assertEqual(template_string_node.type, "template_string")
-
- open_quote_node = template_string_node.child(0)
- if open_quote_node is None:
- self.fail("open_quote_node is None")
- interpolation_node1 = template_string_node.child(2)
- if interpolation_node1 is None:
- self.fail("interpolation_node1 is None")
- interpolation_node2 = template_string_node.child(4)
- if interpolation_node2 is None:
- self.fail("interpolation_node2 is None")
- close_quote_node = template_string_node.child(6)
- if close_quote_node is None:
- self.fail("close_quote_node is None")
-
- html_ranges = [
- Range(
- start_byte=open_quote_node.end_byte,
- start_point=open_quote_node.end_point,
- end_byte=interpolation_node1.start_byte,
- end_point=interpolation_node1.start_point,
- ),
- Range(
- start_byte=interpolation_node1.end_byte,
- start_point=interpolation_node1.end_point,
- end_byte=interpolation_node2.start_byte,
- end_point=interpolation_node2.start_point,
- ),
- Range(
- start_byte=interpolation_node2.end_byte,
- start_point=interpolation_node2.end_point,
- end_byte=close_quote_node.start_byte,
- end_point=close_quote_node.start_point,
- ),
- ]
- parser.set_included_ranges(html_ranges)
- parser.set_language(HTML)
- html_tree = parser.parse(source_code)
-
- self.assertEqual(
- html_tree.root_node.sexp(),
- "(document (element"
- + " (start_tag (tag_name))"
- + " (text)"
- + " (element (start_tag (tag_name)) (end_tag (tag_name)))"
- + " (text)"
- + " (end_tag (tag_name))))",
- )
- self.assertEqual(html_tree.included_ranges, html_ranges)
-
- div_element_node = html_tree.root_node.child(0)
- if div_element_node is None:
- self.fail("div_element_node is None")
- hello_text_node = div_element_node.child(1)
- if hello_text_node is None:
- self.fail("hello_text_node is None")
- b_element_node = div_element_node.child(2)
- if b_element_node is None:
- self.fail("b_element_node is None")
- b_start_tag_node = b_element_node.child(0)
- if b_start_tag_node is None:
- self.fail("b_start_tag_node is None")
- b_end_tag_node = b_element_node.child(1)
- if b_end_tag_node is None:
- self.fail("b_end_tag_node is None")
-
- self.assertEqual(hello_text_node.type, "text")
- self.assertEqual(hello_text_node.start_byte, source_code.index(b"Hello"))
- self.assertEqual(hello_text_node.end_byte, source_code.index(b" "))
-
- self.assertEqual(b_start_tag_node.type, "start_tag")
- self.assertEqual(b_start_tag_node.start_byte, source_code.index(b""))
- self.assertEqual(b_start_tag_node.end_byte, source_code.index(b"${now()}"))
-
- self.assertEqual(b_end_tag_node.type, "end_tag")
- self.assertEqual(b_end_tag_node.start_byte, source_code.index(b""))
- self.assertEqual(b_end_tag_node.end_byte, source_code.index(b"."))
-
- def test_parsing_with_included_range_containing_mismatched_positions(self):
- source_code = b"test
{_ignore_this_part_}"
-
- parser = Parser()
- parser.set_language(HTML)
-
- end_byte = source_code.index(b"{_ignore_this_part_")
-
- range_to_parse = Range(
- start_byte=0,
- start_point=(10, 12),
- end_byte=end_byte,
- end_point=(10, 12 + end_byte),
- )
-
- parser.set_included_ranges([range_to_parse])
-
- html_tree = parser.parse(source_code)
-
- self.assertEqual(
- html_tree.root_node.sexp(),
- "(document (element (start_tag (tag_name)) (text) (end_tag (tag_name))))",
- )
-
- def test_parsing_error_in_invalid_included_ranges(self):
- parser = Parser()
- with self.assertRaises(Exception):
- parser.set_included_ranges(
- [
- Range(
- start_byte=23,
- end_byte=29,
- start_point=(0, 23),
- end_point=(0, 29),
- ),
- Range(
- start_byte=0,
- end_byte=5,
- start_point=(0, 0),
- end_point=(0, 5),
- ),
- Range(
- start_byte=50,
- end_byte=60,
- start_point=(0, 50),
- end_point=(0, 60),
- ),
- ]
- )
-
- with self.assertRaises(Exception):
- parser.set_included_ranges(
- [
- Range(
- start_byte=10,
- end_byte=5,
- start_point=(0, 10),
- end_point=(0, 5),
- )
- ]
- )
-
- def test_parsing_with_external_scanner_that_uses_included_range_boundaries(self):
- source_code = b"a <%= b() %> c <% d() %>"
- range1_start_byte = source_code.index(b" b() ")
- range1_end_byte = range1_start_byte + len(b" b() ")
- range2_start_byte = source_code.index(b" d() ")
- range2_end_byte = range2_start_byte + len(b" d() ")
-
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- parser.set_included_ranges(
- [
- Range(
- start_byte=range1_start_byte,
- end_byte=range1_end_byte,
- start_point=(0, range1_start_byte),
- end_point=(0, range1_end_byte),
- ),
- Range(
- start_byte=range2_start_byte,
- end_byte=range2_end_byte,
- start_point=(0, range2_start_byte),
- end_point=(0, range2_end_byte),
- ),
- ]
- )
-
- tree = parser.parse(source_code)
- root = tree.root_node
- statement1 = root.child(0)
- if statement1 is None:
- self.fail("statement1 is None")
- statement2 = root.child(1)
- if statement2 is None:
- self.fail("statement2 is None")
-
- self.assertEqual(
- root.sexp(),
- "(program"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")",
- )
-
- self.assertEqual(statement1.start_byte, source_code.index(b"b()"))
- self.assertEqual(statement1.end_byte, source_code.find(b" %> c"))
- self.assertEqual(statement2.start_byte, source_code.find(b"d()"))
- self.assertEqual(statement2.end_byte, len(source_code) - len(" %>"))
-
- def test_parsing_with_a_newly_excluded_range(self):
- source_code = b"<%= something %>
"
-
- # Parse HTML including the template directive, which will cause an error
- parser = Parser()
- parser.set_language(HTML)
- first_tree = parser.parse(source_code)
-
- prefix = b"a very very long line of plain text. "
- first_tree.edit(
- start_byte=0,
- old_end_byte=0,
- new_end_byte=len(prefix),
- start_point=(0, 0),
- old_end_point=(0, 0),
- new_end_point=(0, len(prefix)),
- )
- source_code = prefix + source_code
-
- # Parse the HTML again, this time *excluding* the template directive
- # (which has moved since the previous parse).
- directive_start = source_code.index(b"<%=")
- directive_end = source_code.index(b"")
- source_code_end = len(source_code)
- parser.set_included_ranges(
- [
- Range(
- start_byte=0,
- end_byte=directive_start,
- start_point=(0, 0),
- end_point=(0, directive_start),
- ),
- Range(
- start_byte=directive_end,
- end_byte=source_code_end,
- start_point=(0, directive_end),
- end_point=(0, source_code_end),
- ),
- ]
- )
-
- tree = parser.parse(source_code, first_tree)
-
- self.assertEqual(
- tree.root_node.sexp(),
- "(document (text) (element"
- + " (start_tag (tag_name))"
- + " (element (start_tag (tag_name)) (end_tag (tag_name)))"
- + " (end_tag (tag_name))))",
- )
-
- self.assertEqual(
- tree.changed_ranges(first_tree),
- [
- # The first range that has changed syntax is the range of the newly-inserted text.
- Range(
- start_byte=0,
- end_byte=len(prefix),
- start_point=(0, 0),
- end_point=(0, len(prefix)),
- ),
- # Even though no edits were applied to the outer `div` element,
- # its contents have changed syntax because a range of text that
- # was previously included is now excluded.
- Range(
- start_byte=directive_start,
- end_byte=directive_end,
- start_point=(0, directive_start),
- end_point=(0, directive_end),
- ),
- ],
- )
-
- def test_parsing_with_a_newly_included_range(self):
- source_code = b"<%= foo() %>
<%= bar() %><%= baz() %>"
- range1_start = source_code.index(b" foo")
- range2_start = source_code.index(b" bar")
- range3_start = source_code.index(b" baz")
- range1_end = range1_start + 7
- range2_end = range2_start + 7
- range3_end = range3_start + 7
-
- def simple_range(start: int, end: int) -> Range:
- return Range(
- start_byte=start,
- end_byte=end,
- start_point=(0, start),
- end_point=(0, end),
- )
-
- # Parse only the first code directive as JavaScript
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- parser.set_included_ranges([simple_range(range1_start, range1_end)])
- tree = parser.parse(source_code)
- self.assertEqual(
- tree.root_node.sexp(),
- "(program"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")",
- )
-
- # Parse both the first and third code directives as JavaScript, using the old tree as a
- # reference.
- parser.set_included_ranges(
- [
- simple_range(range1_start, range1_end),
- simple_range(range3_start, range3_end),
- ]
- )
- tree2 = parser.parse(source_code)
- self.assertEqual(
- tree2.root_node.sexp(),
- "(program"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")",
- )
- self.assertEqual(tree2.changed_ranges(tree), [simple_range(range1_end, range3_end)])
-
- # Parse all three code directives as JavaScript, using the old tree as a
- # reference.
- parser.set_included_ranges(
- [
- simple_range(range1_start, range1_end),
- simple_range(range2_start, range2_end),
- simple_range(range3_start, range3_end),
- ]
- )
- tree3 = parser.parse(source_code)
- self.assertEqual(
- tree3.root_node.sexp(),
- "(program"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + " "
- + "(expression_statement (call_expression function: (identifier) arguments: (arguments)))"
- + ")",
- )
- self.assertEqual(
- tree3.changed_ranges(tree2),
- [simple_range(range2_start + 1, range2_end - 1)],
- )
-
-
-class TestNode(TestCase):
- def test_child_by_field_id(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
- root_node = tree.root_node
- fn_node = tree.root_node.children[0]
-
- self.assertEqual(PYTHON.field_id_for_name("nameasdf"), None)
- name_field = PYTHON.field_id_for_name("name")
- alias_field = PYTHON.field_id_for_name("alias")
- if not isinstance(alias_field, int):
- self.fail("alias_field is not an int")
- if not isinstance(name_field, int):
- self.fail("name_field is not an int")
- self.assertEqual(root_node.child_by_field_id(alias_field), None)
- self.assertEqual(root_node.child_by_field_id(name_field), None)
- self.assertEqual(fn_node.child_by_field_id(alias_field), None)
- self.assertEqual(fn_node.child_by_field_id(name_field).type, "identifier")
- self.assertRaises(TypeError, root_node.child_by_field_id, "")
- self.assertRaises(TypeError, root_node.child_by_field_name, True)
- self.assertRaises(TypeError, root_node.child_by_field_name, 1)
-
- self.assertEqual(fn_node.child_by_field_name("name").type, "identifier")
- self.assertEqual(fn_node.child_by_field_name("asdfasdfname"), None)
-
- self.assertEqual(
- fn_node.child_by_field_name("name"),
- fn_node.child_by_field_name("name"),
- )
-
- def test_children_by_field_id(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b"")
- jsx_node = tree.root_node.children[0].children[0]
- attribute_field = PYTHON.field_id_for_name("attribute")
- if not isinstance(attribute_field, int):
- self.fail("attribute_field is not an int")
-
- attributes = jsx_node.children_by_field_id(attribute_field)
- self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
-
- def test_children_by_field_name(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b"")
- jsx_node = tree.root_node.children[0].children[0]
-
- attributes = jsx_node.children_by_field_name("attribute")
- self.assertEqual([a.type for a in attributes], ["jsx_attribute", "jsx_attribute"])
-
- def test_node_child_by_field_name_with_extra_hidden_children(self):
- parser = Parser()
- parser.set_language(PYTHON)
-
- tree = parser.parse(b"while a:\n pass")
- while_node = tree.root_node.child(0)
- if while_node is None:
- self.fail("while_node is None")
- self.assertEqual(while_node.type, "while_statement")
- self.assertEqual(while_node.child_by_field_name("body"), while_node.child(3))
-
- def test_node_descendant_count(self):
- parser = Parser()
- parser.set_language(JSON)
- tree = parser.parse(JSON_EXAMPLE)
- value_node = tree.root_node
- all_nodes = get_all_nodes(tree)
-
- self.assertEqual(value_node.descendant_count, len(all_nodes))
-
- cursor = value_node.walk()
- for i, node in enumerate(all_nodes):
- cursor.goto_descendant(i)
- self.assertEqual(cursor.node, node, f"index {i}")
-
- for i, node in reversed(list(enumerate(all_nodes))):
- cursor.goto_descendant(i)
- self.assertEqual(cursor.node, node, f"rev index {i}")
-
- def test_descendant_count_single_node_tree(self):
- parser = Parser()
- parser.set_language(EMBEDDED_TEMPLATE)
- tree = parser.parse(b"hello")
-
- nodes = get_all_nodes(tree)
- self.assertEqual(len(nodes), 2)
- self.assertEqual(tree.root_node.descendant_count, 2)
-
- cursor = tree.walk()
-
- cursor.goto_descendant(0)
- self.assertEqual(cursor.depth, 0)
- self.assertEqual(cursor.node, nodes[0])
- cursor.goto_descendant(1)
- self.assertEqual(cursor.depth, 1)
- self.assertEqual(cursor.node, nodes[1])
-
- def test_field_name_for_child(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b"")
- jsx_node = tree.root_node.children[0].children[0]
-
- self.assertEqual(jsx_node.field_name_for_child(0), None)
- self.assertEqual(jsx_node.field_name_for_child(1), "name")
-
- def test_descendant_for_byte_range(self):
- parser = Parser()
- parser.set_language(JSON)
- tree = parser.parse(JSON_EXAMPLE)
- array_node = tree.root_node
-
- colon_index = JSON_EXAMPLE.index(b":")
-
- # Leaf node exactly matches the given bounds - byte query
- colon_node = array_node.descendant_for_byte_range(colon_index, colon_index + 1)
- if colon_node is None:
- self.fail("colon_node is None")
- self.assertEqual(colon_node.type, ":")
- self.assertEqual(colon_node.start_byte, colon_index)
- self.assertEqual(colon_node.end_byte, colon_index + 1)
- self.assertEqual(colon_node.start_point, (6, 7))
- self.assertEqual(colon_node.end_point, (6, 8))
-
- # Leaf node exactly matches the given bounds - point query
- colon_node = array_node.descendant_for_point_range((6, 7), (6, 8))
- if colon_node is None:
- self.fail("colon_node is None")
- self.assertEqual(colon_node.type, ":")
- self.assertEqual(colon_node.start_byte, colon_index)
- self.assertEqual(colon_node.end_byte, colon_index + 1)
- self.assertEqual(colon_node.start_point, (6, 7))
- self.assertEqual(colon_node.end_point, (6, 8))
-
- # The given point is between two adjacent leaf nodes - byte query
- colon_node = array_node.descendant_for_byte_range(colon_index, colon_index)
- if colon_node is None:
- self.fail("colon_node is None")
- self.assertEqual(colon_node.type, ":")
- self.assertEqual(colon_node.start_byte, colon_index)
- self.assertEqual(colon_node.end_byte, colon_index + 1)
- self.assertEqual(colon_node.start_point, (6, 7))
- self.assertEqual(colon_node.end_point, (6, 8))
-
- # The given point is between two adjacent leaf nodes - point query
- colon_node = array_node.descendant_for_point_range((6, 7), (6, 7))
- if colon_node is None:
- self.fail("colon_node is None")
- self.assertEqual(colon_node.type, ":")
- self.assertEqual(colon_node.start_byte, colon_index)
- self.assertEqual(colon_node.end_byte, colon_index + 1)
- self.assertEqual(colon_node.start_point, (6, 7))
- self.assertEqual(colon_node.end_point, (6, 8))
-
- # Leaf node starts at the lower bound, ends after the upper bound - byte query
- string_index = JSON_EXAMPLE.index(b'"x"')
- string_node = array_node.descendant_for_byte_range(string_index, string_index + 2)
- if string_node is None:
- self.fail("string_node is None")
- self.assertEqual(string_node.type, "string")
- self.assertEqual(string_node.start_byte, string_index)
- self.assertEqual(string_node.end_byte, string_index + 3)
- self.assertEqual(string_node.start_point, (6, 4))
- self.assertEqual(string_node.end_point, (6, 7))
-
- # Leaf node starts at the lower bound, ends after the upper bound - point query
- string_node = array_node.descendant_for_point_range((6, 4), (6, 6))
- if string_node is None:
- self.fail("string_node is None")
- self.assertEqual(string_node.type, "string")
- self.assertEqual(string_node.start_byte, string_index)
- self.assertEqual(string_node.end_byte, string_index + 3)
- self.assertEqual(string_node.start_point, (6, 4))
- self.assertEqual(string_node.end_point, (6, 7))
-
- # Leaf node starts before the lower bound, ends at the upper bound - byte query
- null_index = JSON_EXAMPLE.index(b"null")
- null_node = array_node.descendant_for_byte_range(null_index + 1, null_index + 4)
- if null_node is None:
- self.fail("null_node is None")
- self.assertEqual(null_node.type, "null")
- self.assertEqual(null_node.start_byte, null_index)
- self.assertEqual(null_node.end_byte, null_index + 4)
- self.assertEqual(null_node.start_point, (6, 9))
- self.assertEqual(null_node.end_point, (6, 13))
-
- # Leaf node starts before the lower bound, ends at the upper bound - point query
- null_node = array_node.descendant_for_point_range((6, 11), (6, 13))
- if null_node is None:
- self.fail("null_node is None")
- self.assertEqual(null_node.type, "null")
- self.assertEqual(null_node.start_byte, null_index)
- self.assertEqual(null_node.end_byte, null_index + 4)
- self.assertEqual(null_node.start_point, (6, 9))
- self.assertEqual(null_node.end_point, (6, 13))
-
- # The bounds span multiple leaf nodes - return the smallest node that does span it.
- pair_node = array_node.descendant_for_byte_range(string_index + 2, string_index + 4)
- if pair_node is None:
- self.fail("pair_node is None")
- self.assertEqual(pair_node.type, "pair")
- self.assertEqual(pair_node.start_byte, string_index)
- self.assertEqual(pair_node.end_byte, string_index + 9)
- self.assertEqual(pair_node.start_point, (6, 4))
- self.assertEqual(pair_node.end_point, (6, 13))
-
- self.assertEqual(colon_node.parent, pair_node)
-
- # No leaf spans the given range - return the smallest node that does span it.
- pair_node = array_node.descendant_for_point_range((6, 6), (6, 8))
- if pair_node is None:
- self.fail("pair_node is None")
- self.assertEqual(pair_node.type, "pair")
- self.assertEqual(pair_node.start_byte, string_index)
- self.assertEqual(pair_node.end_byte, string_index + 9)
- self.assertEqual(pair_node.start_point, (6, 4))
- self.assertEqual(pair_node.end_point, (6, 13))
-
- def test_root_node_with_offset(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b" if (a) b")
-
- node = tree.root_node_with_offset(6, (2, 2))
- if node is None:
- self.fail("node is None")
- self.assertEqual(node.byte_range, (8, 16))
- self.assertEqual(node.start_point, (2, 4))
- self.assertEqual(node.end_point, (2, 12))
-
- child = node.child(0).child(2)
- if child is None:
- self.fail("child is None")
- self.assertEqual(child.type, "expression_statement")
- self.assertEqual(child.byte_range, (15, 16))
- self.assertEqual(child.start_point, (2, 11))
- self.assertEqual(child.end_point, (2, 12))
-
- cursor = node.walk()
- cursor.goto_first_child()
- cursor.goto_first_child()
- cursor.goto_next_sibling()
- child = cursor.node
- if child is None:
- self.fail("child is None")
- self.assertEqual(child.type, "parenthesized_expression")
- self.assertEqual(child.byte_range, (11, 14))
- self.assertEqual(child.start_point, (2, 7))
- self.assertEqual(child.end_point, (2, 10))
-
- def test_node_is_extra(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- tree = parser.parse(b"foo(/* hi */);")
-
- root_node = tree.root_node
- comment_node = root_node.descendant_for_byte_range(7, 7)
- if comment_node is None:
- self.fail("comment_node is None")
-
- self.assertEqual(root_node.type, "program")
- self.assertEqual(comment_node.type, "comment")
- self.assertEqual(root_node.is_extra, False)
- self.assertEqual(comment_node.is_extra, True)
-
- def test_children(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
-
- root_node = tree.root_node
- self.assertEqual(root_node.type, "module")
- self.assertEqual(root_node.start_byte, 0)
- self.assertEqual(root_node.end_byte, 18)
- self.assertEqual(root_node.start_point, (0, 0))
- self.assertEqual(root_node.end_point, (1, 7))
-
- # List object is reused
- self.assertIs(root_node.children, root_node.children)
-
- fn_node = root_node.children[0]
- self.assertEqual(fn_node, root_node.child(0))
- self.assertEqual(fn_node.type, "function_definition")
- self.assertEqual(fn_node.start_byte, 0)
- self.assertEqual(fn_node.end_byte, 18)
- self.assertEqual(fn_node.start_point, (0, 0))
- self.assertEqual(fn_node.end_point, (1, 7))
-
- def_node = fn_node.children[0]
- self.assertEqual(def_node, fn_node.child(0))
- self.assertEqual(def_node.type, "def")
- self.assertEqual(def_node.is_named, False)
-
- id_node = fn_node.children[1]
- self.assertEqual(id_node, fn_node.child(1))
- self.assertEqual(id_node.type, "identifier")
- self.assertEqual(id_node.is_named, True)
- self.assertEqual(len(id_node.children), 0)
-
- params_node = fn_node.children[2]
- self.assertEqual(params_node, fn_node.child(2))
- self.assertEqual(params_node.type, "parameters")
- self.assertEqual(params_node.is_named, True)
-
- colon_node = fn_node.children[3]
- self.assertEqual(colon_node, fn_node.child(3))
- self.assertEqual(colon_node.type, ":")
- self.assertEqual(colon_node.is_named, False)
-
- statement_node = fn_node.children[4]
- self.assertEqual(statement_node, fn_node.child(4))
- self.assertEqual(statement_node.type, "block")
- self.assertEqual(statement_node.is_named, True)
-
- def test_named_and_sibling_and_count_and_parent(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"[1, 2, 3]")
-
- root_node = tree.root_node
- self.assertEqual(root_node.type, "module")
- self.assertEqual(root_node.start_byte, 0)
- self.assertEqual(root_node.end_byte, 9)
- self.assertEqual(root_node.start_point, (0, 0))
- self.assertEqual(root_node.end_point, (0, 9))
-
- exp_stmt_node = root_node.children[0]
- self.assertEqual(exp_stmt_node, root_node.child(0))
- self.assertEqual(exp_stmt_node.type, "expression_statement")
- self.assertEqual(exp_stmt_node.start_byte, 0)
- self.assertEqual(exp_stmt_node.end_byte, 9)
- self.assertEqual(exp_stmt_node.start_point, (0, 0))
- self.assertEqual(exp_stmt_node.end_point, (0, 9))
- self.assertEqual(exp_stmt_node.parent, root_node)
-
- list_node = exp_stmt_node.children[0]
- self.assertEqual(list_node, exp_stmt_node.child(0))
- self.assertEqual(list_node.type, "list")
- self.assertEqual(list_node.start_byte, 0)
- self.assertEqual(list_node.end_byte, 9)
- self.assertEqual(list_node.start_point, (0, 0))
- self.assertEqual(list_node.end_point, (0, 9))
- self.assertEqual(list_node.parent, exp_stmt_node)
-
- named_children = list_node.named_children
-
- open_delim_node = list_node.children[0]
- self.assertEqual(open_delim_node, list_node.child(0))
- self.assertEqual(open_delim_node.type, "[")
- self.assertEqual(open_delim_node.start_byte, 0)
- self.assertEqual(open_delim_node.end_byte, 1)
- self.assertEqual(open_delim_node.start_point, (0, 0))
- self.assertEqual(open_delim_node.end_point, (0, 1))
- self.assertEqual(open_delim_node.parent, list_node)
-
- first_num_node = list_node.children[1]
- self.assertEqual(first_num_node, list_node.child(1))
- self.assertEqual(first_num_node, open_delim_node.next_named_sibling)
- self.assertEqual(first_num_node.parent, list_node)
- self.assertEqual(named_children[0], first_num_node)
- self.assertEqual(first_num_node, list_node.named_child(0))
-
- first_comma_node = list_node.children[2]
- self.assertEqual(first_comma_node, list_node.child(2))
- self.assertEqual(first_comma_node, first_num_node.next_sibling)
- self.assertEqual(first_num_node, first_comma_node.prev_sibling)
- self.assertEqual(first_comma_node.parent, list_node)
-
- second_num_node = list_node.children[3]
- self.assertEqual(second_num_node, list_node.child(3))
- self.assertEqual(second_num_node, first_comma_node.next_sibling)
- self.assertEqual(second_num_node, first_num_node.next_named_sibling)
- self.assertEqual(first_num_node, second_num_node.prev_named_sibling)
- self.assertEqual(second_num_node.parent, list_node)
- self.assertEqual(named_children[1], second_num_node)
- self.assertEqual(second_num_node, list_node.named_child(1))
-
- second_comma_node = list_node.children[4]
- self.assertEqual(second_comma_node, list_node.child(4))
- self.assertEqual(second_comma_node, second_num_node.next_sibling)
- self.assertEqual(second_num_node, second_comma_node.prev_sibling)
- self.assertEqual(second_comma_node.parent, list_node)
-
- third_num_node = list_node.children[5]
- self.assertEqual(third_num_node, list_node.child(5))
- self.assertEqual(third_num_node, second_comma_node.next_sibling)
- self.assertEqual(third_num_node, second_num_node.next_named_sibling)
- self.assertEqual(second_num_node, third_num_node.prev_named_sibling)
- self.assertEqual(third_num_node.parent, list_node)
- self.assertEqual(named_children[2], third_num_node)
- self.assertEqual(third_num_node, list_node.named_child(2))
-
- close_delim_node = list_node.children[6]
- self.assertEqual(close_delim_node, list_node.child(6))
- self.assertEqual(close_delim_node.type, "]")
- self.assertEqual(close_delim_node.start_byte, 8)
- self.assertEqual(close_delim_node.end_byte, 9)
- self.assertEqual(close_delim_node.start_point, (0, 8))
- self.assertEqual(close_delim_node.end_point, (0, 9))
- self.assertEqual(close_delim_node, third_num_node.next_sibling)
- self.assertEqual(third_num_node, close_delim_node.prev_sibling)
- self.assertEqual(third_num_node, close_delim_node.prev_named_sibling)
- self.assertEqual(close_delim_node.parent, list_node)
-
- self.assertEqual(list_node.child_count, 7)
- self.assertEqual(list_node.named_child_count, 3)
-
- def test_node_text(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"[0, [1, 2, 3]]")
-
- self.assertEqual(tree.text, b"[0, [1, 2, 3]]")
-
- root_node = tree.root_node
- self.assertEqual(root_node.text, b"[0, [1, 2, 3]]")
-
- exp_stmt_node = root_node.children[0]
- self.assertEqual(exp_stmt_node.text, b"[0, [1, 2, 3]]")
-
- list_node = exp_stmt_node.children[0]
- self.assertEqual(list_node.text, b"[0, [1, 2, 3]]")
-
- open_delim_node = list_node.children[0]
- self.assertEqual(open_delim_node.text, b"[")
-
- first_num_node = list_node.children[1]
- self.assertEqual(first_num_node.text, b"0")
-
- first_comma_node = list_node.children[2]
- self.assertEqual(first_comma_node.text, b",")
-
- child_list_node = list_node.children[3]
- self.assertEqual(child_list_node.text, b"[1, 2, 3]")
-
- close_delim_node = list_node.children[4]
- self.assertEqual(close_delim_node.text, b"]")
-
- edit_offset = len(b"[0, [")
- tree.edit(
- start_byte=edit_offset,
- old_end_byte=edit_offset,
- new_end_byte=edit_offset + 2,
- start_point=(0, edit_offset),
- old_end_point=(0, edit_offset),
- new_end_point=(0, edit_offset + 2),
- )
- self.assertEqual(tree.text, None)
-
- root_node_again = tree.root_node
- self.assertEqual(root_node_again.text, None)
-
- tree_text_false = parser.parse(b"[0, [1, 2, 3]]", keep_text=False)
- self.assertIsNone(tree_text_false.text)
- root_node_text_false = tree_text_false.root_node
- self.assertIsNone(root_node_text_false.text)
-
- tree_text_true = parser.parse(b"[0, [1, 2, 3]]", keep_text=True)
- self.assertEqual(tree_text_true.text, b"[0, [1, 2, 3]]")
- root_node_text_true = tree_text_true.root_node
- self.assertEqual(root_node_text_true.text, b"[0, [1, 2, 3]]")
-
- def test_tree(self):
- code = b"def foo():\n bar()\n\ndef foo():\n bar()"
- parser = Parser()
- parser.set_language(PYTHON)
-
- def parse_root(bytes_):
- tree = parser.parse(bytes_)
- return tree.root_node
-
- root = parse_root(code)
- for item in root.children:
- self.assertIsNotNone(item.is_named)
-
- def parse_root_children(bytes_):
- tree = parser.parse(bytes_)
- return tree.root_node.children
-
- children = parse_root_children(code)
- for item in children:
- self.assertIsNotNone(item.is_named)
-
- def test_node_numeric_symbols_respect_simple_aliases(self):
- parser = Parser()
- parser.set_language(PYTHON)
-
- # Example 1:
- # Python argument lists can contain "splat" arguments, which are not allowed within
- # other expressions. This includes `parenthesized_list_splat` nodes like `(*b)`. These
- # `parenthesized_list_splat` nodes are aliased as `parenthesized_expression`. Their numeric
- # `symbol`, aka `kind_id` should match that of a normal `parenthesized_expression`.
- tree = parser.parse(b"(a((*b)))")
- root_node = tree.root_node
- self.assertEqual(
- root_node.sexp(),
- "(module (expression_statement (parenthesized_expression (call "
- + "function: (identifier) arguments: (argument_list (parenthesized_expression "
- + "(list_splat (identifier))))))))",
- )
-
- outer_expr_node = root_node.child(0).child(0)
- if outer_expr_node is None:
- self.fail("outer_expr_node is None")
- self.assertEqual(outer_expr_node.type, "parenthesized_expression")
-
- inner_expr_node = (
- outer_expr_node.named_child(0).child_by_field_name("arguments").named_child(0)
- )
- if inner_expr_node is None:
- self.fail("inner_expr_node is None")
-
- self.assertEqual(inner_expr_node.type, "parenthesized_expression")
- self.assertEqual(inner_expr_node.kind_id, outer_expr_node.kind_id)
-
-
-class TestTree(TestCase):
- def test_tree_cursor_without_tree(self):
- parser = Parser()
- parser.set_language(PYTHON)
-
- def parse():
- tree = parser.parse(b"def foo():\n bar()")
- return tree.walk()
-
- cursor = parse()
- self.assertIs(cursor.node, cursor.node)
- for item in cursor.node.children:
- self.assertIsNotNone(item.is_named)
-
- cursor = cursor.copy()
- self.assertIs(cursor.node, cursor.node)
- for item in cursor.node.children:
- self.assertIsNotNone(item.is_named)
-
- def test_walk(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
- cursor = tree.walk()
-
- # Node always returns the same instance
- self.assertIs(cursor.node, cursor.node)
-
- self.assertEqual(cursor.node.type, "module")
- self.assertEqual(cursor.node.start_byte, 0)
- self.assertEqual(cursor.node.end_byte, 18)
- self.assertEqual(cursor.node.start_point, (0, 0))
- self.assertEqual(cursor.node.end_point, (1, 7))
- self.assertEqual(cursor.field_name, None)
-
- self.assertTrue(cursor.goto_first_child())
- self.assertEqual(cursor.node.type, "function_definition")
- self.assertEqual(cursor.node.start_byte, 0)
- self.assertEqual(cursor.node.end_byte, 18)
- self.assertEqual(cursor.node.start_point, (0, 0))
- self.assertEqual(cursor.node.end_point, (1, 7))
- self.assertEqual(cursor.field_name, None)
-
- self.assertTrue(cursor.goto_first_child())
- self.assertEqual(cursor.node.type, "def")
- self.assertEqual(cursor.node.is_named, False)
- self.assertEqual(cursor.node.sexp(), '("def")')
- self.assertEqual(cursor.field_name, None)
- def_node = cursor.node
-
- # Node remains cached after a failure to move
- self.assertFalse(cursor.goto_first_child())
- self.assertIs(cursor.node, def_node)
-
- self.assertTrue(cursor.goto_next_sibling())
- self.assertEqual(cursor.node.type, "identifier")
- self.assertEqual(cursor.node.is_named, True)
- self.assertEqual(cursor.field_name, "name")
- self.assertFalse(cursor.goto_first_child())
-
- self.assertTrue(cursor.goto_next_sibling())
- self.assertEqual(cursor.node.type, "parameters")
- self.assertEqual(cursor.node.is_named, True)
- self.assertEqual(cursor.field_name, "parameters")
-
- def test_edit(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
-
- edit_offset = len(b"def foo(")
- tree.edit(
- start_byte=edit_offset,
- old_end_byte=edit_offset,
- new_end_byte=edit_offset + 2,
- start_point=(0, edit_offset),
- old_end_point=(0, edit_offset),
- new_end_point=(0, edit_offset + 2),
- )
-
- fn_node = tree.root_node.children[0]
- self.assertEqual(fn_node.type, "function_definition")
- self.assertTrue(fn_node.has_changes)
- self.assertFalse(fn_node.children[0].has_changes)
- self.assertFalse(fn_node.children[1].has_changes)
- self.assertFalse(fn_node.children[3].has_changes)
-
- params_node = fn_node.children[2]
- self.assertEqual(params_node.type, "parameters")
- self.assertTrue(params_node.has_changes)
- self.assertEqual(params_node.start_point, (0, edit_offset - 1))
- self.assertEqual(params_node.end_point, (0, edit_offset + 3))
-
- new_tree = parser.parse(b"def foo(ab):\n bar()", tree)
- self.assertEqual(
- new_tree.root_node.sexp(),
- trim(
- """(module (function_definition
- name: (identifier)
- parameters: (parameters (identifier))
- body: (block
- (expression_statement (call
- function: (identifier)
- arguments: (argument_list))))))"""
- ),
- )
-
- def test_changed_ranges(self):
- parser = Parser()
- parser.set_language(PYTHON)
- tree = parser.parse(b"def foo():\n bar()")
-
- edit_offset = len(b"def foo(")
- tree.edit(
- start_byte=edit_offset,
- old_end_byte=edit_offset,
- new_end_byte=edit_offset + 2,
- start_point=(0, edit_offset),
- old_end_point=(0, edit_offset),
- new_end_point=(0, edit_offset + 2),
- )
-
- new_tree = parser.parse(b"def foo(ab):\n bar()", tree)
- changed_ranges = tree.changed_ranges(new_tree)
-
- self.assertEqual(len(changed_ranges), 1)
- self.assertEqual(changed_ranges[0].start_byte, edit_offset)
- self.assertEqual(changed_ranges[0].start_point, (0, edit_offset))
- self.assertEqual(changed_ranges[0].end_byte, edit_offset + 2)
- self.assertEqual(changed_ranges[0].end_point, (0, edit_offset + 2))
-
- def test_tree_cursor(self):
- parser = Parser()
- parser.set_language(RUST)
-
- tree = parser.parse(
- b"""
- struct Stuff {
- a: A,
- b: Option,
- }
- """
- )
-
- cursor = tree.walk()
- self.assertEqual(cursor.node.type, "source_file")
-
- self.assertEqual(cursor.goto_first_child(), True)
- self.assertEqual(cursor.node.type, "struct_item")
-
- self.assertEqual(cursor.goto_first_child(), True)
- self.assertEqual(cursor.node.type, "struct")
- self.assertEqual(cursor.node.is_named, False)
-
- self.assertEqual(cursor.goto_next_sibling(), True)
- self.assertEqual(cursor.node.type, "type_identifier")
- self.assertEqual(cursor.node.is_named, True)
-
- self.assertEqual(cursor.goto_next_sibling(), True)
- self.assertEqual(cursor.node.type, "field_declaration_list")
- self.assertEqual(cursor.node.is_named, True)
-
- self.assertEqual(cursor.goto_last_child(), True)
- self.assertEqual(cursor.node.type, "}")
- self.assertEqual(cursor.node.is_named, False)
- self.assertEqual(cursor.node.start_point, (4, 16))
-
- self.assertEqual(cursor.goto_previous_sibling(), True)
- self.assertEqual(cursor.node.type, ",")
- self.assertEqual(cursor.node.is_named, False)
- self.assertEqual(cursor.node.start_point, (3, 32))
-
- self.assertEqual(cursor.goto_previous_sibling(), True)
- self.assertEqual(cursor.node.type, "field_declaration")
- self.assertEqual(cursor.node.is_named, True)
- self.assertEqual(cursor.node.start_point, (3, 20))
-
- self.assertEqual(cursor.goto_previous_sibling(), True)
- self.assertEqual(cursor.node.type, ",")
- self.assertEqual(cursor.node.is_named, False)
- self.assertEqual(cursor.node.start_point, (2, 24))
-
- self.assertEqual(cursor.goto_previous_sibling(), True)
- self.assertEqual(cursor.node.type, "field_declaration")
- self.assertEqual(cursor.node.is_named, True)
- self.assertEqual(cursor.node.start_point, (2, 20))
-
- self.assertEqual(cursor.goto_previous_sibling(), True)
- self.assertEqual(cursor.node.type, "{")
- self.assertEqual(cursor.node.is_named, False)
- self.assertEqual(cursor.node.start_point, (1, 29))
-
- copy = tree.walk()
- copy.reset_to(cursor)
-
- self.assertEqual(copy.node.type, "{")
- self.assertEqual(copy.node.is_named, False)
-
- self.assertEqual(copy.goto_parent(), True)
- self.assertEqual(copy.node.type, "field_declaration_list")
- self.assertEqual(copy.node.is_named, True)
-
- self.assertEqual(copy.goto_parent(), True)
- self.assertEqual(copy.node.type, "struct_item")
-
-
-class TestQuery(TestCase):
- def test_errors(self):
- with self.assertRaisesRegex(NameError, "Invalid node type foo"):
- PYTHON.query("(list (foo))")
- with self.assertRaisesRegex(NameError, "Invalid field name buzz"):
- PYTHON.query("(function_definition buzz: (identifier))")
- with self.assertRaisesRegex(NameError, "Invalid capture name garbage"):
- PYTHON.query("((function_definition) (#eq? @garbage foo))")
- with self.assertRaisesRegex(SyntaxError, "Invalid syntax at offset 6"):
- PYTHON.query("(list))")
- PYTHON.query("(function_definition)")
-
- def collect_matches(
- self,
- matches: List[Tuple[int, Dict[str, Union[Node, List[Node]]]]],
- ) -> List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]]:
- return [(m[0], self.format_captures(m[1])) for m in matches]
-
- def format_captures(
- self, captures: Dict[str, Union[Node, List[Node]]]
- ) -> List[Tuple[str, Union[str, List[str]]]]:
- return [(name, self.format_capture(capture)) for name, capture in captures.items()]
-
- def format_capture(self, capture: Union[Node, List[Node]]) -> Union[str, List[str]]:
- return (
- [n.text.decode("utf-8") for n in capture]
- if isinstance(capture, List)
- else capture.text.decode("utf-8")
- )
-
- def assert_query_matches(
- self,
- language: Language,
- query: Query,
- source: bytes,
- expected: List[Tuple[int, List[Tuple[str, Union[str, List[str]]]]]],
- ):
- parser = Parser()
- parser.set_language(language)
- tree = parser.parse(source)
- matches = query.matches(tree.root_node)
- matches = self.collect_matches(matches)
- self.assertEqual(matches, expected)
-
- def test_matches_with_simple_pattern(self):
- query = JAVASCRIPT.query("(function_declaration name: (identifier) @fn-name)")
- self.assert_query_matches(
- JAVASCRIPT,
- query,
- b"function one() { two(); function three() {} }",
- [(0, [("fn-name", "one")]), (0, [("fn-name", "three")])],
- )
-
- def test_matches_with_multiple_on_same_root(self):
- query = JAVASCRIPT.query(
- """
- (class_declaration
- name: (identifier) @the-class-name
- (class_body
- (method_definition
- name: (property_identifier) @the-method-name)))
- """
- )
- self.assert_query_matches(
- JAVASCRIPT,
- query,
- b"""
- class Person {
- // the constructor
- constructor(name) { this.name = name; }
-
- // the getter
- getFullName() { return this.name; }
- }
- """,
- [
- (0, [("the-class-name", "Person"), ("the-method-name", "constructor")]),
- (0, [("the-class-name", "Person"), ("the-method-name", "getFullName")]),
- ],
- )
-
- def test_matches_with_multiple_patterns_different_roots(self):
- query = JAVASCRIPT.query(
- """
- (function_declaration name:(identifier) @fn-def)
- (call_expression function:(identifier) @fn-ref)
- """
- )
- self.assert_query_matches(
- JAVASCRIPT,
- query,
- b"""
- function f1() {
- f2(f3());
- }
- """,
- [
- (0, [("fn-def", "f1")]),
- (1, [("fn-ref", "f2")]),
- (1, [("fn-ref", "f3")]),
- ],
- )
-
- def test_matches_with_nesting_and_no_fields(self):
- query = JAVASCRIPT.query(
- """
- (array
- (array
- (identifier) @x1
- (identifier) @x2))
- """
- )
- self.assert_query_matches(
- JAVASCRIPT,
- query,
- b"""
- [[a]];
- [[c, d], [e, f, g, h]];
- [[h], [i]];
- """,
- [
- (0, [("x1", "c"), ("x2", "d")]),
- (0, [("x1", "e"), ("x2", "f")]),
- (0, [("x1", "e"), ("x2", "g")]),
- (0, [("x1", "f"), ("x2", "g")]),
- (0, [("x1", "e"), ("x2", "h")]),
- (0, [("x1", "f"), ("x2", "h")]),
- (0, [("x1", "g"), ("x2", "h")]),
- ],
- )
-
- def test_matches_with_list_capture(self):
- query = JAVASCRIPT.query(
- """(function_declaration name: (identifier) @fn-name
- body: (statement_block (_)* @fn-statements)
- )"""
- )
- self.assert_query_matches(
- JAVASCRIPT,
- query,
- b"""function one() {
- x = 1;
- y = 2;
- z = 3;
- }
- function two() {
- x = 1;
- }
- """,
- [
- (
- 0,
- [
- ("fn-name", "one"),
- ("fn-statements", ["x = 1;", "y = 2;", "z = 3;"]),
- ],
- ),
- (0, [("fn-name", "two"), ("fn-statements", ["x = 1;"])]),
- ],
- )
-
- def test_captures(self):
- parser = Parser()
- parser.set_language(PYTHON)
- source = b"def foo():\n bar()\ndef baz():\n quux()\n"
- tree = parser.parse(source)
- query = PYTHON.query(
- """
- (function_definition name: (identifier) @func-def)
- (call function: (identifier) @func-call)
- """
- )
-
- captures = query.captures(tree.root_node)
-
- self.assertEqual(captures[0][0].start_point, (0, 4))
- self.assertEqual(captures[0][0].end_point, (0, 7))
- self.assertEqual(captures[0][1], "func-def")
-
- self.assertEqual(captures[1][0].start_point, (1, 2))
- self.assertEqual(captures[1][0].end_point, (1, 5))
- self.assertEqual(captures[1][1], "func-call")
-
- self.assertEqual(captures[2][0].start_point, (2, 4))
- self.assertEqual(captures[2][0].end_point, (2, 7))
- self.assertEqual(captures[2][1], "func-def")
-
- self.assertEqual(captures[3][0].start_point, (3, 2))
- self.assertEqual(captures[3][0].end_point, (3, 6))
- self.assertEqual(captures[3][1], "func-call")
-
- def test_text_predicates(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- source = b"""
- keypair_object = {
- key1: value1,
- equal: equal
- }
-
- function fun1(arg) {
- return 1;
- }
-
- function fun2(arg) {
- return 2;
- }
- """
- tree = parser.parse(source)
- root_node = tree.root_node
-
- # function with name equal to 'fun1' -> test for #eq? @capture string
- query1 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#eq? @function-name fun1)
- )
- """
- )
- captures1 = query1.captures(root_node)
- self.assertEqual(1, len(captures1))
- self.assertEqual(b"fun1", captures1[0][0].text)
- self.assertEqual("function-name", captures1[0][1])
-
- # functions with name not equal to 'fun1' -> test for #not-eq? @capture string
- query2 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#not-eq? @function-name fun1)
- )
- """
- )
- captures2 = query2.captures(root_node)
- self.assertEqual(1, len(captures2))
- self.assertEqual(b"fun2", captures2[0][0].text)
- self.assertEqual("function-name", captures2[0][1])
-
- # key pairs whose key is equal to its value -> test for #eq? @capture1 @capture2
- query3 = JAVASCRIPT.query(
- """
- (
- (pair
- key: (property_identifier) @key-name
- value: (identifier) @value-name)
- (#eq? @key-name @value-name)
- )
- """
- )
- captures3 = query3.captures(root_node)
- self.assertEqual(2, len(captures3))
- self.assertSetEqual({b"equal"}, set([c[0].text for c in captures3]))
- self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures3]))
-
- # key pairs whose key is not equal to its value
- # -> test for #not-eq? @capture1 @capture2
- query4 = JAVASCRIPT.query(
- """
- (
- (pair
- key: (property_identifier) @key-name
- value: (identifier) @value-name)
- (#not-eq? @key-name @value-name)
- )
- """
- )
- captures4 = query4.captures(root_node)
- self.assertEqual(2, len(captures4))
- self.assertSetEqual({b"key1", b"value1"}, set([c[0].text for c in captures4]))
- self.assertSetEqual({"key-name", "value-name"}, set([c[1] for c in captures4]))
-
- # equality that is satisfied by *another* capture
- query5 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- parameters: (formal_parameters (identifier) @parameter-name)
- )
- (#eq? @function-name arg)
- )
- """
- )
- captures5 = query5.captures(root_node)
- self.assertEqual(0, len(captures5))
-
- # functions that match the regex .*1 -> test for #match @capture regex
- query6 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#match? @function-name ".*1")
- )
- """
- )
- captures6 = query6.captures(root_node)
- self.assertEqual(1, len(captures6))
- self.assertEqual(b"fun1", captures6[0][0].text)
-
- # functions that do not match the regex .*1 -> test for #not-match @capture regex
- query6 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#not-match? @function-name ".*1")
- )
- """
- )
- captures6 = query6.captures(root_node)
- self.assertEqual(1, len(captures6))
- self.assertEqual(b"fun2", captures6[0][0].text)
-
- # after editing there is no text property, so predicates are ignored
- tree.edit(
- start_byte=0,
- old_end_byte=0,
- new_end_byte=2,
- start_point=(0, 0),
- old_end_point=(0, 0),
- new_end_point=(0, 2),
- )
- captures_notext = query1.captures(root_node)
- self.assertEqual(2, len(captures_notext))
- self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext]))
-
- def test_text_predicate_on_optional_capture(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- source = b"fun1(1)"
- tree = parser.parse(source)
- root_node = tree.root_node
-
- # optional capture that is missing in source used in #eq? @capture string
- query1 = JAVASCRIPT.query(
- """
- ((call_expression
- function: (identifier) @function-name
- arguments: (arguments (string)? @optional-string-arg)
- (#eq? @optional-string-arg "1")))
- """
- )
- captures1 = query1.captures(root_node)
- self.assertEqual(1, len(captures1))
- self.assertEqual(b"fun1", captures1[0][0].text)
- self.assertEqual("function-name", captures1[0][1])
-
- # optional capture that is missing in source used in #eq? @capture @capture
- query2 = JAVASCRIPT.query(
- """
- ((call_expression
- function: (identifier) @function-name
- arguments: (arguments (string)? @optional-string-arg)
- (#eq? @optional-string-arg @function-name)))
- """
- )
- captures2 = query2.captures(root_node)
- self.assertEqual(1, len(captures2))
- self.assertEqual(b"fun1", captures2[0][0].text)
- self.assertEqual("function-name", captures2[0][1])
-
- # optional capture that is missing in source used in #match? @capture string
- query3 = JAVASCRIPT.query(
- """
- ((call_expression
- function: (identifier) @function-name
- arguments: (arguments (string)? @optional-string-arg)
- (#match? @optional-string-arg "\\d+")))
- """
- )
- captures3 = query3.captures(root_node)
- self.assertEqual(1, len(captures3))
- self.assertEqual(b"fun1", captures3[0][0].text)
- self.assertEqual("function-name", captures3[0][1])
-
- def test_text_predicates_errors(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- with self.assertRaises(RuntimeError):
- JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#eq? @function-name @function-name fun1)
- )
- """
- )
-
- with self.assertRaises(RuntimeError):
- JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#eq? fun1 @function-name)
- )
- """
- )
-
- with self.assertRaises(RuntimeError):
- JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#match? @function-name @function-name fun1)
- )
- """
- )
-
- with self.assertRaises(RuntimeError):
- JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#match? fun1 @function-name)
- )
- """
- )
-
- with self.assertRaises(RuntimeError):
- JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- )
- (#match? @function-name @function-name)
- )
- """
- )
-
- def test_multiple_text_predicates(self):
- parser = Parser()
- parser.set_language(JAVASCRIPT)
- source = b"""
- keypair_object = {
- key1: value1,
- equal: equal
- }
-
- function fun1(arg) {
- return 1;
- }
-
- function fun1(notarg) {
- return 1 + 1;
- }
-
- function fun2(arg) {
- return 2;
- }
- """
- tree = parser.parse(source)
- root_node = tree.root_node
-
- # function with name equal to 'fun1' -> test for first #eq? @capture string
- query1 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- parameters: (formal_parameters
- (identifier) @argument-name
- )
- )
- (#eq? @function-name fun1)
- )
- """
- )
- captures1 = query1.captures(root_node)
- self.assertEqual(4, len(captures1))
- self.assertEqual(b"fun1", captures1[0][0].text)
- self.assertEqual("function-name", captures1[0][1])
- self.assertEqual(b"arg", captures1[1][0].text)
- self.assertEqual("argument-name", captures1[1][1])
- self.assertEqual(b"fun1", captures1[2][0].text)
- self.assertEqual("function-name", captures1[2][1])
- self.assertEqual(b"notarg", captures1[3][0].text)
- self.assertEqual("argument-name", captures1[3][1])
-
- # function with argument equal to 'arg' -> test for second #eq? @capture string
- query2 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- parameters: (formal_parameters
- (identifier) @argument-name
- )
- )
- (#eq? @argument-name arg)
- )
- """
- )
- captures2 = query2.captures(root_node)
- self.assertEqual(4, len(captures2))
- self.assertEqual(b"fun1", captures2[0][0].text)
- self.assertEqual("function-name", captures2[0][1])
- self.assertEqual(b"arg", captures2[1][0].text)
- self.assertEqual("argument-name", captures2[1][1])
- self.assertEqual(b"fun2", captures2[2][0].text)
- self.assertEqual("function-name", captures2[2][1])
- self.assertEqual(b"arg", captures2[3][0].text)
- self.assertEqual("argument-name", captures2[3][1])
-
- # function with name equal to 'fun1' & argument 'arg' -> test for both together
- query3 = JAVASCRIPT.query(
- """
- (
- (function_declaration
- name: (identifier) @function-name
- parameters: (formal_parameters
- (identifier) @argument-name
- )
- )
- (#eq? @function-name fun1)
- (#eq? @argument-name arg)
- )
- """
- )
- captures3 = query3.captures(root_node)
- self.assertEqual(2, len(captures3))
- self.assertEqual(b"fun1", captures3[0][0].text)
- self.assertEqual("function-name", captures3[0][1])
- self.assertEqual(b"arg", captures3[1][0].text)
- self.assertEqual("argument-name", captures3[1][1])
-
- def test_byte_range_captures(self):
- parser = Parser()
- parser.set_language(PYTHON)
- source = b"def foo():\n bar()\ndef baz():\n quux()\n"
- tree = parser.parse(source)
- query = PYTHON.query(
- """
- (function_definition name: (identifier) @func-def)
- (call function: (identifier) @func-call)
- """
- )
-
- captures = query.captures(tree.root_node, start_byte=10, end_byte=20)
- self.assertEqual(captures[0][0].start_point, (1, 2))
- self.assertEqual(captures[0][0].end_point, (1, 5))
- self.assertEqual(captures[0][1], "func-call")
-
- def test_point_range_captures(self):
- parser = Parser()
- parser.set_language(PYTHON)
- source = b"def foo():\n bar()\ndef baz():\n quux()\n"
- tree = parser.parse(source)
- query = PYTHON.query(
- """
- (function_definition name: (identifier) @func-def)
- (call function: (identifier) @func-call)
- """
- )
-
- captures = query.captures(tree.root_node, start_point=(1, 0), end_point=(2, 0))
-
- self.assertEqual(captures[0][0].start_point, (1, 2))
- self.assertEqual(captures[0][0].end_point, (1, 5))
- self.assertEqual(captures[0][1], "func-call")
-
- def test_node_hash(self):
- parser = Parser()
- parser.set_language(PYTHON)
- source_code = b"def foo():\n bar()\n bar()"
- tree = parser.parse(source_code)
- root_node = tree.root_node
- first_function_node = root_node.children[0]
- second_function_node = root_node.children[0]
-
- # Uniqueness and consistency
- self.assertEqual(hash(first_function_node), hash(first_function_node))
- self.assertNotEqual(hash(root_node), hash(first_function_node))
-
- # Equality implication
- self.assertEqual(hash(first_function_node), hash(second_function_node))
- self.assertTrue(first_function_node == second_function_node)
-
- # Different nodes with different properties
- different_tree = parser.parse(b"def baz():\n qux()")
- different_node = different_tree.root_node.children[0]
- self.assertNotEqual(hash(first_function_node), hash(different_node))
-
- # Same code, different parse trees
- another_tree = parser.parse(source_code)
- another_node = another_tree.root_node.children[0]
- self.assertNotEqual(hash(first_function_node), hash(another_node))
-
-
-class TestLookaheadIterator(TestCase):
- def test_lookahead_iterator(self):
- parser = Parser()
- parser.set_language(RUST)
- tree = parser.parse(b"struct Stuff{}")
-
- cursor = tree.walk()
-
- self.assertEqual(cursor.goto_first_child(), True) # struct
- self.assertEqual(cursor.goto_first_child(), True) # struct keyword
-
- next_state = cursor.node.next_parse_state
-
- self.assertNotEqual(next_state, 0)
- self.assertEqual(
- next_state, RUST.next_state(cursor.node.parse_state, cursor.node.grammar_id)
- )
- self.assertLess(next_state, RUST.parse_state_count)
- self.assertEqual(cursor.goto_next_sibling(), True) # type_identifier
- self.assertEqual(next_state, cursor.node.parse_state)
- self.assertEqual(cursor.node.grammar_name, "identifier")
- self.assertNotEqual(cursor.node.grammar_id, cursor.node.kind_id)
-
- expected_symbols = ["//", "/*", "identifier", "line_comment", "block_comment"]
- lookahead: LookaheadIterator = RUST.lookahead_iterator(next_state)
- self.assertEqual(lookahead.language, RUST.language_id)
- self.assertEqual(list(lookahead.iter_names()), expected_symbols)
-
- lookahead.reset_state(next_state)
- self.assertEqual(list(lookahead.iter_names()), expected_symbols)
-
- lookahead.reset(RUST.language_id, next_state)
- self.assertEqual(list(map(RUST.node_kind_for_id, list(iter(lookahead)))), expected_symbols)
-
-
-def trim(string):
- return re.sub(r"\s+", " ", string).strip()
-
-
-def get_all_nodes(tree: Tree) -> List[Node]:
- result = []
- visited_children = False
- cursor = tree.walk()
- while True:
- if not visited_children:
- result.append(cursor.node)
- if not cursor.goto_first_child():
- visited_children = True
- elif cursor.goto_next_sibling():
- visited_children = False
- elif not cursor.goto_parent():
- break
- return result
diff --git a/tree_sitter/__init__.py b/tree_sitter/__init__.py
index 18d8f0f6..a21fca5c 100644
--- a/tree_sitter/__init__.py
+++ b/tree_sitter/__init__.py
@@ -14,9 +14,9 @@
MIN_COMPATIBLE_LANGUAGE_VERSION,
)
-Point.__doc__ = 'A position in a multi-line text document, in terms of rows and columns.'
-Point.row.__doc__ = 'The zero-based row of the document.'
-Point.column.__doc__ = 'The zero-based column of the document.'
+Point.__doc__ = "A position in a multi-line text document, in terms of rows and columns."
+Point.row.__doc__ = "The zero-based row of the document."
+Point.column.__doc__ = "The zero-based column of the document."
__all__ = [
"Language",
diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi
index 07d24610..bd88ccd0 100644
--- a/tree_sitter/__init__.pyi
+++ b/tree_sitter/__init__.pyi
@@ -4,16 +4,14 @@ from typing_extensions import deprecated
_Ptr = Annotated[int, "TSLanguage *"]
-_ParseCB = Callable[[int, Point], bytes]
+_ParseCB = Callable[[int, Point | tuple[int, int]], bytes]
_UINT32_MAX = 0xFFFFFFFF
-
class Point(NamedTuple):
row: int
column: int
-
@final
class Language:
def __init__(self, ptr: _Ptr, /) -> None: ...
@@ -24,288 +22,195 @@ class Language:
@property
def version(self) -> int: ...
-
@property
def node_kind_count(self) -> int: ...
-
@property
def parse_state_count(self) -> int: ...
-
@property
def field_count(self) -> int: ...
-
def node_kind_for_id(self, id: int, /) -> str | None: ...
-
def id_for_node_kind(self, kind: str, named: bool, /) -> int | None: ...
-
def node_kind_is_named(self, id: int, /) -> bool: ...
-
def node_kind_is_visible(self, id: int, /) -> bool: ...
-
def field_name_for_id(self, field_id: int, /) -> str | None: ...
-
def field_id_for_name(self, name: str, /) -> int | None: ...
-
def next_state(self, state: int, id: int, /) -> int: ...
-
def lookahead_iterator(self, state: int, /) -> LookaheadIterator | None: ...
-
def query(self, source: str, /) -> Query: ...
-
def __repr__(self) -> str: ...
-
def __eq__(self, other: Any, /) -> bool: ...
-
def __ne__(self, other: Any, /) -> bool: ...
-
def __int__(self) -> int: ...
-
def __index__(self) -> int: ...
-
def __hash__(self) -> int: ...
-
@final
class Node:
@property
def id(self) -> int: ...
-
@property
def kind_id(self) -> int: ...
-
@property
def grammar_id(self) -> int: ...
-
@property
def grammar_name(self) -> str: ...
-
@property
def type(self) -> str: ...
-
@property
def is_named(self) -> bool: ...
-
@property
def is_extra(self) -> bool: ...
-
@property
def has_changes(self) -> bool: ...
-
@property
def has_error(self) -> bool: ...
-
@property
def is_error(self) -> bool: ...
-
@property
def parse_state(self) -> int: ...
-
@property
def next_parse_state(self) -> int: ...
-
@property
def is_missing(self) -> bool: ...
-
@property
def start_byte(self) -> int: ...
-
@property
def end_byte(self) -> int: ...
-
@property
def byte_range(self) -> tuple[int, int]: ...
-
@property
def range(self) -> Range: ...
-
@property
def start_point(self) -> Point: ...
-
@property
def end_point(self) -> Point: ...
-
@property
def children(self) -> list[Node]: ...
-
@property
def child_count(self) -> int: ...
-
@property
def named_children(self) -> list[Node]: ...
-
@property
def named_child_count(self) -> int: ...
-
@property
def parent(self) -> Node | None: ...
-
@property
def next_sibling(self) -> Node | None: ...
-
@property
def prev_sibling(self) -> Node | None: ...
-
@property
def next_named_sibling(self) -> Node | None: ...
-
@property
def prev_named_sibling(self) -> Node | None: ...
-
@property
def descendant_count(self) -> int: ...
-
@property
def text(self) -> bytes | None: ...
-
def walk(self) -> TreeCursor: ...
-
def edit(
self,
start_byte: int,
old_end_byte: int,
new_end_byte: int,
- start_point: Point,
- old_end_point: Point,
- new_end_point: Point,
+ start_point: Point | tuple[int, int],
+ old_end_point: Point | tuple[int, int],
+ new_end_point: Point | tuple[int, int],
) -> None: ...
-
def child(self, index: int, /) -> Node | None: ...
-
def named_child(self, index: int, /) -> Node | None: ...
-
def child_by_field_id(self, id: int, /) -> Node | None: ...
-
def child_by_field_name(self, name: str, /) -> Node | None: ...
-
def children_by_field_id(self, id: int, /) -> list[Node]: ...
-
def children_by_field_name(self, name: str, /) -> list[Node]: ...
-
def field_name_for_child(self, child_index: int, /) -> str | None: ...
-
def descendant_for_byte_range(
self,
start_byte: int,
end_byte: int,
/,
) -> Node | None: ...
-
def named_descendant_for_byte_range(
self,
start_byte: int,
end_byte: int,
/,
) -> Node | None: ...
-
def descendant_for_point_range(
self,
- start_point: Point,
- end_point: Point,
+ start_point: Point | tuple[int, int],
+ end_point: Point | tuple[int, int],
/,
) -> Node | None: ...
-
def named_descendant_for_point_range(
self,
- start_point: Point,
- end_point: Point,
+ start_point: Point | tuple[int, int],
+ end_point: Point | tuple[int, int],
/,
) -> Node | None: ...
-
@deprecated("Use `str()` instead")
def sexp(self) -> str: ...
-
def __repr__(self) -> str: ...
-
def __str__(self) -> str: ...
-
def __eq__(self, other: Any, /) -> bool: ...
-
def __ne__(self, other: Any, /) -> bool: ...
-
def __hash__(self) -> int: ...
-
@final
-class Tree():
+class Tree:
@property
def root_node(self) -> Node: ...
-
@property
def included_ranges(self) -> list[Range]: ...
-
@property
@deprecated("Use `root_node.text` instead")
def text(self) -> bytes | None: ...
-
def root_node_with_offset(
self,
offset_bytes: int,
- offset_extent: Point,
+ offset_extent: Point | tuple[int, int],
/,
) -> Node | None: ...
-
def edit(
self,
start_byte: int,
old_end_byte: int,
new_end_byte: int,
- start_point: Point,
- old_end_point: Point,
- new_end_point: Point,
+ start_point: Point | tuple[int, int],
+ old_end_point: Point | tuple[int, int],
+ new_end_point: Point | tuple[int, int],
) -> None: ...
-
def walk(self) -> TreeCursor: ...
-
def changed_ranges(self, new_tree: Tree) -> list[Range]: ...
-
@final
class TreeCursor:
@property
def node(self) -> Node: ...
-
@property
def field_id(self) -> int | None: ...
-
@property
def field_name(self) -> str | None: ...
-
@property
def depth(self) -> int: ...
-
@property
def descendant_index(self) -> int: ...
-
def copy(self) -> TreeCursor: ...
-
def reset(self, node: Node, /) -> None: ...
-
def reset_to(self, cursor: TreeCursor, /) -> None: ...
-
def goto_first_child(self) -> bool: ...
-
def goto_last_child(self) -> bool: ...
-
def goto_parent(self) -> bool: ...
-
def goto_next_sibling(self) -> bool: ...
-
def goto_previous_sibling(self) -> bool: ...
-
def goto_descendant(self, index: int, /) -> None: ...
-
def goto_first_child_for_byte(self, byte: int, /) -> bool: ...
-
@overload
- def goto_first_child_for_point(self, point: Point, /) -> bool: ...
-
+ def goto_first_child_for_point(self, point: Point | tuple[int, int], /) -> bool: ...
@overload
@deprecated("Use `goto_first_child_for_point(point)` instead")
def goto_first_child_for_point(self, row: int, column: int, /) -> bool: ...
-
def __copy__(self) -> TreeCursor: ...
-
@final
class Parser:
def __init__(
@@ -313,33 +218,24 @@ class Parser:
language: Language | None = None,
*,
included_ranges: Sequence[Range] | None = None,
- timeout_micros: int | None = None
+ timeout_micros: int | None = None,
) -> None: ...
-
@property
def language(self) -> Language | None: ...
-
@language.setter
def language(self, language: Language) -> None: ...
-
@language.deleter
def language(self) -> None: ...
-
@property
def included_ranges(self) -> list[Range]: ...
-
@included_ranges.setter
def included_ranges(self, ranges: Sequence[Range]) -> None: ...
-
@included_ranges.deleter
def included_ranges(self) -> None: ...
-
@property
def timeout_micros(self) -> int: ...
-
@timeout_micros.setter
def timeout_micros(self, timeout: int) -> None: ...
-
@timeout_micros.deleter
def timeout_micros(self) -> None: ...
@@ -352,7 +248,6 @@ class Parser:
/,
old_tree: Tree | None = None,
) -> Tree: ...
-
@overload
@deprecated("`keep_text` will be removed")
def parse(
@@ -362,19 +257,14 @@ class Parser:
old_tree: Tree | None = None,
keep_text: bool = True,
) -> Tree: ...
-
def reset(self) -> None: ...
-
@deprecated("Use the `language` setter instead")
def set_language(self, language: Language, /) -> None: ...
-
@deprecated("Use the `included_ranges` setter instead")
def set_included_ranges(self, ranges: Sequence[Range], /) -> None: ...
-
@deprecated("Use the `timeout_micros` setter instead")
def set_timeout_micros(self, timeout: int, /) -> None: ...
-
@final
class Query:
def __init__(self, language: Language, source: str) -> None: ...
@@ -386,76 +276,67 @@ class Query:
self,
node: Node,
*,
- start_point: Point = Point(0, 0),
- end_point: Point = Point(_UINT32_MAX, _UINT32_MAX),
+ start_point: Point | tuple[int, int] = Point(0, 0),
+ end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX),
start_byte: int = 0,
end_byte: int = _UINT32_MAX,
) -> list[tuple[Node, str]]: ...
-
def matches(
self,
node: Node,
*,
- start_point: Point = Point(0, 0),
- end_point: Point = Point(_UINT32_MAX, _UINT32_MAX),
+ start_point: Point | tuple[int, int] = Point(0, 0),
+ end_point: Point | tuple[int, int] = Point(_UINT32_MAX, _UINT32_MAX),
start_byte: int = 0,
end_byte: int = _UINT32_MAX,
) -> list[tuple[int, dict[str, Node | list[Node]]]]: ...
-
@final
class LookaheadIterator(Iterator[int]):
@property
def language(self) -> Language: ...
-
@property
def current_symbol(self) -> int: ...
-
@property
def current_symbol_name(self) -> str: ...
-
@deprecated("Use `reset_state()` instead")
def reset(self, language: _Ptr, state: int, /) -> None: ...
# TODO(0.24): rename to reset
def reset_state(self, state: int, language: Language | None = None) -> None: ...
-
def iter_names(self) -> Iterator[str]: ...
-
def __next__(self) -> int: ...
-
@final
class Range:
def __init__(
self,
- start_point: Point,
- end_point: Point,
+ start_point: Point | tuple[int, int],
+ end_point: Point | tuple[int, int],
start_byte: int,
end_byte: int,
) -> None: ...
-
@property
- def start_point(self): Point
+ def start_point(self):
+ Point
@property
- def end_point(self): Point
+ def end_point(self):
+ Point
@property
- def start_byte(self): int
+ def start_byte(self):
+ int
@property
- def end_byte(self): int
+ def end_byte(self):
+ int
def __eq__(self, other: Any, /) -> bool: ...
-
def __ne__(self, other: Any, /) -> bool: ...
-
def __repr__(self) -> str: ...
-
def __hash__(self) -> int: ...
-
LANGUAGE_VERSION: Final[int]
MIN_COMPATIBLE_LANGUAGE_VERSION: Final[int]
diff --git a/tree_sitter/binding/language.c b/tree_sitter/binding/language.c
index e74af1a7..94172dd6 100644
--- a/tree_sitter/binding/language.c
+++ b/tree_sitter/binding/language.c
@@ -16,7 +16,7 @@ static void segfault_handler(int signal) {
TSLanguage *language_check_pointer(void *ptr) {
PyOS_setsig(SIGSEGV, segfault_handler);
if (!setjmp(segv_jmp)) {
- (void)ts_language_version(ptr);
+ __attribute__((unused)) volatile uint32_t version = ts_language_version((TSLanguage *)ptr);
} else {
PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer");
}
@@ -29,7 +29,7 @@ TSLanguage *language_check_pointer(void *ptr) {
// HACK: recover from invalid pointer using SEH (Windows)
TSLanguage *language_check_pointer(void *ptr) {
__try {
- (void)ts_language_version(ptr);
+ volatile uint32_t version = ts_language_version((TSLanguage *)ptr);
} __except (GetExceptionCode() == EXCEPTION_ACCESS_VIOLATION ? EXCEPTION_EXECUTE_HANDLER
: EXCEPTION_CONTINUE_SEARCH) {
PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer");
@@ -43,7 +43,7 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
if (!PyArg_ParseTuple(args, "O:__init__", &language)) {
return -1;
}
- if (PyLong_AsLong(language) < 1) {
+ if (PyLong_AsSsize_t(language) < 1) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError, "language ID must be positive");
}
@@ -61,7 +61,10 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
return 0;
}
-void language_dealloc(Language *self) { Py_TYPE(self)->tp_free(self); }
+void language_dealloc(Language *self) {
+ ts_language_delete(self->language);
+ Py_TYPE(self)->tp_free(self);
+}
PyObject *language_repr(Language *self) {
#if HAS_LANGUAGE_NAMES
@@ -88,7 +91,7 @@ PyObject *language_compare(Language *self, PyObject *other, int op) {
Language *lang = (Language *)other;
bool result = (Py_uintptr_t)self->language == (Py_uintptr_t)lang->language;
- return PyBool_FromLong(result & (op == Py_EQ));
+ return PyBool_FromLong(result ^ (op == Py_NE));
}
#if HAS_LANGUAGE_NAMES
@@ -128,7 +131,7 @@ PyObject *language_node_kind_for_id(Language *self, PyObject *args) {
PyObject *language_id_for_node_kind(Language *self, PyObject *args) {
char *kind;
Py_ssize_t length;
- bool named;
+ int named;
if (!PyArg_ParseTuple(args, "s#p:id_for_node_kind", &kind, &length, &named)) {
return NULL;
}
@@ -205,8 +208,9 @@ PyObject *language_lookahead_iterator(Language *self, PyObject *args) {
if (iter == NULL) {
return NULL;
}
- iter->lookahead_iterator = lookahead_iterator;
+ Py_INCREF(self);
iter->language = (PyObject *)self;
+ iter->lookahead_iterator = lookahead_iterator;
return PyObject_Init((PyObject *)iter, state->lookahead_iterator_type);
}
diff --git a/tree_sitter/binding/lookahead_iterator.c b/tree_sitter/binding/lookahead_iterator.c
index 84fcc2ac..b4f4e571 100644
--- a/tree_sitter/binding/lookahead_iterator.c
+++ b/tree_sitter/binding/lookahead_iterator.c
@@ -24,11 +24,9 @@ PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *Py_UNUS
}
language->language = language_id;
language->version = ts_language_version(language->language);
- PyObject *obj = PyObject_Init((PyObject *)language, state->language_type);
- Py_XSETREF(self->language, obj);
- } else {
- Py_INCREF(self->language);
+ self->language = PyObject_Init((PyObject *)language, state->language_type);
}
+ Py_INCREF(self->language);
return self->language;
}
@@ -65,7 +63,7 @@ PyObject *lookahead_iterator_reset(LookaheadIterator *self, PyObject *args) {
PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args,
PyObject *kwargs) {
uint16_t state_id;
- PyObject *language_obj;
+ PyObject *language_obj = NULL;
ModuleState *state = GET_MODULE_STATE(self);
char *keywords[] = {"state", "language", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "H|O!:reset_state", keywords, &state_id,
diff --git a/tree_sitter/binding/node.c b/tree_sitter/binding/node.c
index 0e5d5328..7a56258f 100644
--- a/tree_sitter/binding/node.c
+++ b/tree_sitter/binding/node.c
@@ -43,7 +43,7 @@ PyObject *node_compare(Node *self, PyObject *other, int op) {
}
bool result = ts_node_eq(self->node, ((Node *)other)->node);
- return PyBool_FromLong(result & (op == Py_EQ));
+ return PyBool_FromLong(result ^ (op == Py_NE));
}
PyObject *node_sexp(Node *self, PyObject *Py_UNUSED(args)) {
@@ -59,9 +59,11 @@ PyObject *node_walk(Node *self, PyObject *Py_UNUSED(args)) {
if (tree_cursor == NULL) {
return NULL;
}
- tree_cursor->cursor = ts_tree_cursor_new(self->node);
+
Py_INCREF(self->tree);
tree_cursor->tree = self->tree;
+ tree_cursor->node = NULL;
+ tree_cursor->cursor = ts_tree_cursor_new(self->node);
return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type);
}
@@ -436,7 +438,7 @@ PyObject *node_get_named_children(Node *self, void *payload) {
PyObject *child = PyList_GetItem(self->children, i);
if (ts_node_is_named(((Node *)child)->node)) {
Py_INCREF(child);
- if (PyList_SetItem(result, ++j, child)) {
+ if (PyList_SetItem(result, j++, child)) {
Py_DECREF(result);
return NULL;
}
diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c
index ce47f280..ab0c4e10 100644
--- a/tree_sitter/binding/parser.c
+++ b/tree_sitter/binding/parser.c
@@ -92,7 +92,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo
// Store return value in payload so its reference count can be decremented and
// return string representation of bytes.
wrapper_payload->previous_return_value = rv;
- *bytes_read = PyBytes_Size(rv);
+ *bytes_read = (uint32_t)PyBytes_Size(rv);
return PyBytes_AsString(rv);
}
@@ -100,7 +100,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
ModuleState *state = GET_MODULE_STATE(self);
PyObject *source_or_callback;
PyObject *old_tree_obj = NULL;
- bool keep_text = true;
+ int keep_text = 1;
char *keywords[] = {"", "old_tree", "keep_text", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback,
state->tree_type, &old_tree_obj, &keep_text)) {
@@ -114,7 +114,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) {
// parse a buffer
const char *source_bytes = (const char *)source_view.buf;
- size_t length = source_view.len;
+ uint32_t length = (uint32_t)source_view.len;
new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length);
PyBuffer_Release(&source_view);
} else if (PyCallable_Check(source_or_callback)) {
@@ -137,7 +137,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
source_or_callback = Py_None;
keep_text = 0;
} else {
- PyErr_SetString(PyExc_TypeError, "source must be a byte buffer or a callable");
+ PyErr_SetString(PyExc_TypeError, "source must be a bytestring or a callable");
return NULL;
}
@@ -172,27 +172,20 @@ int parser_set_timeout_micros(Parser *self, PyObject *arg, void *Py_UNUSED(paylo
ts_parser_set_timeout_micros(self->parser, 0);
return 0;
}
- if (!PyLong_CheckExact(arg)) {
+ if (!PyLong_Check(arg)) {
PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s",
arg->ob_type->tp_name);
return -1;
}
- long timeout = PyLong_AsLong(arg);
- if (timeout < 0) {
- PyErr_SetString(PyExc_ValueError, "'timeout_micros' must be a positive integer");
- return -1;
- }
-
- ts_parser_set_timeout_micros(self->parser, timeout);
+ ts_parser_set_timeout_micros(self->parser, PyLong_AsUnsignedLong(arg));
return 0;
}
PyObject *parser_set_timeout_micros_old(Parser *self, PyObject *arg) {
- if (PyLong_AsLong(arg) < 0) {
- if (!PyErr_Occurred()) {
- PyErr_SetString(PyExc_ValueError, "'timeout_micros' must be a positive integer");
- }
+ if (!PyLong_Check(arg)) {
+ PyErr_Format(PyExc_TypeError, "'timeout_micros' must be assigned an int, not %s",
+ arg->ob_type->tp_name);
return NULL;
}
if (REPLACE("Parser.set_timeout_micros()", "the timeout_micros setter") < 0) {
@@ -235,7 +228,7 @@ int parser_set_included_ranges(Parser *self, PyObject *arg, void *Py_UNUSED(payl
return -1;
}
- uint32_t length = PyList_Size(arg);
+ uint32_t length = (uint32_t)PyList_Size(arg);
TSRange *ranges = PyMem_Calloc(length, sizeof(TSRange));
if (!ranges) {
PyErr_Format(PyExc_MemoryError, "Failed to allocate memory for ranges of length %u",
@@ -315,8 +308,8 @@ int parser_set_language(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) {
return -1;
}
- self->language = (PyObject *)language;
- Py_INCREF(self->language);
+ Py_INCREF(language);
+ Py_XSETREF(self->language, (PyObject *)language);
return 0;
}
diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c
index 89b39aae..b43f757b 100644
--- a/tree_sitter/binding/query.c
+++ b/tree_sitter/binding/query.c
@@ -490,10 +490,10 @@ PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) {
"node", "start_point", "end_point", "start_byte", "end_byte", NULL,
};
PyObject *node_obj;
- TSPoint start_point = {.row = 0, .column = 0};
- TSPoint end_point = {.row = UINT32_MAX, .column = UINT32_MAX};
- unsigned start_byte = 0, end_byte = UINT32_MAX;
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|$(II)(II)II:matches", keywords,
+ TSPoint start_point = {0, 0};
+ TSPoint end_point = {UINT32_MAX, UINT32_MAX};
+ uint32_t start_byte = 0, end_byte = UINT32_MAX;
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!|$(II)(II)II:matches", keywords,
state->node_type, &node_obj, &start_point.row,
&start_point.column, &end_point.row, &end_point.column,
&start_byte, &end_byte)) {
diff --git a/tree_sitter/binding/range.c b/tree_sitter/binding/range.c
index fc52fb2b..a44c4f01 100644
--- a/tree_sitter/binding/range.c
+++ b/tree_sitter/binding/range.c
@@ -33,13 +33,35 @@ PyObject *range_repr(Range *self) {
Py_hash_t range_hash(Range *self) {
// FIXME: replace with an efficient integer hashing algorithm
- PyObject *row_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_point.row),
+ PyObject *row_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.row),
PyLong_FromLong(self->range.end_point.row));
- PyObject *col_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_point.column),
- PyLong_FromLong(self->range.end_point.column));
- PyObject *bytes_tuple = PyTuple_Pack(2, PyLong_FromLong(self->range.start_byte),
- PyLong_FromLong(self->range.end_byte));
+ if (!row_tuple) {
+ return NULL;
+ }
+
+ PyObject *col_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_point.column),
+ PyLong_FromSize_t(self->range.end_point.column));
+ if (!col_tuple) {
+ Py_DECREF(row_tuple);
+ return NULL;
+ }
+
+ PyObject *bytes_tuple = PyTuple_Pack(2, PyLong_FromSize_t(self->range.start_byte),
+ PyLong_FromSize_t(self->range.end_byte));
+ if (!bytes_tuple) {
+ Py_DECREF(row_tuple);
+ Py_DECREF(col_tuple);
+ return NULL;
+ }
+
PyObject *range_tuple = PyTuple_Pack(3, row_tuple, col_tuple, bytes_tuple);
+ if (!range_tuple) {
+ Py_DECREF(row_tuple);
+ Py_DECREF(col_tuple);
+ Py_DECREF(bytes_tuple);
+ return NULL;
+ }
+
Py_hash_t hash = PyObject_Hash(range_tuple);
Py_DECREF(range_tuple);
@@ -61,7 +83,7 @@ PyObject *range_compare(Range *self, PyObject *other, int op) {
(self->range.end_point.row == range->range.end_point.row) &&
(self->range.end_point.column == range->range.end_point.column) &&
(self->range.end_byte == range->range.end_byte));
- return PyBool_FromLong(result & (op == Py_EQ));
+ return PyBool_FromLong(result ^ (op == Py_NE));
}
PyObject *range_get_start_point(Range *self, void *Py_UNUSED(payload)) {
diff --git a/tree_sitter/binding/tree.c b/tree_sitter/binding/tree.c
index 97e1943b..1c03e2dc 100644
--- a/tree_sitter/binding/tree.c
+++ b/tree_sitter/binding/tree.c
@@ -45,9 +45,11 @@ PyObject *tree_walk(Tree *self, PyObject *Py_UNUSED(args)) {
if (tree_cursor == NULL) {
return NULL;
}
- tree_cursor->cursor = ts_tree_cursor_new(ts_tree_root_node(self->tree));
+
Py_INCREF(self);
tree_cursor->tree = (PyObject *)self;
+ tree_cursor->node = NULL;
+ tree_cursor->cursor = ts_tree_cursor_new(ts_tree_root_node(self->tree));
return PyObject_Init((PyObject *)tree_cursor, state->tree_cursor_type);
}
diff --git a/tree_sitter/binding/tree_cursor.c b/tree_sitter/binding/tree_cursor.c
index f2f93883..7dfbb8df 100644
--- a/tree_sitter/binding/tree_cursor.c
+++ b/tree_sitter/binding/tree_cursor.c
@@ -9,15 +9,14 @@ void tree_cursor_dealloc(TreeCursor *self) {
}
PyObject *tree_cursor_get_node(TreeCursor *self, void *Py_UNUSED(payload)) {
- ModuleState *state = GET_MODULE_STATE(self);
- if (!self->node) {
+ if (self->node == NULL) {
TSNode current_node = ts_tree_cursor_current_node(&self->cursor);
if (ts_node_is_null(current_node)) {
Py_RETURN_NONE;
}
- return node_new_internal(state, current_node, self->tree);
+ ModuleState *state = GET_MODULE_STATE(self);
+ self->node = node_new_internal(state, current_node, self->tree);
}
-
Py_INCREF(self->node);
return self->node;
}
@@ -174,9 +173,10 @@ PyObject *tree_cursor_copy(PyObject *self, PyObject *Py_UNUSED(args)) {
if (copied == NULL) {
return NULL;
}
- copied->cursor = ts_tree_cursor_copy(&origin->cursor);
+
Py_INCREF(origin->tree);
copied->tree = origin->tree;
+ copied->cursor = ts_tree_cursor_copy(&origin->cursor);
return PyObject_Init((PyObject *)copied, state->tree_cursor_type);
}
diff --git a/tree_sitter/core b/tree_sitter/core
index 4bbaee2f..6e6dcf1c 160000
--- a/tree_sitter/core
+++ b/tree_sitter/core
@@ -1 +1 @@
-Subproject commit 4bbaee2f56d1febc5992c6ebae556d35d571a712
+Subproject commit 6e6dcf1cafb00300338b46bb4bffcd05ad99fafc