Skip to content

Commit

Permalink
feat: implement Point namedtuple
Browse files Browse the repository at this point in the history
  • Loading branch information
ObserverOfTime committed Mar 15, 2024
1 parent 1360992 commit ea468a3
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 37 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ concurrency:
jobs:
build_asan:
runs-on: ubuntu-latest
timeout-minutes: 30
env:
PYTHON_VERSION: "3.10"
steps:
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"tree_sitter/binding/lookahead_names_iterator.c",
"tree_sitter/binding/node.c",
"tree_sitter/binding/parser.c",
"tree_sitter/binding/point.c",
"tree_sitter/binding/query.c",
"tree_sitter/binding/range.c",
"tree_sitter/binding/tree.c",
Expand Down
2 changes: 2 additions & 0 deletions tree_sitter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
LookaheadNamesIterator,
Node,
Parser,
Point,
Query,
Range,
Tree,
Expand All @@ -18,6 +19,7 @@
"LookaheadNamesIterator",
"Node",
"Parser",
"Point",
"Query",
"Range",
"Tree",
Expand Down
36 changes: 30 additions & 6 deletions tree_sitter/binding/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,20 @@ static int AddObjectRef(PyObject *module, const char *name, PyObject *value) {
}
#endif

static inline PyObject *import_attribute(const char *mod, const char *attr) {
PyObject *module = PyImport_ImportModule(mod);
if (module == NULL) {
return NULL;
}
PyObject *import = PyObject_GetAttrString(module, attr);
Py_DECREF(module);
return import;
}

