Skip to content

Commit

Permalink
fix(sql-parser): fix different query lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
Zitrone44 committed Dec 7, 2024
1 parent 99f48b6 commit f85f8bf
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
36 changes: 28 additions & 8 deletions modules/fbs-sql-checker/api/comparator/sqlparse_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ def to_db_dict(self):
@typechecked
class SqlparseComparator(Comparator):
def compare(self, solution: str, submission: str) -> list[Error]:
solution_parsed = sqlparse.parse(solution)
solution_parsed = sqlparse.parse(self._preprocess(solution))
dfs = SqlParserDfs()
dfs.visit(solution_parsed)
cv = SqlParserCoVisitor(dfs.dfs)
submission_parsed = sqlparse.parse(submission)
submission_parsed = sqlparse.parse(self._preprocess(submission))
cv.visit(submission_parsed)
return cv.errors

def _preprocess(self, query: str) -> str:
return query.replace("\n", " ")


class SqlParseVisitor:
def __init__(self):
Expand Down Expand Up @@ -90,7 +93,7 @@ def visit(self, tokens: list[sqlparse.sql.Token]):
raise ValueError('unhandled token', token)

def trace_to_str_list(self) -> list[str]:
return [entry.__class__.__name__ if entry.ttype is None else entry.value for entry in self.parent_stack]
return [token_to_str(entry) for entry in self.parent_stack]


class SqlParserDfs(SqlParseVisitor):
Expand All @@ -114,19 +117,36 @@ def __init__(self, solution):
self._i = 0
self.errors = []

def visit(self, tokens: list[sqlparse.sql.Token]):
super().visit(tokens)
if len(self.parent_stack) == 0 and self._i < len(self._solution):
should = self._solution[self._i]
self.errors.append(Error(token_to_str(should), "EOF", [token_to_str(tokens[0])]))

def _get_should(self):
index = self._i
if index >= len(self._solution):
return None
self._i += 1
return self._solution[self._i - 1]
return self._solution[index]

def recursive_visit(self, token: sqlparse.sql.Statement):
should = self._get_should()
if token.__class__ != should.__class__:
self.errors.append(Error(should.__class__.__name__, token.__class__.__name__, self.trace_to_str_list()))
if should is None:
self.errors.append(Error("EOF", token_to_str(token), self.trace_to_str_list()))
elif token.__class__ != should.__class__:
self.errors.append(Error(token_to_str(should), token_to_str(token), self.trace_to_str_list()))
else:
super().recursive_visit(token)

def visit_literal(self, token: sqlparse.tokens.Token):
should = self._get_should()
if token.value != should.value:
self.errors.append(Error(should.value, token.value, self.trace_to_str_list()))
if should is None:
self.errors.append(Error("EOF", token_to_str(token), self.trace_to_str_list()))
elif token.value != should.value:
self.errors.append(Error(token_to_str(should), token_to_str(token), self.trace_to_str_list()))
super().visit_literal(token)


def token_to_str(token: sqlparse.tokens.Token) -> str:
return token.__class__.__name__ if token.ttype is None else repr(token.value)
12 changes: 12 additions & 0 deletions modules/fbs-sql-checker/api/comparator/sqlparse_comparator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def test_compare_simple(self):
"SELECT username, email, password FROM users WHERE username ILIKE 'test%'")
assert len(errors) == 0

def test_compare_shorter(self):
comparator = SqlparseComparator()
errors = comparator.compare("SELECT username, email, password FROM users WHERE username ILIKE 'test%'",
"SELECT username, email, password FROM users")
assert len(errors) != 0

def test_compare_shorter_swap(self):
comparator = SqlparseComparator()
errors = comparator.compare("SELECT username, email, password FROM users",
"SELECT username, email, password FROM users WHERE username ILIKE 'test%'",)
assert len(errors) != 0

def test_compare_with_and(self):
comparator = SqlparseComparator()
errors = comparator.compare(
Expand Down

0 comments on commit f85f8bf

Please sign in to comment.