Skip to content

Commit

Permalink
Return information about extracted constants from pg_query_normalize
Browse files Browse the repository at this point in the history
`PgQueryNormalizeResult` now includes information about the extracted
constants, their location and extent, as well as the lexer token type
and the value of constants as interpreted by the lexer.  This makes
`pg_query_normalize` usable not just for query identification, but also
as a basis for automatic conversion of literal queries into
constant-agnostic prepared statements or other applications where
auto-parametrization of queries is useful.

This also makes `pg_query_normalize` to use `base_yylex` instead of
`core_yylex` which normalizes `USCONST` constants properly.
  • Loading branch information
elprans committed Nov 8, 2024
1 parent fce106a commit a8cb6f4
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
11 changes: 11 additions & 0 deletions pg_query.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,20 @@ typedef struct {
PgQueryError* error;
} PgQueryFingerprintResult;

typedef struct {
int location; /* start offset in query text */
int length; /* length in bytes, or -1 to ignore */
int param_id; /* Param id to use - if negative prefix, need to abs(..) and add highest_extern_param_id */
int token; /* constant token type as reported by lexer */
char *val; /* constant value */
} PgQueryNormalizeConstLocation;

typedef struct {
char* normalized_query;
PgQueryError* error;
PgQueryNormalizeConstLocation *clocations;
int clocations_count;
int highest_extern_param_id;
} PgQueryNormalizeResult;

// Postgres parser options (parse mode and GUCs that affect parsing)
Expand Down
95 changes: 76 additions & 19 deletions src/pg_query_normalize.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@
#include "pg_query_internal.h"
#include "pg_query_fingerprint.h"

#include "gramparse.h"
#include "parser/parser.h"
#include "parser/scanner.h"
#include "parser/scansup.h"
#include "mb/pg_wchar.h"
#include "nodes/nodeFuncs.h"

#include "pg_query_outfuncs.h"
#include "postgres/include/parser/scanner.h"

#include <limits.h>
#include <stdio.h>

/*
* Struct for tracking locations/lengths of constants during normalization
Expand All @@ -18,6 +23,8 @@ typedef struct pgssLocationLen
int location; /* start offset in query text */
int length; /* length in bytes, or -1 to ignore */
int param_id; /* Param id to use - if negative prefix, need to abs(..) and add highest_extern_param_id */
char *val; /* constant value */
int token; /* token type as reported by the lexer */
} pgssLocationLen;

/*
Expand Down Expand Up @@ -107,9 +114,9 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
{
pgssLocationLen *locs;
core_yyscan_t yyscanner;
core_yy_extra_type yyextra;
core_YYSTYPE yylval;
YYLTYPE yylloc;
base_yy_extra_type yyextra;
YYSTYPE yylval;
YYLTYPE yylloc = 0;
int last_loc = -1;
int i;

Expand All @@ -124,10 +131,12 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)

/* initialize the flex scanner --- should match raw_parser() */
yyscanner = scanner_init(query,
&yyextra,
&yyextra.core_yy_extra,
&ScanKeywords,
ScanKeywordTokens);

yyextra.have_lookahead = false;

