Skip to content

Commit

Permalink
fix: no more random segfaults (hopefully)
Browse files Browse the repository at this point in the history
Plus many codebase improvements
Plus Query & Parser constructors
  • Loading branch information
ObserverOfTime committed Apr 28, 2024
1 parent 81edd6c commit 8db9475
Show file tree
Hide file tree
Showing 17 changed files with 511 additions and 590 deletions.
40 changes: 26 additions & 14 deletions tree_sitter/binding/language.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#include "language.h"
#include "lookahead_iterator.h"
#include "query.h"

#ifndef _MSC_VER
#include <setjmp.h>
Expand All @@ -20,7 +18,7 @@ TSLanguage *language_check_pointer(void *ptr) {
if (!setjmp(segv_jmp)) {
(void)ts_language_version(ptr);
} else {
PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer.");
PyErr_SetString(PyExc_RuntimeError, "Invalid TSLanguage pointer");
}
PyOS_setsig(SIGSEGV, SIG_DFL);
return PyErr_Occurred() ? NULL : (TSLanguage *)ptr;
Expand All @@ -47,7 +45,7 @@ int language_init(Language *self, PyObject *args, PyObject *Py_UNUSED(kwargs)) {
}
if (PyLong_AsLong(language) < 1) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError, "The language ID must be positive.");
PyErr_SetString(PyExc_ValueError, "language ID must be positive");
}
return -1;
}
Expand Down Expand Up @@ -121,7 +119,10 @@ PyObject *language_node_kind_for_id(Language *self, PyObject *args) {
return NULL;
}
const char *name = ts_language_symbol_name(self->language, symbol);
return name == NULL ? Py_None : PyUnicode_FromString(name);
if (name == NULL) {
Py_RETURN_NONE;
}
return PyUnicode_FromString(name);
}

PyObject *language_id_for_node_kind(Language *self, PyObject *args) {
Expand All @@ -132,7 +133,10 @@ PyObject *language_id_for_node_kind(Language *self, PyObject *args) {
return NULL;
}
TSSymbol symbol = ts_language_symbol_for_name(self->language, kind, length, named);
return symbol == 0 ? Py_None : PyLong_FromUnsignedLong(symbol);
if (symbol == 0) {
Py_RETURN_NONE;
}
return PyLong_FromUnsignedLong(symbol);
}

PyObject *language_node_kind_is_named(Language *self, PyObject *args) {
Expand All @@ -159,7 +163,10 @@ PyObject *language_field_name_for_id(Language *self, PyObject *args) {
return NULL;
}
const char *field_name = ts_language_field_name_for_id(self->language, field_id);
return field_name == NULL ? Py_None : PyUnicode_FromString(field_name);
if (field_name == NULL) {
Py_RETURN_NONE;
}
return PyUnicode_FromString(field_name);
}

PyObject *language_field_id_for_name(Language *self, PyObject *args) {
Expand All @@ -169,12 +176,14 @@ PyObject *language_field_id_for_name(Language *self, PyObject *args) {
return NULL;
}
TSFieldId field_id = ts_language_field_id_for_name(self->language, field_name, length);
return field_id == 0 ? Py_None : PyLong_FromUnsignedLong(field_id);
if (field_id == 0) {
Py_RETURN_NONE;
}
return PyLong_FromUnsignedLong(field_id);
}

PyObject *language_next_state(Language *self, PyObject *args) {
uint16_t state_id;
uint16_t symbol;
uint16_t state_id, symbol;
if (!PyArg_ParseTuple(args, "HH:next_state", &state_id, &symbol)) {
return NULL;
}
Expand All @@ -183,7 +192,6 @@ PyObject *language_next_state(Language *self, PyObject *args) {
}

PyObject *language_lookahead_iterator(Language *self, PyObject *args) {
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));
uint16_t state_id;
if (!PyArg_ParseTuple(args, "H:lookahead_iterator", &state_id)) {
return NULL;
Expand All @@ -192,17 +200,21 @@ PyObject *language_lookahead_iterator(Language *self, PyObject *args) {
if (lookahead_iterator == NULL) {
Py_RETURN_NONE;
}
return lookahead_iterator_new_internal(state, lookahead_iterator);
ModuleState *state = GET_MODULE_STATE(self);
LookaheadIterator* iter = PyObject_New(LookaheadIterator, state->lookahead_iterator_type);
iter->lookahead_iterator = lookahead_iterator;
iter->language = (PyObject *)self;
return PyObject_Init((PyObject *)iter, state->lookahead_iterator_type);
}

PyObject *language_query(Language *self, PyObject *args) {
ModuleState *state = GET_MODULE_STATE(self);
char *source;
Py_ssize_t length;
if (!PyArg_ParseTuple(args, "s#:query", &source, &length)) {
return NULL;
}
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));
return query_new_internal(state, self->language, source, length);
return PyObject_CallFunction((PyObject *)state->query_type, "Os#", self, source, length);
}

