From 268caec3b184cd263bc1de12888f65a1cbce92e7 Mon Sep 17 00:00:00 2001 From: Emil Valeev Date: Wed, 11 Oct 2023 19:59:33 +0600 Subject: [PATCH] refactor(compiler): minor improvements --- examples/005_add_two_numbers_1.neva | 4 +- internal/compiler/analyzer/analyzer.go | 41 +++++++------- internal/compiler/analyzer/component.go | 75 ++++++++++++++----------- internal/compiler/compiler.go | 21 +++---- internal/compiler/parser/parser.go | 23 +++++++- internal/runtime/funcs/funcs.go | 7 +-- 6 files changed, 97 insertions(+), 74 deletions(-) diff --git a/examples/005_add_two_numbers_1.neva b/examples/005_add_two_numbers_1.neva index 88608eb2..7b584ff2 100644 --- a/examples/005_add_two_numbers_1.neva +++ b/examples/005_add_two_numbers_1.neva @@ -17,8 +17,8 @@ components { print tmp.Print } net { - in.enter -> read1.in.sig - read1.out.v -> parse1.in.v + in.enter -> read1.in.sig + read1.out.v -> parse1.in.v parse1.out.v -> { add.in.a read2.in.sig diff --git a/internal/compiler/analyzer/analyzer.go b/internal/compiler/analyzer/analyzer.go index 0c741f27..4498fe45 100644 --- a/internal/compiler/analyzer/analyzer.go +++ b/internal/compiler/analyzer/analyzer.go @@ -7,6 +7,7 @@ import ( "github.com/nevalang/neva/internal/compiler/src" ts "github.com/nevalang/neva/pkg/typesystem" + "golang.org/x/exp/maps" ) type Analyzer struct { @@ -14,43 +15,41 @@ type Analyzer struct { } // Analyze method formats error from a.analyze so end-user can easily understand what's wrong. -func (a Analyzer) Analyze(prog src.Program) error { - return a.analyze(prog) -} - -var ( - ErrEmptyProgram = errors.New("empty program") - ErrMainPkgNotFound = errors.New("main package not found") - ErrEmptyPkg = errors.New("package must not be empty") - ErrUnknownEntityKind = errors.New("unknown entity kind") -) - -// analyze returns error if program is invalid. It also modifies program by resolving types. -func (a Analyzer) analyze(prog src.Program) error { +func (a Analyzer) Analyze(prog src.Program) (src.Program, error) { if len(prog) == 0 { - return ErrEmptyProgram + return src.Program{}, ErrEmptyProgram } mainPkg, ok := prog["main"] if !ok { - return ErrMainPkgNotFound + return src.Program{}, ErrMainPkgNotFound } if err := a.mainSpecificPkgValidation(mainPkg, prog); err != nil { - return fmt.Errorf("main specific pkg validation: %w", err) + return src.Program{}, fmt.Errorf("main specific pkg validation: %w", err) } - for pkgName := range prog { - resolvedPkg, err := a.analyzePkg(pkgName, prog) + var progCopy src.Program + maps.Copy(progCopy, prog) + + for pkgName := range progCopy { + resolvedPkg, err := a.analyzePkg(pkgName, progCopy) if err != nil { - return fmt.Errorf("analyze pkg: %v: %w", pkgName, err) + return src.Program{}, fmt.Errorf("analyze pkg: %v: %w", pkgName, err) } - prog[pkgName] = resolvedPkg + progCopy[pkgName] = resolvedPkg } - return nil + return progCopy, nil } +var ( + ErrEmptyProgram = errors.New("empty program") + ErrMainPkgNotFound = errors.New("main package not found") + ErrEmptyPkg = errors.New("package must not be empty") + ErrUnknownEntityKind = errors.New("unknown entity kind") +) + // TODO check that there's no 2 entities with the same name // and that there's no unused entities. func (a Analyzer) analyzePkg(pkgName string, prog src.Program) (src.Package, error) { diff --git a/internal/compiler/analyzer/component.go b/internal/compiler/analyzer/component.go index 384bd98c..a0245b7e 100644 --- a/internal/compiler/analyzer/component.go +++ b/internal/compiler/analyzer/component.go @@ -22,12 +22,12 @@ func (a Analyzer) analyzeComponent( return src.Component{}, fmt.Errorf("analyze interface: %w", err) } - resolvedNodes, err := a.analyzeComponentNodes(comp.Nodes, scope) + resolvedNodes, nodesIfaces, err := a.analyzeComponentNodes(comp.Nodes, scope) if err != nil { return src.Component{}, fmt.Errorf("analyze component nodes: %w", err) } - resolvedNet, err := a.analyzeComponentNet(comp.Net, resolvedInterface, resolvedNodes, scope) + resolvedNet, err := a.analyzeComponentNet(comp.Net, resolvedInterface, resolvedNodes, nodesIfaces, scope) if err != nil { return src.Component{}, fmt.Errorf("analyze component network: %w", err) } @@ -39,16 +39,21 @@ func (a Analyzer) analyzeComponent( }, nil } -func (a Analyzer) analyzeComponentNodes(nodes map[string]src.Node, scope src.Scope) (map[string]src.Node, error) { +func (a Analyzer) analyzeComponentNodes( + nodes map[string]src.Node, + scope src.Scope, +) (map[string]src.Node, map[string]src.Interface, error) { resolvedNodes := make(map[string]src.Node, len(nodes)) + nodesIfaces := make(map[string]src.Interface, len(nodes)) for name, node := range nodes { - resolvedNode, err := a.analyzeComponentNode(node, scope) + resolvedNode, iface, err := a.analyzeComponentNode(node, scope) if err != nil { - return nil, fmt.Errorf("analyze node: %w", err) + return nil, nil, fmt.Errorf("analyze node: %w", err) } + nodesIfaces[name] = iface resolvedNodes[name] = resolvedNode } - return resolvedNodes, nil + return resolvedNodes, nodesIfaces, nil } var ( @@ -57,50 +62,50 @@ var ( ErrNodeInterfaceDI = errors.New("interface node cannot have dependency injection") ) -func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node, error) { +func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node, src.Interface, error) { entity, _, err := scope.Entity(node.EntityRef) if err != nil { - return src.Node{}, fmt.Errorf("entity: %w", err) + return src.Node{}, src.Interface{}, fmt.Errorf("entity: %w", err) } if entity.Kind != src.ComponentEntity && entity.Kind != src.InterfaceEntity { - return src.Node{}, fmt.Errorf("%w: %v", ErrNodeWrongEntity, entity.Kind) + return src.Node{}, src.Interface{}, fmt.Errorf("%w: %v", ErrNodeWrongEntity, entity.Kind) } - var compInterface src.Interface + var iface src.Interface if entity.Kind == src.ComponentEntity { - compInterface = entity.Component.Interface + iface = entity.Component.Interface } else { if node.ComponentDI != nil { - return src.Node{}, ErrNodeInterfaceDI + return src.Node{}, src.Interface{}, ErrNodeInterfaceDI } - compInterface = entity.Interface + iface = entity.Interface } - if len(node.TypeArgs) != len(compInterface.TypeParams) { - return src.Node{}, fmt.Errorf( + if len(node.TypeArgs) != len(iface.TypeParams) { + return src.Node{}, src.Interface{}, fmt.Errorf( "%w: want %v, got %v", - ErrNodeTypeArgsCountMismatch, compInterface.TypeParams, node.TypeArgs, + ErrNodeTypeArgsCountMismatch, iface.TypeParams, node.TypeArgs, ) } - resolvedArgs, _, err := a.resolver.ResolveFrame(node.TypeArgs, compInterface.TypeParams, scope) + resolvedArgs, _, err := a.resolver.ResolveFrame(node.TypeArgs, iface.TypeParams, scope) if err != nil { - return src.Node{}, fmt.Errorf("resolve args: %w", err) + return src.Node{}, src.Interface{}, fmt.Errorf("resolve args: %w", err) } if node.ComponentDI == nil { return src.Node{ EntityRef: node.EntityRef, TypeArgs: resolvedArgs, - }, nil + }, iface, nil } resolvedComponentDI := make(map[string]src.Node, len(node.ComponentDI)) for depName, depNode := range node.ComponentDI { - resolvedDep, err := a.analyzeComponentNode(depNode, scope) + resolvedDep, _, err := a.analyzeComponentNode(depNode, scope) if err != nil { - return src.Node{}, fmt.Errorf("analyze dependency node: %w", err) + return src.Node{}, src.Interface{}, fmt.Errorf("analyze dependency node: %w", err) } resolvedComponentDI[depName] = resolvedDep } @@ -109,30 +114,34 @@ func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node EntityRef: node.EntityRef, TypeArgs: resolvedArgs, ComponentDI: resolvedComponentDI, - }, nil + }, iface, nil } func (a Analyzer) analyzeComponentNet( net []src.Connection, compInterface src.Interface, nodes map[string]src.Node, + nodesIfaces map[string]src.Interface, scope src.Scope, ) ([]src.Connection, error) { for _, conn := range net { - senderType, err := a.getSenderType(conn.SenderSide, compInterface.IO.In, nodes, scope) + senderType, err := a.getSenderType(conn.SenderSide, compInterface.IO.In, nodes, nodesIfaces, scope) if err != nil { return nil, fmt.Errorf("get sender type: %w", err) } + for _, receiver := range conn.ReceiverSides { - receiverType, err := a.getReceiverType(receiver, compInterface.IO.Out, nodes, scope) + receiverType, err := a.getReceiverType(receiver, compInterface.IO.Out, nodes, nodesIfaces, scope) if err != nil { - return nil, fmt.Errorf("get sender type: %w", err) + return nil, fmt.Errorf("get sen der type: %w", err) } + if err := a.resolver.IsSubtypeOf(senderType, receiverType, scope); err != nil { return nil, fmt.Errorf("is subtype of: %w", err) } } } + return net, nil } @@ -140,6 +149,7 @@ func (a Analyzer) getReceiverType( receiverSide src.ReceiverConnectionSide, outports map[string]src.Port, nodes map[string]src.Node, + nodesIfaces map[string]src.Interface, scope src.Scope, ) (ts.Expr, error) { if receiverSide.PortAddr.Node == "in" { @@ -207,7 +217,6 @@ func (a Analyzer) getResolvedPortType( return ts.Expr{}, fmt.Errorf("resolve args: %w", err) } - // FIXME resolve t1 resolvedOutportType, err := a.resolver.ResolveExprWithFrame(port.TypeExpr, frame, scope) if err != nil { return ts.Expr{}, fmt.Errorf("resolve expr with frame: %w", err) @@ -230,6 +239,7 @@ func (a Analyzer) getSenderType( senderSide src.SenderConnectionSide, inports map[string]src.Port, nodes map[string]src.Node, + nodesIfaces map[string]src.Interface, scope src.Scope, ) (ts.Expr, error) { if senderSide.ConstRef != nil { @@ -255,7 +265,7 @@ func (a Analyzer) getSenderType( return inport.TypeExpr, nil } - nodeOutportType, err := a.getNodeOutportType(*senderSide.PortAddr, nodes, scope) + nodeOutportType, err := a.getNodeOutportType(*senderSide.PortAddr, nodes, nodesIfaces, scope) if err != nil { return ts.Expr{}, fmt.Errorf("get node outport type: %w", err) } @@ -266,6 +276,7 @@ func (a Analyzer) getSenderType( func (a Analyzer) getNodeOutportType( portAddr src.PortAddr, nodes map[string]src.Node, + nodesIfaces map[string]src.Interface, scope src.Scope, ) (ts.Expr, error) { node, ok := nodes[portAddr.Node] @@ -273,14 +284,14 @@ func (a Analyzer) getNodeOutportType( return ts.Expr{}, ErrNodeNotFound } - entity, _, err := scope.Entity(node.EntityRef) - if err != nil { - panic(err) + nodeIface, ok := nodesIfaces[portAddr.Node] + if !ok { + return ts.Expr{}, ErrNodeNotFound } typ, err := a.getResolvedPortType( - entity.Component.IO.Out, - entity.Component.TypeParams, + nodeIface.IO.Out, + nodeIface.TypeParams, portAddr, node, scope, diff --git a/internal/compiler/compiler.go b/internal/compiler/compiler.go index 9cb5b071..aff640bd 100644 --- a/internal/compiler/compiler.go +++ b/internal/compiler/compiler.go @@ -21,10 +21,10 @@ type ( Save(context.Context, string, *ir.Program) error } Parser interface { - ParseFiles(context.Context, map[string][]byte) (map[string]src.File, error) + ParseProg(context.Context, map[string]RawPackage) (src.Program, error) } Analyzer interface { - Analyze(src.Program) error + Analyze(src.Program) (src.Program, error) } IRGenerator interface { Generate(context.Context, src.Program) (*ir.Program, error) @@ -34,25 +34,22 @@ type ( ) func (c Compiler) Compile(ctx context.Context, srcPath, dstPath string) (*ir.Program, error) { - raw, err := c.repo.ByPath(ctx, srcPath) + rawProg, err := c.repo.ByPath(ctx, srcPath) if err != nil { return nil, fmt.Errorf("repo by path: %w", err) } - parsedPackages := make(src.Program, len(raw)) - for pkgName, files := range raw { - parsedFiles, err := c.parser.ParseFiles(ctx, files) - if err != nil { - return nil, fmt.Errorf("parse files: %w", err) - } - parsedPackages[pkgName] = parsedFiles + parsedProg, err := c.parser.ParseProg(ctx, rawProg) + if err != nil { + return nil, fmt.Errorf("parse prog: %w", err) } - if err := c.analyzer.Analyze(parsedPackages); err != nil { + analyzedProg, err := c.analyzer.Analyze(parsedProg) + if err != nil { return nil, fmt.Errorf("analyze: %w", err) } - irProg, err := c.irgen.Generate(ctx, parsedPackages) + irProg, err := c.irgen.Generate(ctx, analyzedProg) if err != nil { return nil, fmt.Errorf("generate IR: %w", err) } diff --git a/internal/compiler/parser/parser.go b/internal/compiler/parser/parser.go index fa235827..7f011322 100644 --- a/internal/compiler/parser/parser.go +++ b/internal/compiler/parser/parser.go @@ -3,8 +3,10 @@ package parser import ( "context" + "fmt" "github.com/antlr4-go/antlr/v4" + "github.com/nevalang/neva/internal/compiler" generated "github.com/nevalang/neva/internal/compiler/parser/generated" "github.com/nevalang/neva/internal/compiler/src" "golang.org/x/sync/errgroup" @@ -19,14 +21,29 @@ type Parser struct { isDebug bool } -func (p Parser) ParseFiles(ctx context.Context, files map[string][]byte) (map[string]src.File, error) { +func (p Parser) ParseProg(ctx context.Context, rawProg map[string]compiler.RawPackage) (src.Program, error) { + prog := make(src.Program, len(rawProg)) + + for pkgName, pkgFiles := range rawProg { + parsedFiles, err := p.parseFiles(ctx, pkgFiles) + if err != nil { + return src.Program{}, fmt.Errorf("parse files: %w", err) + } + + prog[pkgName] = parsedFiles + } + + return prog, nil +} + +func (p Parser) parseFiles(ctx context.Context, files map[string][]byte) (map[string]src.File, error) { result := make(map[string]src.File, len(files)) g, gctx := errgroup.WithContext(ctx) for name, bb := range files { name := name bb := bb g.Go(func() error { - v, err := p.ParseFile(gctx, bb) + v, err := p.parseFile(gctx, bb) if err != nil { return err } @@ -40,7 +57,7 @@ func (p Parser) ParseFiles(ctx context.Context, files map[string][]byte) (map[st return result, nil } -func (p Parser) ParseFile(ctx context.Context, bb []byte) (src.File, error) { +func (p Parser) parseFile(ctx context.Context, bb []byte) (src.File, error) { input := antlr.NewInputStream(string(bb)) lexer := generated.NewnevaLexer(input) stream := antlr.NewCommonTokenStream(lexer, 0) diff --git a/internal/runtime/funcs/funcs.go b/internal/runtime/funcs/funcs.go index eea8b3ed..46cbc3f8 100644 --- a/internal/runtime/funcs/funcs.go +++ b/internal/runtime/funcs/funcs.go @@ -149,18 +149,17 @@ func Add(ctx context.Context, io runtime.FuncIO) (func(), error) { return nil, errors.New("ctx value is not runtime message") } - // overloading: var handler func(a, b runtime.Msg) runtime.Msg switch typ.Type() { - case runtime.IntMsgType: // int + case runtime.IntMsgType: handler = func(a, b runtime.Msg) runtime.Msg { return runtime.NewIntMsg(a.Int() + b.Int()) } - case runtime.FloatMsgType: // float + case runtime.FloatMsgType: handler = func(a, b runtime.Msg) runtime.Msg { return runtime.NewFloatMsg(a.Float() + b.Float()) } - case runtime.StrMsgType: // string + case runtime.StrMsgType: handler = func(a, b runtime.Msg) runtime.Msg { return runtime.NewStrMsg(a.Str() + b.Str()) }