diff --git a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py index f45e71125..6f8a10b3d 100644 --- a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py +++ b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py @@ -28,14 +28,17 @@ def to_db_dict(self): @typechecked class SqlparseComparator(Comparator): def compare(self, solution: str, submission: str) -> list[Error]: - solution_parsed = sqlparse.parse(solution) + solution_parsed = sqlparse.parse(self._preprocess(solution)) dfs = SqlParserDfs() dfs.visit(solution_parsed) cv = SqlParserCoVisitor(dfs.dfs) - submission_parsed = sqlparse.parse(submission) + submission_parsed = sqlparse.parse(self._preprocess(submission)) cv.visit(submission_parsed) return cv.errors + def _preprocess(self, query: str) -> str: + return query.replace("\n", " ") + class SqlParseVisitor: def __init__(self): @@ -90,7 +93,7 @@ def visit(self, tokens: list[sqlparse.sql.Token]): raise ValueError('unhandled token', token) def trace_to_str_list(self) -> list[str]: - return [entry.__class__.__name__ if entry.ttype is None else entry.value for entry in self.parent_stack] + return [token_to_str(entry) for entry in self.parent_stack] class SqlParserDfs(SqlParseVisitor): @@ -114,19 +117,36 @@ def __init__(self, solution): self._i = 0 self.errors = [] + def visit(self, tokens: list[sqlparse.sql.Token]): + super().visit(tokens) + if len(self.parent_stack) == 0 and self._i < len(self._solution): + should = self._solution[self._i] + self.errors.append(Error(token_to_str(should), "EOF", [token_to_str(tokens[0])])) + def _get_should(self): + index = self._i + if index >= len(self._solution): + return None self._i += 1 - return self._solution[self._i - 1] + return self._solution[index] def recursive_visit(self, token: sqlparse.sql.Statement): should = self._get_should() - if token.__class__ != should.__class__: - self.errors.append(Error(should.__class__.__name__, token.__class__.__name__, self.trace_to_str_list())) + if should is None: + self.errors.append(Error("EOF", token_to_str(token), self.trace_to_str_list())) + elif token.__class__ != should.__class__: + self.errors.append(Error(token_to_str(should), token_to_str(token), self.trace_to_str_list())) else: super().recursive_visit(token) def visit_literal(self, token: sqlparse.tokens.Token): should = self._get_should() - if token.value != should.value: - self.errors.append(Error(should.value, token.value, self.trace_to_str_list())) + if should is None: + self.errors.append(Error("EOF", token_to_str(token), self.trace_to_str_list())) + elif token.value != should.value: + self.errors.append(Error(token_to_str(should), token_to_str(token), self.trace_to_str_list())) super().visit_literal(token) + + +def token_to_str(token: sqlparse.tokens.Token) -> str: + return token.__class__.__name__ if token.ttype is None else repr(token.value) diff --git a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py index 6e76e8ac6..6354d04a5 100644 --- a/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py +++ b/modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py @@ -13,6 +13,18 @@ def test_compare_simple(self): "SELECT username, email, password FROM users WHERE username ILIKE 'test%'") assert len(errors) == 0 + def test_compare_shorter(self): + comparator = SqlparseComparator() + errors = comparator.compare("SELECT username, email, password FROM users WHERE username ILIKE 'test%'", + "SELECT username, email, password FROM users") + assert len(errors) != 0 + + def test_compare_shorter_swap(self): + comparator = SqlparseComparator() + errors = comparator.compare("SELECT username, email, password FROM users", + "SELECT username, email, password FROM users WHERE username ILIKE 'test%'",) + assert len(errors) != 0 + def test_compare_with_and(self): comparator = SqlparseComparator() errors = comparator.compare(