/* Search for each constant, in sequence */
for (i = 0; i < jstate->clocations_count; i++)
{
Expand All @@ -142,7 +151,7 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
/* Lex tokens until we find the desired constant */
for (;;)
{
tok = core_yylex(&yylval, &yylloc, yyscanner);
tok = base_yylex(&yylval, &yylloc, yyscanner);

/* We should not hit end-of-string, but if we do, behave sanely */
if (tok == 0)
Expand All @@ -154,6 +163,8 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
*/
if (yylloc >= loc)
{
bool negative = false;

if (query[loc] == '-')
{
/*
Expand All @@ -168,29 +179,37 @@ fill_in_constant_lengths(pgssConstLocations *jstate, const char *query)
* where bar = 1" and "select * from foo where bar = -2"
* will have identical normalized query strings.
*/
tok = core_yylex(&yylval, &yylloc, yyscanner);
tok = base_yylex(&yylval, &yylloc, yyscanner);
if (tok == 0)
break; /* out of inner for-loop */
negative = true;
}

/*
* We now rely on the assumption that flex has placed a zero
* byte after the text of the current token in scanbuf.
*/
locs[i].length = (int) strlen(yyextra.scanbuf + loc);
locs[i].length = (int) strlen(yyextra.core_yy_extra.scanbuf + loc);
locs[i].token = tok;

/* Quoted string with Unicode escapes
*
* The lexer consumes trailing whitespace in order to find UESCAPE, but if there
* is no UESCAPE it has still consumed it - don't include it in constant length.
*/
if (locs[i].length > 4 && /* U&'' */
(yyextra.scanbuf[loc] == 'u' || yyextra.scanbuf[loc] == 'U') &&
yyextra.scanbuf[loc + 1] == '&' && yyextra.scanbuf[loc + 2] == '\'')
if (tok == SCONST || tok == FCONST || tok == BCONST || tok == XCONST)
{
int j = locs[i].length - 1; /* Skip the \0 */
for (; j >= 0 && scanner_isspace(yyextra.scanbuf[loc + j]); j--) {}
locs[i].length = j + 1; /* Count the \0 */
locs[i].val = palloc(strlen(yylval.core_yystype.str) + 1);
strcpy(locs[i].val, yylval.core_yystype.str);
}
else if (tok == ICONST)
{
int val = yylval.core_yystype.ival;
/* Maximum number of digits in 32-bit int is 10 */
int buf_size = 10 + 1;
if (negative)
{
buf_size += 1;
val = -val;
}

locs[i].val = (char *)palloc(buf_size * sizeof(char));
snprintf(locs[i].val, buf_size, "%d", val);
}

break; /* out of inner for-loop */
Expand Down Expand Up @@ -322,6 +341,8 @@ static void RecordConstLocation(pgssConstLocations *jstate, int location)
jstate->clocations[jstate->clocations_count].length = -1;
/* by default we assume that we need a new param ref */
jstate->clocations[jstate->clocations_count].param_id = - jstate->highest_normalize_param_id;
jstate->clocations[jstate->clocations_count].val = NULL;
jstate->clocations[jstate->clocations_count].token = 0;
jstate->highest_normalize_param_id++;
/* record param ref number if requested */
if (jstate->param_refs != NULL) {
Expand Down Expand Up @@ -599,6 +620,7 @@ PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_
List *tree;
pgssConstLocations jstate;
int query_len;
int i;

/* Parse query */
tree = raw_parser(input, RAW_PARSE_DEFAULT);
Expand All @@ -624,6 +646,28 @@ PgQueryNormalizeResult pg_query_normalize_ext(const char* input, bool normalize_

/* Normalize query */
result.normalized_query = strdup(generate_normalized_query(&jstate, 0, &query_len, PG_UTF8));

/* Report constant locations */
result.clocations_count = jstate.clocations_count;
if (result.clocations_count > 0)
{
result.clocations = (PgQueryNormalizeConstLocation *)
malloc(result.clocations_count * sizeof(PgQueryNormalizeConstLocation));

for (i = 0; i < result.clocations_count; i++)
{
pgssLocationLen jloc = jstate.clocations[i];
result.clocations[i].location = jloc.location;
result.clocations[i].length = jloc.length;
result.clocations[i].param_id = jloc.param_id;
if (jloc.val != NULL)
result.clocations[i].val = strdup(jloc.val);
else
result.clocations[i].val = NULL;
result.clocations[i].token = jloc.token;
}
}
result.highest_extern_param_id = jstate.highest_extern_param_id;
}
PG_CATCH();
{
Expand Down Expand Up @@ -664,12 +708,25 @@ PgQueryNormalizeResult pg_query_normalize_utility(const char* input)

void pg_query_free_normalize_result(PgQueryNormalizeResult result)
{
if (result.error) {
if (result.error)
{
free(result.error->message);
free(result.error->filename);
free(result.error->funcname);
free(result.error);
result.error = NULL;
}

if (result.clocations)
{
int i;
for (i = 0; i < result.clocations_count; i++)
if (result.clocations[i].val != NULL)
free(result.clocations[i].val);
free(result.clocations);
result.clocations = NULL;
}

free(result.normalized_query);
result.normalized_query = NULL;
}
2 changes: 2 additions & 0 deletions test/normalize_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ const char* tests[] = {
"CLOSE cursor_a",
"SELECT 1; ALTER USER a WITH PASSWORD 'b'",
"SELECT $1; ALTER USER a WITH PASSWORD $2",
"SELECT U&'d!0061t!+000061' UESCAPE '!'",
"SELECT $1",
};

size_t testsLength = __LINE__ - 7;

0 comments on commit a8cb6f4

Please sign in to comment.