Skip to content

Commit

Permalink
Using base visitor in geo (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdelewski authored Jul 16, 2024
1 parent a6de33f commit f858dc7
Showing 1 changed file with 49 additions and 60 deletions.
109 changes: 49 additions & 60 deletions quesma/quesma/schema_transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,76 +155,65 @@ func (s *SchemaCheckPass) applyIpTransformations(query *model.Query) (*model.Que
return query, nil
}

type GeoIpVisitor struct {
model.ExprVisitor
tableName string
schemaRegistry schema.Registry
}

func (v *GeoIpVisitor) VisitTableRef(e model.TableRef) interface{} {
return model.NewTableRef(e.Name)
}

func (v *GeoIpVisitor) VisitSelectCommand(e model.SelectCommand) interface{} {
if v.schemaRegistry == nil {
logger.Error().Msg("Schema registry is not set")
return e
}
schemaInstance, exists := v.schemaRegistry.FindSchema(schema.TableName(v.tableName))
if !exists {
logger.Error().Msgf("Schema fot table %s not found", v.tableName)
return e
}
var groupBy []model.Expr
for _, expr := range e.GroupBy {
groupByExpr := expr.Accept(v).(model.Expr)
if col, ok := expr.(model.ColumnRef); ok {
// This checks if the column is of type point
// and if it is, it appends the lat and lon columns to the group by clause
field := schemaInstance.Fields[schema.FieldName(col.ColumnName)]
if field.Type.Name == schema.TypePoint.Name {
// TODO suffixes ::lat, ::lon are hardcoded for now
groupBy = append(groupBy, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lat"))
groupBy = append(groupBy, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lon"))
func (s *SchemaCheckPass) applyGeoTransformations(query *model.Query) (*model.Query, error) {
fromTable := getFromTable(query.TableName)
visitor := model.NewBaseVisitor()
visitor.OverrideVisitSelectCommand = func(b *model.BaseExprVisitor, e model.SelectCommand) interface{} {
if s.schemaRegistry == nil {
logger.Error().Msg("Schema registry is not set")
return e
}
schemaInstance, exists := s.schemaRegistry.FindSchema(schema.TableName(fromTable))
if !exists {
logger.Error().Msgf("Schema fot table %s not found", fromTable)
return e
}
var groupBy []model.Expr
for _, expr := range e.GroupBy {
groupByExpr := expr.Accept(b).(model.Expr)
if col, ok := expr.(model.ColumnRef); ok {
// This checks if the column is of type point
// and if it is, it appends the lat and lon columns to the group by clause
field := schemaInstance.Fields[schema.FieldName(col.ColumnName)]
if field.Type.Name == schema.TypePoint.Name {
// TODO suffixes ::lat, ::lon are hardcoded for now
groupBy = append(groupBy, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lat"))
groupBy = append(groupBy, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lon"))
} else {
groupBy = append(groupBy, groupByExpr)
}
} else {
groupBy = append(groupBy, groupByExpr)
}
} else {
groupBy = append(groupBy, groupByExpr)
}
}
var columns []model.Expr
for _, expr := range e.Columns {
if col, ok := expr.(model.ColumnRef); ok {
// This checks if the column is of type point
// and if it is, it appends the lat and lon columns to the select clause
field := schemaInstance.Fields[schema.FieldName(col.ColumnName)]
if field.Type.Name == schema.TypePoint.Name {
// TODO suffixes ::lat, ::lon are hardcoded for now
columns = append(columns, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lat"))
columns = append(columns, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lon"))
var columns []model.Expr
for _, expr := range e.Columns {
if col, ok := expr.(model.ColumnRef); ok {
// This checks if the column is of type point
// and if it is, it appends the lat and lon columns to the select clause
field := schemaInstance.Fields[schema.FieldName(col.ColumnName)]
if field.Type.Name == schema.TypePoint.Name {
// TODO suffixes ::lat, ::lon are hardcoded for now
columns = append(columns, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lat"))
columns = append(columns, model.NewColumnRef(field.InternalPropertyName.AsString()+"::lon"))
} else {
columns = append(columns, expr.Accept(b).(model.Expr))
}
} else {
columns = append(columns, expr.Accept(v).(model.Expr))
columns = append(columns, expr.Accept(b).(model.Expr))
}
} else {
columns = append(columns, expr.Accept(v).(model.Expr))
}
}

var fromClause model.Expr
if e.FromClause != nil {
fromClause = e.FromClause.Accept(v).(model.Expr)
}

return model.NewSelectCommand(columns, groupBy, e.OrderBy,
fromClause, e.WhereClause, e.LimitBy, e.Limit, e.SampleLimit, e.IsDistinct, e.CTEs)
}
var fromClause model.Expr
if e.FromClause != nil {
fromClause = e.FromClause.Accept(b).(model.Expr)
}

func (s *SchemaCheckPass) applyGeoTransformations(query *model.Query) (*model.Query, error) {
fromTable := getFromTable(query.TableName)
return model.NewSelectCommand(columns, groupBy, e.OrderBy,
fromClause, e.WhereClause, e.LimitBy, e.Limit, e.SampleLimit, e.IsDistinct, e.CTEs)
}

geoIpVisitor := &GeoIpVisitor{ExprVisitor: model.NoOpVisitor{}, tableName: fromTable, schemaRegistry: s.schemaRegistry}
expr := query.SelectCommand.Accept(geoIpVisitor)
expr := query.SelectCommand.Accept(visitor)
if _, ok := expr.(*model.SelectCommand); ok {
query.SelectCommand = *expr.(*model.SelectCommand)
}
Expand Down

0 comments on commit f858dc7

Please sign in to comment.