static void module_free(void *self) {
ModuleState *state = PyModule_GetState((PyObject *)self);
ts_query_cursor_delete(state->query_cursor);
Py_XDECREF(state->point_type);
Py_XDECREF(state->tree_type);
Py_XDECREF(state->tree_cursor_type);
Py_XDECREF(state->language_type);
Expand All @@ -48,6 +59,7 @@ static void module_free(void *self) {
Py_XDECREF(state->capture_match_string_type);
Py_XDECREF(state->lookahead_iterator_type);
Py_XDECREF(state->re_compile);
Py_XDECREF(state->namedtuple);
}

static struct PyModuleDef module_definition = {
Expand All @@ -66,6 +78,8 @@ PyMODINIT_FUNC PyInit__binding(void) {

ModuleState *state = PyModule_GetState(module);

state->query_cursor = ts_query_cursor_new();

state->tree_type = (PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_type_spec, NULL);
state->tree_cursor_type =
(PyTypeObject *)PyType_FromModuleAndSpec(module, &tree_cursor_type_spec, NULL);
Expand All @@ -90,7 +104,6 @@ PyMODINIT_FUNC PyInit__binding(void) {
state->lookahead_names_iterator_type =
(PyTypeObject *)PyType_FromModuleAndSpec(module, &lookahead_names_iterator_type_spec, NULL);

state->query_cursor = ts_query_cursor_new();
if ((AddObjectRef(module, "Tree", (PyObject *)state->tree_type) < 0) ||
(AddObjectRef(module, "TreeCursor", (PyObject *)state->tree_cursor_type) < 0) ||
(AddObjectRef(module, "Language", (PyObject *)state->language_type) < 0) ||
Expand All @@ -112,13 +125,24 @@ PyMODINIT_FUNC PyInit__binding(void) {
goto cleanup;
}

PyObject *re_module = PyImport_ImportModule("re");
if (re_module == NULL) {
state->re_compile = import_attribute("re", "compile");
if (state->re_compile == NULL) {
goto cleanup;
}
state->re_compile = PyObject_GetAttrString(re_module, (char *)"compile");
Py_DECREF(re_module);
if (state->re_compile == NULL) {

state->namedtuple = import_attribute("collections", "namedtuple");
if (state->namedtuple == NULL) {
goto cleanup;
}

PyObject *point_args = Py_BuildValue("s[ss]", "Point", "row", "column");
PyObject *point_kwargs = PyDict_New();
PyDict_SetItemString(point_kwargs, "module", PyUnicode_FromString("tree_sitter"));
state->point_type = (PyTypeObject *)PyObject_Call(state->namedtuple, point_args, point_kwargs);
Py_DECREF(point_args);
Py_DECREF(point_kwargs);
if (state->point_type == NULL ||
AddObjectRef(module, "Point", (PyObject *)state->point_type) < 0) {
goto cleanup;
}

Expand Down
6 changes: 4 additions & 2 deletions tree_sitter/binding/node.c
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,13 @@ PyObject *node_get_range(Node *self, void *payload) {
}

PyObject *node_get_start_point(Node *self, void *payload) {
return point_new(ts_node_start_point(self->node));
TSPoint point = ts_node_start_point(self->node);
return POINT_NEW(GET_MODULE_STATE(Py_TYPE(self)), point);
}

PyObject *node_get_end_point(Node *self, void *payload) {
return point_new(ts_node_end_point(self->node));
TSPoint point = ts_node_end_point(self->node);
return POINT_NEW(GET_MODULE_STATE(Py_TYPE(self)), point);
}

PyObject *node_get_children(Node *self, void *payload) {
Expand Down
6 changes: 3 additions & 3 deletions tree_sitter/binding/parser.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "parser.h"
#include "point.h"
#include "tree.h"

PyObject *parser_new(PyTypeObject *type, PyObject *args, PyObject *kwds) {
Expand All @@ -17,7 +16,7 @@ void parser_dealloc(Parser *self) {

static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPoint position,
uint32_t *bytes_read) {
ReadWrapperPayload *wrapper_payload = payload;
ReadWrapperPayload *wrapper_payload = (ReadWrapperPayload *)payload;
PyObject *read_cb = wrapper_payload->read_cb;

// We assume that the parser only needs the return value until the next time
Expand All @@ -30,7 +29,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo

// Form arguments to callable.
PyObject *byte_offset_obj = PyLong_FromSize_t((size_t)byte_offset);
PyObject *position_obj = point_new(position);
PyObject *position_obj = POINT_NEW(wrapper_payload->state, position);
if (!position_obj || !byte_offset_obj) {
*bytes_read = 0;
return NULL;
Expand Down Expand Up @@ -98,6 +97,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
PyErr_Clear(); // clear the GetBuffer error
// parse a callable
ReadWrapperPayload payload = {
.state = state,
.read_cb = source_or_callback,
.previous_return_value = NULL,
};
Expand Down
1 change: 1 addition & 0 deletions tree_sitter/binding/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
typedef struct {
PyObject *read_cb;
PyObject *previous_return_value;
ModuleState *state;
} ReadWrapperPayload;

PyObject *parser_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
Expand Down
16 changes: 0 additions & 16 deletions tree_sitter/binding/point.c

This file was deleted.

5 changes: 0 additions & 5 deletions tree_sitter/binding/point.h

This file was deleted.

4 changes: 2 additions & 2 deletions tree_sitter/binding/range.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ PyObject *range_compare(Range *self, Range *other, int op) {
}

PyObject *range_get_start_point(Range *self, void *payload) {
return point_new(self->range.start_point);
return POINT_NEW(GET_MODULE_STATE(Py_TYPE(self)), self->range.start_point);
}

PyObject *range_get_end_point(Range *self, void *payload) {
return point_new(self->range.end_point);
return POINT_NEW(GET_MODULE_STATE(Py_TYPE(self)), self->range.end_point);
}

PyObject *range_get_start_byte(Range *self, void *payload) {
Expand Down
9 changes: 8 additions & 1 deletion tree_sitter/binding/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ typedef LookaheadIterator LookaheadNamesIterator;
typedef struct {
TSTreeCursor default_cursor;
TSQueryCursor *query_cursor;

PyObject *re_compile;
PyObject *namedtuple;

PyTypeObject *point_type;
PyTypeObject *tree_type;
PyTypeObject *tree_cursor_type;
PyTypeObject *language_type;
Expand All @@ -113,4 +116,8 @@ typedef struct {

#define GET_MODULE_STATE(type) ((ModuleState *)PyType_GetModuleState(type))

#define IS_INSTANCE(obj, type) PyObject_IsInstance((obj), (PyObject *)(GET_MODULE_STATE(Py_TYPE(obj))->type))
#define IS_INSTANCE(obj, type) \
PyObject_IsInstance((obj), (PyObject *)(GET_MODULE_STATE(Py_TYPE(obj))->type))

#define POINT_NEW(state, point) \
PyObject_CallFunction((PyObject *)(state)->point_type, "II", (point).row, (point).column)

0 comments on commit ea468a3

Please sign in to comment.