static PyMethodDef language_methods[] = {
Expand Down
46 changes: 17 additions & 29 deletions tree_sitter/binding/lookahead_iterator.c
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
#include "lookahead_iterator.h"
#include "language.h"
#include "lookahead_names_iterator.h"

PyObject *lookahead_iterator_new_internal(ModuleState *state,
TSLookaheadIterator *lookahead_iterator) {
LookaheadIterator *self = (LookaheadIterator *)state->lookahead_iterator_type->tp_alloc(
state->lookahead_iterator_type, 0);
if (self != NULL) {
self->lookahead_iterator = lookahead_iterator;
self->language = NULL;
}
return (PyObject *)self;
}

void lookahead_iterator_dealloc(LookaheadIterator *self) {
if (self->lookahead_iterator) {
Expand All @@ -25,26 +13,27 @@ PyObject *lookahead_iterator_repr(LookaheadIterator *self) {
return PyUnicode_FromFormat("<LookaheadIterator %p>", self->lookahead_iterator);
}

PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *payload) {
PyObject *lookahead_iterator_get_language(LookaheadIterator *self, void *Py_UNUSED(payload)) {
TSLanguage *language_id = (TSLanguage *)ts_lookahead_iterator_language(self->lookahead_iterator);
if (self->language == NULL || ((Language *)self->language)->language != language_id) {
ModuleState *state = GET_MODULE_STATE(Py_TYPE(self));
ModuleState *state = GET_MODULE_STATE(self);
Language* language = PyObject_New(Language, state->language_type);
language->language = language_id;
language->version = ts_language_version(language->language);
Py_XSETREF(self->language, PyObject_Init((PyObject *)language, state->language_type));
PyObject *obj = PyObject_Init((PyObject *)language, state->language_type);
Py_XSETREF(self->language, obj);
} else {
Py_INCREF(self->language);
}
return self->language;
}

PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *payload) {
PyObject *lookahead_iterator_get_current_symbol(LookaheadIterator *self, void *Py_UNUSED(payload)) {
TSSymbol symbol = ts_lookahead_iterator_current_symbol(self->lookahead_iterator);
return PyLong_FromUnsignedLong(symbol);
}

PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, void *payload) {
PyObject *lookahead_iterator_get_current_symbol_name(LookaheadIterator *self, void *Py_UNUSED(payload)) {
const char *name = ts_lookahead_iterator_current_symbol_name(self->lookahead_iterator);
return PyUnicode_FromString(name);
}
Expand Down Expand Up @@ -72,22 +61,17 @@ PyObject *lookahead_iterator_reset_state(LookaheadIterator *self, PyObject *args
PyObject *kwargs) {
uint16_t state_id;
PyObject *language_obj;
ModuleState *state = GET_MODULE_STATE(self);
char *keywords[] = {"state", "language", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "H|O:reset_state", keywords, &state_id,
&language_obj)) {
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "H|O!:reset_state", keywords, &state_id,
state->language_type, &language_obj)) {
return NULL;
}

bool result;
if (language_obj == NULL) {
result = ts_lookahead_iterator_reset_state(self->lookahead_iterator, state_id);
} else {
if (!IS_INSTANCE(language_obj, language_type)) {
PyErr_Format(PyExc_TypeError,
"the 'language' argument must be a Language object, not '%s'",
language_obj->ob_type->tp_name);
return NULL;
}
TSLanguage *language_id = ((Language *)language_obj)->language;
result = ts_lookahead_iterator_reset(self->lookahead_iterator, language_id, state_id);
}
Expand All @@ -104,12 +88,15 @@ PyObject *lookahead_iterator_next(LookaheadIterator *self) {
PyErr_SetNone(PyExc_StopIteration);
return NULL;
}
return PyLong_FromUnsignedLong(ts_lookahead_iterator_current_symbol(self->lookahead_iterator));
TSSymbol symbol = ts_lookahead_iterator_current_symbol(self->lookahead_iterator);
return PyLong_FromUnsignedLong(symbol);
}

PyObject *lookahead_iterator_names_iterator(LookaheadIterator *self) {
return lookahead_names_iterator_new_internal(PyType_GetModuleState(Py_TYPE(self)),
self->lookahead_iterator);
ModuleState *state = GET_MODULE_STATE(self);
LookaheadNamesIterator* iter = PyObject_New(LookaheadNamesIterator, state->lookahead_names_iterator_type);
iter->lookahead_iterator = self->lookahead_iterator;
return PyObject_Init((PyObject *)iter, state->lookahead_names_iterator_type);
}

static PyGetSetDef lookahead_iterator_accessors[] = {
Expand Down Expand Up @@ -146,6 +133,7 @@ static PyMethodDef lookahead_iterator_methods[] = {

static PyType_Slot lookahead_iterator_type_slots[] = {
{Py_tp_doc, "An iterator over the possible syntax nodes that could come next."},
{Py_tp_new, NULL},
{Py_tp_dealloc, lookahead_iterator_dealloc},
{Py_tp_repr, lookahead_iterator_repr},
{Py_tp_getset, lookahead_iterator_accessors},
Expand All @@ -159,6 +147,6 @@ PyType_Spec lookahead_iterator_type_spec = {
.name = "tree_sitter.LookaheadIterator",
.basicsize = sizeof(LookaheadIterator),
.itemsize = 0,
.flags = Py_TPFLAGS_DEFAULT,
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION,
.slots = lookahead_iterator_type_slots,
};
3 changes: 0 additions & 3 deletions tree_sitter/binding/lookahead_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

#include "types.h"

PyObject *lookahead_iterator_new_internal(ModuleState *state,
TSLookaheadIterator *lookahead_iterator);

void lookahead_iterator_dealloc(LookaheadIterator *self);

PyObject *lookahead_iterator_repr(LookaheadIterator *self);
Expand Down
15 changes: 2 additions & 13 deletions tree_sitter/binding/lookahead_names_iterator.c
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
#include "lookahead_names_iterator.h"

PyObject *lookahead_names_iterator_new_internal(ModuleState *state,
TSLookaheadIterator *lookahead_iterator) {
LookaheadNamesIterator *self =
(LookaheadNamesIterator *)state->lookahead_names_iterator_type->tp_alloc(
state->lookahead_names_iterator_type, 0);
if (self == NULL) {
return NULL;
}
self->lookahead_iterator = lookahead_iterator;
return (PyObject *)self;
}

PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self) {
return PyUnicode_FromFormat("<LookaheadNamesIterator %p>", self->lookahead_iterator);
}
Expand All @@ -36,6 +24,7 @@ PyObject *lookahead_names_iterator_next(LookaheadNamesIterator *self) {

static PyType_Slot lookahead_names_iterator_type_slots[] = {
{Py_tp_doc, "An iterator over the possible syntax nodes that could come next."},
{Py_tp_new, NULL},
{Py_tp_dealloc, lookahead_names_iterator_dealloc},
{Py_tp_repr, lookahead_names_iterator_repr},
{Py_tp_iter, lookahead_names_iterator_iter},
Expand All @@ -47,6 +36,6 @@ PyType_Spec lookahead_names_iterator_type_spec = {
.name = "tree_sitter.LookaheadNamesIterator",
.basicsize = sizeof(LookaheadNamesIterator),
.itemsize = 0,
.flags = Py_TPFLAGS_DEFAULT,
.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_DISALLOW_INSTANTIATION,
.slots = lookahead_names_iterator_type_slots,
};
3 changes: 0 additions & 3 deletions tree_sitter/binding/lookahead_names_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

#include "types.h"

PyObject *lookahead_names_iterator_new_internal(ModuleState *state,
TSLookaheadIterator *lookahead_iterator);

PyObject *lookahead_names_iterator_repr(LookaheadNamesIterator *self);

void lookahead_names_iterator_dealloc(LookaheadNamesIterator *self);
Expand Down
Loading

0 comments on commit 8db9475

Please sign in to comment.