Skip to content

Commit

Permalink
refactor(compiler): minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
emil14 committed Oct 11, 2023
1 parent 7857c7e commit 268caec
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 74 deletions.
4 changes: 2 additions & 2 deletions examples/005_add_two_numbers_1.neva
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ components {
print tmp.Print<int>
}
net {
in.enter -> read1.in.sig
read1.out.v -> parse1.in.v
in.enter -> read1.in.sig
read1.out.v -> parse1.in.v
parse1.out.v -> {
add.in.a
read2.in.sig
Expand Down
41 changes: 20 additions & 21 deletions internal/compiler/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,49 @@ import (

"github.com/nevalang/neva/internal/compiler/src"
ts "github.com/nevalang/neva/pkg/typesystem"
"golang.org/x/exp/maps"
)

type Analyzer struct {
resolver ts.Resolver
}

// Analyze method formats error from a.analyze so end-user can easily understand what's wrong.
func (a Analyzer) Analyze(prog src.Program) error {
return a.analyze(prog)
}

var (
ErrEmptyProgram = errors.New("empty program")
ErrMainPkgNotFound = errors.New("main package not found")
ErrEmptyPkg = errors.New("package must not be empty")
ErrUnknownEntityKind = errors.New("unknown entity kind")
)

// analyze returns error if program is invalid. It also modifies program by resolving types.
func (a Analyzer) analyze(prog src.Program) error {
func (a Analyzer) Analyze(prog src.Program) (src.Program, error) {
if len(prog) == 0 {
return ErrEmptyProgram
return src.Program{}, ErrEmptyProgram
}

mainPkg, ok := prog["main"]
if !ok {
return ErrMainPkgNotFound
return src.Program{}, ErrMainPkgNotFound
}

if err := a.mainSpecificPkgValidation(mainPkg, prog); err != nil {
return fmt.Errorf("main specific pkg validation: %w", err)
return src.Program{}, fmt.Errorf("main specific pkg validation: %w", err)
}

for pkgName := range prog {
resolvedPkg, err := a.analyzePkg(pkgName, prog)
var progCopy src.Program
maps.Copy(progCopy, prog)

for pkgName := range progCopy {
resolvedPkg, err := a.analyzePkg(pkgName, progCopy)
if err != nil {
return fmt.Errorf("analyze pkg: %v: %w", pkgName, err)
return src.Program{}, fmt.Errorf("analyze pkg: %v: %w", pkgName, err)
}
prog[pkgName] = resolvedPkg
progCopy[pkgName] = resolvedPkg
}

return nil
return progCopy, nil
}

var (
ErrEmptyProgram = errors.New("empty program")
ErrMainPkgNotFound = errors.New("main package not found")
ErrEmptyPkg = errors.New("package must not be empty")
ErrUnknownEntityKind = errors.New("unknown entity kind")
)

// TODO check that there's no 2 entities with the same name
// and that there's no unused entities.
func (a Analyzer) analyzePkg(pkgName string, prog src.Program) (src.Package, error) {
Expand Down
75 changes: 43 additions & 32 deletions internal/compiler/analyzer/component.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ func (a Analyzer) analyzeComponent(
return src.Component{}, fmt.Errorf("analyze interface: %w", err)
}

resolvedNodes, err := a.analyzeComponentNodes(comp.Nodes, scope)
resolvedNodes, nodesIfaces, err := a.analyzeComponentNodes(comp.Nodes, scope)
if err != nil {
return src.Component{}, fmt.Errorf("analyze component nodes: %w", err)
}

resolvedNet, err := a.analyzeComponentNet(comp.Net, resolvedInterface, resolvedNodes, scope)
resolvedNet, err := a.analyzeComponentNet(comp.Net, resolvedInterface, resolvedNodes, nodesIfaces, scope)
if err != nil {
return src.Component{}, fmt.Errorf("analyze component network: %w", err)
}
Expand All @@ -39,16 +39,21 @@ func (a Analyzer) analyzeComponent(
}, nil
}

func (a Analyzer) analyzeComponentNodes(nodes map[string]src.Node, scope src.Scope) (map[string]src.Node, error) {
func (a Analyzer) analyzeComponentNodes(
nodes map[string]src.Node,
scope src.Scope,
) (map[string]src.Node, map[string]src.Interface, error) {
resolvedNodes := make(map[string]src.Node, len(nodes))
nodesIfaces := make(map[string]src.Interface, len(nodes))
for name, node := range nodes {
resolvedNode, err := a.analyzeComponentNode(node, scope)
resolvedNode, iface, err := a.analyzeComponentNode(node, scope)
if err != nil {
return nil, fmt.Errorf("analyze node: %w", err)
return nil, nil, fmt.Errorf("analyze node: %w", err)
}
nodesIfaces[name] = iface
resolvedNodes[name] = resolvedNode
}
return resolvedNodes, nil
return resolvedNodes, nodesIfaces, nil
}

var (
Expand All @@ -57,50 +62,50 @@ var (
ErrNodeInterfaceDI = errors.New("interface node cannot have dependency injection")
)

func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node, error) {
func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node, src.Interface, error) {
entity, _, err := scope.Entity(node.EntityRef)
if err != nil {
return src.Node{}, fmt.Errorf("entity: %w", err)
return src.Node{}, src.Interface{}, fmt.Errorf("entity: %w", err)
}

if entity.Kind != src.ComponentEntity && entity.Kind != src.InterfaceEntity {
return src.Node{}, fmt.Errorf("%w: %v", ErrNodeWrongEntity, entity.Kind)
return src.Node{}, src.Interface{}, fmt.Errorf("%w: %v", ErrNodeWrongEntity, entity.Kind)
}

var compInterface src.Interface
var iface src.Interface
if entity.Kind == src.ComponentEntity {
compInterface = entity.Component.Interface
iface = entity.Component.Interface
} else {
if node.ComponentDI != nil {
return src.Node{}, ErrNodeInterfaceDI
return src.Node{}, src.Interface{}, ErrNodeInterfaceDI
}
compInterface = entity.Interface
iface = entity.Interface
}

if len(node.TypeArgs) != len(compInterface.TypeParams) {
return src.Node{}, fmt.Errorf(
if len(node.TypeArgs) != len(iface.TypeParams) {
return src.Node{}, src.Interface{}, fmt.Errorf(
"%w: want %v, got %v",
ErrNodeTypeArgsCountMismatch, compInterface.TypeParams, node.TypeArgs,
ErrNodeTypeArgsCountMismatch, iface.TypeParams, node.TypeArgs,
)
}

resolvedArgs, _, err := a.resolver.ResolveFrame(node.TypeArgs, compInterface.TypeParams, scope)
resolvedArgs, _, err := a.resolver.ResolveFrame(node.TypeArgs, iface.TypeParams, scope)
if err != nil {
return src.Node{}, fmt.Errorf("resolve args: %w", err)
return src.Node{}, src.Interface{}, fmt.Errorf("resolve args: %w", err)
}

if node.ComponentDI == nil {
return src.Node{
EntityRef: node.EntityRef,
TypeArgs: resolvedArgs,
}, nil
}, iface, nil
}

resolvedComponentDI := make(map[string]src.Node, len(node.ComponentDI))
for depName, depNode := range node.ComponentDI {
resolvedDep, err := a.analyzeComponentNode(depNode, scope)
resolvedDep, _, err := a.analyzeComponentNode(depNode, scope)
if err != nil {
return src.Node{}, fmt.Errorf("analyze dependency node: %w", err)
return src.Node{}, src.Interface{}, fmt.Errorf("analyze dependency node: %w", err)
}
resolvedComponentDI[depName] = resolvedDep
}
Expand All @@ -109,37 +114,42 @@ func (a Analyzer) analyzeComponentNode(node src.Node, scope src.Scope) (src.Node
EntityRef: node.EntityRef,
TypeArgs: resolvedArgs,
ComponentDI: resolvedComponentDI,
}, nil
}, iface, nil
}

func (a Analyzer) analyzeComponentNet(
net []src.Connection,
compInterface src.Interface,
nodes map[string]src.Node,
nodesIfaces map[string]src.Interface,
scope src.Scope,
) ([]src.Connection, error) {
for _, conn := range net {
senderType, err := a.getSenderType(conn.SenderSide, compInterface.IO.In, nodes, scope)
senderType, err := a.getSenderType(conn.SenderSide, compInterface.IO.In, nodes, nodesIfaces, scope)
if err != nil {
return nil, fmt.Errorf("get sender type: %w", err)
}

for _, receiver := range conn.ReceiverSides {
receiverType, err := a.getReceiverType(receiver, compInterface.IO.Out, nodes, scope)
receiverType, err := a.getReceiverType(receiver, compInterface.IO.Out, nodes, nodesIfaces, scope)
if err != nil {
return nil, fmt.Errorf("get sender type: %w", err)
return nil, fmt.Errorf("get sen der type: %w", err)
}

if err := a.resolver.IsSubtypeOf(senderType, receiverType, scope); err != nil {
return nil, fmt.Errorf("is subtype of: %w", err)
}
}
}

return net, nil
}

func (a Analyzer) getReceiverType(
receiverSide src.ReceiverConnectionSide,
outports map[string]src.Port,
nodes map[string]src.Node,
nodesIfaces map[string]src.Interface,

Check failure on line 152 in internal/compiler/analyzer/component.go

View workflow job for this annotation

GitHub Actions / lint

`(Analyzer).getReceiverType` - `nodesIfaces` is unused (unparam)
scope src.Scope,
) (ts.Expr, error) {
if receiverSide.PortAddr.Node == "in" {
Expand Down Expand Up @@ -207,7 +217,6 @@ func (a Analyzer) getResolvedPortType(
return ts.Expr{}, fmt.Errorf("resolve args: %w", err)
}

// FIXME resolve t1
resolvedOutportType, err := a.resolver.ResolveExprWithFrame(port.TypeExpr, frame, scope)
if err != nil {
return ts.Expr{}, fmt.Errorf("resolve expr with frame: %w", err)
Expand All @@ -230,6 +239,7 @@ func (a Analyzer) getSenderType(
senderSide src.SenderConnectionSide,
inports map[string]src.Port,
nodes map[string]src.Node,
nodesIfaces map[string]src.Interface,
scope src.Scope,
) (ts.Expr, error) {
if senderSide.ConstRef != nil {
Expand All @@ -255,7 +265,7 @@ func (a Analyzer) getSenderType(
return inport.TypeExpr, nil
}

nodeOutportType, err := a.getNodeOutportType(*senderSide.PortAddr, nodes, scope)
nodeOutportType, err := a.getNodeOutportType(*senderSide.PortAddr, nodes, nodesIfaces, scope)
if err != nil {
return ts.Expr{}, fmt.Errorf("get node outport type: %w", err)
}
Expand All @@ -266,21 +276,22 @@ func (a Analyzer) getSenderType(
func (a Analyzer) getNodeOutportType(
portAddr src.PortAddr,
nodes map[string]src.Node,
nodesIfaces map[string]src.Interface,
scope src.Scope,
) (ts.Expr, error) {
node, ok := nodes[portAddr.Node]
if !ok {
return ts.Expr{}, ErrNodeNotFound
}

entity, _, err := scope.Entity(node.EntityRef)
if err != nil {
panic(err)
nodeIface, ok := nodesIfaces[portAddr.Node]
if !ok {
return ts.Expr{}, ErrNodeNotFound
}

typ, err := a.getResolvedPortType(
entity.Component.IO.Out,
entity.Component.TypeParams,
nodeIface.IO.Out,
nodeIface.TypeParams,
portAddr,
node,
scope,
Expand Down
21 changes: 9 additions & 12 deletions internal/compiler/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ type (
Save(context.Context, string, *ir.Program) error
}
Parser interface {
ParseFiles(context.Context, map[string][]byte) (map[string]src.File, error)
ParseProg(context.Context, map[string]RawPackage) (src.Program, error)
}
Analyzer interface {
Analyze(src.Program) error
Analyze(src.Program) (src.Program, error)
}
IRGenerator interface {
Generate(context.Context, src.Program) (*ir.Program, error)
Expand All @@ -34,25 +34,22 @@ type (
)

func (c Compiler) Compile(ctx context.Context, srcPath, dstPath string) (*ir.Program, error) {
raw, err := c.repo.ByPath(ctx, srcPath)
rawProg, err := c.repo.ByPath(ctx, srcPath)
if err != nil {
return nil, fmt.Errorf("repo by path: %w", err)
}

parsedPackages := make(src.Program, len(raw))
for pkgName, files := range raw {
parsedFiles, err := c.parser.ParseFiles(ctx, files)
if err != nil {
return nil, fmt.Errorf("parse files: %w", err)
}
parsedPackages[pkgName] = parsedFiles
parsedProg, err := c.parser.ParseProg(ctx, rawProg)
if err != nil {
return nil, fmt.Errorf("parse prog: %w", err)
}

if err := c.analyzer.Analyze(parsedPackages); err != nil {
analyzedProg, err := c.analyzer.Analyze(parsedProg)
if err != nil {
return nil, fmt.Errorf("analyze: %w", err)
}

irProg, err := c.irgen.Generate(ctx, parsedPackages)
irProg, err := c.irgen.Generate(ctx, analyzedProg)
if err != nil {
return nil, fmt.Errorf("generate IR: %w", err)
}
Expand Down
23 changes: 20 additions & 3 deletions internal/compiler/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package parser

import (
"context"
"fmt"

"github.com/antlr4-go/antlr/v4"
"github.com/nevalang/neva/internal/compiler"
generated "github.com/nevalang/neva/internal/compiler/parser/generated"
"github.com/nevalang/neva/internal/compiler/src"
"golang.org/x/sync/errgroup"
Expand All @@ -19,14 +21,29 @@ type Parser struct {
isDebug bool
}

func (p Parser) ParseFiles(ctx context.Context, files map[string][]byte) (map[string]src.File, error) {
func (p Parser) ParseProg(ctx context.Context, rawProg map[string]compiler.RawPackage) (src.Program, error) {
prog := make(src.Program, len(rawProg))

for pkgName, pkgFiles := range rawProg {
parsedFiles, err := p.parseFiles(ctx, pkgFiles)
if err != nil {
return src.Program{}, fmt.Errorf("parse files: %w", err)
}

prog[pkgName] = parsedFiles
}

return prog, nil
}

func (p Parser) parseFiles(ctx context.Context, files map[string][]byte) (map[string]src.File, error) {
result := make(map[string]src.File, len(files))
g, gctx := errgroup.WithContext(ctx)
for name, bb := range files {
name := name
bb := bb
g.Go(func() error {
v, err := p.ParseFile(gctx, bb)
v, err := p.parseFile(gctx, bb)
if err != nil {
return err
}
Expand All @@ -40,7 +57,7 @@ func (p Parser) ParseFiles(ctx context.Context, files map[string][]byte) (map[st
return result, nil
}

func (p Parser) ParseFile(ctx context.Context, bb []byte) (src.File, error) {
func (p Parser) parseFile(ctx context.Context, bb []byte) (src.File, error) {
input := antlr.NewInputStream(string(bb))
lexer := generated.NewnevaLexer(input)
stream := antlr.NewCommonTokenStream(lexer, 0)
Expand Down
Loading

0 comments on commit 268caec

Please sign in to comment.