Skip to content

Commit

Permalink
feat: Support multi inheritance on contracts
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanmenzel committed Jan 9, 2025
1 parent aff7644 commit ee9acb4
Show file tree
Hide file tree
Showing 70 changed files with 9,325 additions and 567 deletions.
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"dev:examples": "tsx src/cli.ts build examples --output-awst --output-awst-json",
"dev:approvals": "rimraf tests/approvals/out && tsx src/cli.ts build tests/approvals --dry-run",
"dev:expected-output": "tsx src/cli.ts build tests/expected-output --dry-run",
"dev:testing": "tsx src/cli.ts build tests/approvals/arc-28-events.algo.ts --output-awst --output-awst-json --output-ssa-ir --log-level=info --out-dir out/[name] --optimization-level=0",
"dev:testing": "tsx src/cli.ts build tests/approvals/multi-inheritance.algo.ts --output-awst --output-awst-json --output-ssa-ir --log-level=info --out-dir out/[name] --optimization-level=0",
"audit": "better-npm-audit audit",
"format": "prettier --write .",
"lint": "eslint \"src/**/*.ts\"",
Expand Down
2 changes: 1 addition & 1 deletion packages/algo-ts/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@algorandfoundation/algorand-typescript",
"version": "1.0.0-beta.1",
"version": "1.0.0-beta.2",
"description": "This package contains definitions for the types which comprise Algorand TypeScript which can be compiled to run on the Algorand Virtual Machine using the Puya compiler.",
"private": false,
"main": "index.js",
Expand Down
4 changes: 4 additions & 0 deletions packages/algo-ts/src/impl/primitives.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ describe('ArrayUtil', () => {
[[1, 2, 3], [-3], [1, 2, 3]],
[[1, 2, 3], [4], []],
[[1, 2, 3], [-4], [1, 2, 3]],
])('%s.at(%d) results in %s', (theArray, [start, stop], theResult) => {
expect(arrayUtil.arraySlice(theArray, start, stop)).toEqual(theResult)
})
it.each([
[new Uint8Array([1, 2, 3]), [], new Uint8Array([1, 2, 3])],
[new Uint8Array([1, 2, 3]), [0], new Uint8Array([1, 2, 3])],
[new Uint8Array([1, 2, 3]), [1], new Uint8Array([2, 3])],
Expand Down
3 changes: 1 addition & 2 deletions packages/algo-ts/src/impl/primitives.ts
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ export const arrayUtil = new (class ArrayUtil {
}
arraySlice(arrayLike: Uint8Array, start: undefined | StubUint64Compat, end: undefined | StubUint64Compat): Uint8Array
arraySlice<T>(arrayLike: T[], start: undefined | StubUint64Compat, end: undefined | StubUint64Compat): T[]
arraySlice<T>(arrayLike: T[] | Uint8Array, start: undefined | StubUint64Compat, end: undefined | StubUint64Compat): Uint8Array | T[]
arraySlice<T>(arrayLike: T[] | Uint8Array, start: undefined | StubUint64Compat, end: undefined | StubUint64Compat) {
arraySlice<T>(arrayLike: T[] | Uint8Array, start: undefined | StubUint64Compat, end: undefined | StubUint64Compat): Uint8Array | T[] {
const startNum = start === undefined ? undefined : getNumber(start)
const endNum = end === undefined ? undefined : getNumber(end)
if (arrayLike instanceof Uint8Array) {
Expand Down
14 changes: 10 additions & 4 deletions src/awst_build/ast-visitors/constructor-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import * as awst from '../../awst/nodes'
import { AwstBuildFailureError } from '../../errors'
import { codeInvariant, invariant } from '../../util'
import type { AwstBuildContext } from '../context/awst-build-context'
import type { ContractClassPType } from '../ptypes'
import { voidPType } from '../ptypes'
import { ContractMethodBaseVisitor } from './contract-method-visitor'

Expand All @@ -17,8 +18,8 @@ export class ConstructorVisitor extends ContractMethodBaseVisitor {
private readonly _result: awst.ContractMethod
private _foundSuperCall = false
private readonly _propertyInitializerStatements: awst.Statement[]
constructor(ctx: AwstBuildContext, node: ts.ConstructorDeclaration, contractInfo: ConstructorInfo) {
super(ctx, node)
constructor(ctx: AwstBuildContext, node: ts.ConstructorDeclaration, contractType: ContractClassPType, contractInfo: ConstructorInfo) {
super(ctx, node, contractType)
this._propertyInitializerStatements = contractInfo.propertyInitializerStatements
const sourceLocation = this.sourceLocation(node)

Expand All @@ -40,8 +41,13 @@ export class ConstructorVisitor extends ContractMethodBaseVisitor {
return this._result
}

public static buildConstructor(parentCtx: AwstBuildContext, node: ts.ConstructorDeclaration, constructorMethodInfo: ConstructorInfo) {
const result = new ConstructorVisitor(parentCtx.createChildContext(), node, constructorMethodInfo).result
public static buildConstructor(
parentCtx: AwstBuildContext,
node: ts.ConstructorDeclaration,
contractType: ContractClassPType,
constructorMethodInfo: ConstructorInfo,
) {
const result = new ConstructorVisitor(parentCtx.createChildContext(), node, contractType, constructorMethodInfo).result
invariant(result instanceof awst.ContractMethod, "result must be ContractMethod'")
return result
}
Expand Down
31 changes: 15 additions & 16 deletions src/awst_build/ast-visitors/contract-method-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,45 @@ import type { SourceLocation } from '../../awst/source-location'
import { Constants } from '../../constants'
import { CodeError } from '../../errors'
import { logger } from '../../logger'
import { codeInvariant, isIn } from '../../util'
import { codeInvariant, invariant, isIn } from '../../util'
import { getArc4StructDef, getFunctionTypes, ptypeToArc4PType } from '../arc4-util'
import type { AwstBuildContext } from '../context/awst-build-context'
import type { NodeBuilder } from '../eb'
import { ContractSuperBuilder, ContractThisBuilder } from '../eb/contract-builder'
import { isValidLiteralForPType } from '../eb/util'
import type { Arc4AbiDecoratorData, DecoratorData } from '../models/decorator-data'
import type { FunctionPType } from '../ptypes'
import { ContractClassPType, GlobalStateType } from '../ptypes'
import type { ContractClassPType, FunctionPType } from '../ptypes'
import { GlobalStateType } from '../ptypes'
import { ARC4StructType } from '../ptypes/arc4-types'
import { DecoratorVisitor } from './decorator-visitor'
import { FunctionVisitor } from './function-visitor'

export class ContractMethodBaseVisitor extends FunctionVisitor {
protected readonly _contractType: ContractClassPType
constructor(ctx: AwstBuildContext, node: ts.MethodDeclaration | ts.ConstructorDeclaration, contractType: ContractClassPType) {
super(ctx, node)
this._contractType = contractType
}
visitSuperKeyword(node: ts.SuperExpression): NodeBuilder {
const sourceLocation = this.sourceLocation(node)
const ptype = this.context.getPTypeForNode(node)
if (ptype instanceof ContractClassPType) {
return new ContractSuperBuilder(ptype, sourceLocation, this.context)
}
throw new CodeError(`'super' keyword is not valid outside of a contract type`, { sourceLocation })

// Only the polytype clustered class should have more than one base type, and it shouldn't have
// any user code with super calls
invariant(this._contractType.baseTypes.length === 1, 'Super keyword only valid if contract has a single base type')
return new ContractSuperBuilder(this._contractType.baseTypes[0], sourceLocation, this.context)
}

visitThisKeyword(node: ts.ThisExpression): NodeBuilder {
const sourceLocation = this.sourceLocation(node)
const ptype = this.context.getPTypeForNode(node)
if (ptype instanceof ContractClassPType) {
return new ContractThisBuilder(ptype, sourceLocation, this.context)
}
throw new CodeError(`'this' keyword is not valid outside of a contract type`, { sourceLocation })
return new ContractThisBuilder(this._contractType, sourceLocation, this.context)
}
}

export class ContractMethodVisitor extends ContractMethodBaseVisitor {
private readonly _result: awst.ContractMethod
private readonly _contractType: ContractClassPType

constructor(ctx: AwstBuildContext, node: ts.MethodDeclaration, contractType: ContractClassPType) {
super(ctx, node)
this._contractType = contractType
super(ctx, node, contractType)
const sourceLocation = this.sourceLocation(node)
const { args, body, documentation } = this.buildFunctionAwst(node)
const cref = ContractReference.fromPType(this._contractType)
Expand Down
3 changes: 1 addition & 2 deletions src/awst_build/ast-visitors/contract-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ export class ContractVisitor extends BaseVisitor implements Visitor<ClassElement
appState: this.context.getStorageDefinitionsForContract(this._contractPType),
ctor: this._ctor ?? this.makeDefaultConstructor(sourceLocation),
methods: this._methods,
bases: this._contractPType.baseTypes.map((bt) => ContractReference.fromPType(bt)),
description: this.getNodeDescription(classDec),
approvalProgram: this._contractPType.isARC4 ? null : this._approvalProgram,
clearProgram: this._clearStateProgram,
Expand Down Expand Up @@ -125,7 +124,7 @@ export class ContractVisitor extends BaseVisitor implements Visitor<ClassElement
this.throwNotSupported(node, 'class static blocks')
}
visitConstructor(node: ts.ConstructorDeclaration): void {
this._ctor = ConstructorVisitor.buildConstructor(this.context, node, {
this._ctor = ConstructorVisitor.buildConstructor(this.context, node, this._contractPType, {
cref: ContractReference.fromPType(this._contractPType),
propertyInitializerStatements: this._propertyInitialization,
})
Expand Down
28 changes: 10 additions & 18 deletions src/awst_build/context/awst-build-context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,19 @@ class AwstBuildContextImpl implements AwstBuildContext {

getStorageDefinitionsForContract(contractType: ContractClassPType): AppStorageDefinition[] {
const result = new Map<string, AppStorageDefinition>()
for (const baseType of contractType.baseTypes) {
for (const definition of this.getStorageDefinitionsForContract(baseType)) {
if (result.has(definition.memberName)) {
logger.error(
definition.sourceLocation,
`Redefinition of app storage member, original declared in ${result.get(definition.memberName)?.sourceLocation}`,
)
}
result.set(definition.memberName, definition)
}
}
const localDeclarations = this.storageDeclarations.get(contractType.fullName)
if (localDeclarations) {
for (const [member, declaration] of localDeclarations) {
if (result.has(member)) {
const seenContracts = new Set<string>()
for (const ct of [contractType, ...contractType.allBases()]) {
if (seenContracts.has(ct.fullName)) continue
seenContracts.add(ct.fullName)

for (const [memberName, declaration] of this.storageDeclarations.get(ct.fullName) ?? []) {
if (result.has(memberName)) {
logger.error(
declaration.sourceLocation,
`Redefinition of app storage member, original declared in ${result.get(member)?.sourceLocation}`,
result.get(memberName)?.sourceLocation,
`Redefinition of app storage member, original declared in ${declaration.sourceLocation}`,
)
}
result.set(member, declaration.definition)
result.set(memberName, declaration.definition)
}
}
return Array.from(result.values())
Expand Down
79 changes: 77 additions & 2 deletions src/awst_build/eb/contract-builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { PType } from '../ptypes'
import {
arc4BaseContractType,
baseContractType,
ClusteredContractClassType,
ContractClassPType,
contractOptionsDecorator,
numberPType,
Expand All @@ -21,7 +22,11 @@ import {

import { instanceEb } from '../type-registry'

import { BaseContractMethodExpressionBuilder, ContractMethodExpressionBuilder } from './free-subroutine-expression-builder'
import {
BaseContractMethodExpressionBuilder,
ContractMethodExpressionBuilder,
ExplicitBaseContractMethodExpressionBuilder,
} from './free-subroutine-expression-builder'
import type { NodeBuilder } from './index'
import { DecoratorDataBuilder, FunctionBuilder, InstanceBuilder } from './index'
import { ArrayLiteralExpressionBuilder } from './literal/array-literal-expression-builder'
Expand All @@ -31,6 +36,9 @@ import { parseFunctionArgs } from './util/arg-parsing'
import { requireAvmVersion } from './util/avm-version'
import { VoidExpressionBuilder } from './void-expression-builder'

/**
* Handles expressions using `this` in the context of a contract
*/
export class ContractThisBuilder extends InstanceBuilder<ContractClassPType> {
resolve(): Expression {
throw new CodeError('this keyword is not valid as a value', { sourceLocation: this.sourceLocation })
Expand Down Expand Up @@ -69,6 +77,9 @@ export class ContractThisBuilder extends InstanceBuilder<ContractClassPType> {
}
}

/**
* Handles expressions using `super` in the context of a contract
*/
export class ContractSuperBuilder extends ContractThisBuilder {
constructor(ptype: ContractClassPType, sourceLocation: SourceLocation, context: AwstBuildContext) {
super(ptype, sourceLocation, context)
Expand All @@ -94,9 +105,73 @@ export class ContractSuperBuilder extends ContractThisBuilder {
}

memberAccess(name: string, sourceLocation: SourceLocation): NodeBuilder {
if (this.ptype instanceof ClusteredContractClassType && name === 'class') {
return new PolytypeClassSuperMethodBuilder(this.ptype, sourceLocation, this.context)
}

const method = this.ptype.methods[name]
if (method) {
return new BaseContractMethodExpressionBuilder(sourceLocation, method, this.ptype)
return new BaseContractMethodExpressionBuilder(sourceLocation, method)
}
return super.memberAccess(name, sourceLocation)
}
}

/**
* Handles calls of `super.class` from polytype library which is used to access the prototype of a specific base type
*/
class PolytypeClassSuperMethodBuilder extends FunctionBuilder {
constructor(
public readonly ptype: ClusteredContractClassType,
sourceLocation: SourceLocation,
private readonly context: AwstBuildContext,
) {
super(sourceLocation)
}
call(args: ReadonlyArray<NodeBuilder>, typeArgs: ReadonlyArray<PType>, sourceLocation: SourceLocation): NodeBuilder {
const {
args: [contract],
} = parseFunctionArgs({
args,
typeArgs,
genericTypeArgs: 1,
callLocation: sourceLocation,
funcName: 'super.class',
argSpec: (a) => [a.required(ContractClassPType)],
})
const matchedBaseType = this.ptype.baseTypes.find((b) => b.equals(contract.ptype))

codeInvariant(matchedBaseType, `${contract.ptype} must be a direct base type of this class`)
return new PolytypeExplicitClassAccessExpressionBuilder(matchedBaseType, sourceLocation)
}
}

/**
* Matches polytype's super.class(SomeType) expression
*/
export class PolytypeExplicitClassAccessExpressionBuilder extends InstanceBuilder {
resolve(): Expression {
throw new CodeError('Contract class cannot be used as a value')
}
resolveLValue(): LValue {
throw new CodeError('Contract class cannot be used as a value')
}
constructor(
public readonly ptype: ContractClassPType,
sourceLocation: SourceLocation,
) {
super(sourceLocation)
}

memberAccess(name: string, sourceLocation: SourceLocation): NodeBuilder {
const method = this.ptype.methods[name]
if (method) {
return new ExplicitBaseContractMethodExpressionBuilder(sourceLocation, method, this.ptype)
}
if (name in this.ptype.properties) {
throw new CodeError(`Not Supported: Accessing properties of a specific base type. Instead just use \`this.${name}\``, {
sourceLocation,
})
}
return super.memberAccess(name, sourceLocation)
}
Expand Down
28 changes: 27 additions & 1 deletion src/awst_build/eb/free-subroutine-expression-builder.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ContractReference } from '../../awst/models'
import { nodeFactory } from '../../awst/node-factory'
import type { InstanceMethodTarget, InstanceSuperMethodTarget, SubroutineID } from '../../awst/nodes'
import type { SourceLocation } from '../../awst/source-location'
Expand Down Expand Up @@ -40,6 +41,25 @@ abstract class SubroutineExpressionBuilder extends FunctionBuilder {
}
}

/**
* Invoke a contract method by naming the contract explicitly
*/
export class ExplicitBaseContractMethodExpressionBuilder extends SubroutineExpressionBuilder {
constructor(sourceLocation: SourceLocation, ptype: FunctionPType, baseContractPType: ContractClassPType) {
super(
sourceLocation,
ptype,
nodeFactory.contractMethodTarget({
cref: ContractReference.fromPType(baseContractPType),
memberName: ptype.name,
}),
)
}
}

/**
* Invoke a contract method on the current contract (ie. this.someMethod())
*/
export class ContractMethodExpressionBuilder extends SubroutineExpressionBuilder {
constructor(sourceLocation: SourceLocation, ptype: FunctionPType) {
super(
Expand All @@ -52,8 +72,11 @@ export class ContractMethodExpressionBuilder extends SubroutineExpressionBuilder
}
}

/**
* Invoke a contract method on the super contract (ie. super.someMethod())
*/
export class BaseContractMethodExpressionBuilder extends SubroutineExpressionBuilder {
constructor(sourceLocation: SourceLocation, ptype: FunctionPType, baseContractPType: ContractClassPType) {
constructor(sourceLocation: SourceLocation, ptype: FunctionPType) {
super(
sourceLocation,
ptype,
Expand All @@ -64,6 +87,9 @@ export class BaseContractMethodExpressionBuilder extends SubroutineExpressionBui
}
}

/**
* Invoke a free subroutine (ie. someMethod())
*/
export class FreeSubroutineExpressionBuilder extends SubroutineExpressionBuilder {
constructor(sourceLocation: SourceLocation, ptype: PType) {
if (!(ptype instanceof FunctionPType)) {
Expand Down
2 changes: 0 additions & 2 deletions src/awst_build/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ export function buildLibAwst(context: AwstBuildContext) {
isAbstract: true,
propertyInitialization: [],
ctor: null,
bases: [],
methods: [],
appState: [],
options: undefined,
Expand Down Expand Up @@ -52,7 +51,6 @@ export function buildLibAwst(context: AwstBuildContext) {
appState: [],
options: undefined,
description: null,
bases: [baseContractCref],
clearProgram: null,
sourceLocation: SourceLocation.None,
approvalProgram: nodeFactory.contractMethod({
Expand Down
Loading

0 comments on commit ee9acb4

Please sign in to comment.