From 4143102220fd5b074ef5aa964448b7c068463afd Mon Sep 17 00:00:00 2001 From: Rene Date: Tue, 23 Jul 2024 11:33:11 +0200 Subject: [PATCH] minor update --- src/BasisConvolution/util/arguments.py | 3 ++- src/BasisConvolution/util/hyperparameters.py | 9 ++++++--- src/BasisConvolution/util/testcases.py | 9 +++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/BasisConvolution/util/arguments.py b/src/BasisConvolution/util/arguments.py index f058f24..efab87f 100644 --- a/src/BasisConvolution/util/arguments.py +++ b/src/BasisConvolution/util/arguments.py @@ -69,5 +69,6 @@ parser.add_argument('--shiftLoss', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop') parser.add_argument('--skipLastShift', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop') parser.add_argument('--scaleShiftLoss', type = bool, default = argparse.SUPPRESS, action=argparse.BooleanOptionalAction, help='Shifting the loop') +parser.add_argument('--integrationScheme', type = str, default = argparse.SUPPRESS, help='Integration scheme') -parser.add_argument('--exportPath', type = str, default = argparse.SUPPRESS, help='Export path') \ No newline at end of file +parser.add_argument('--exportPath', type = str, default = argparse.SUPPRESS, help='Export path') diff --git a/src/BasisConvolution/util/hyperparameters.py b/src/BasisConvolution/util/hyperparameters.py index 890119a..bc8767c 100644 --- a/src/BasisConvolution/util/hyperparameters.py +++ b/src/BasisConvolution/util/hyperparameters.py @@ -66,7 +66,8 @@ def defaultHyperParameters(): 'inputEncoder': None, 'outputDecoder': None, 'edgeMLP': None, - 'vertexMLP': None + 'vertexMLP': None, + 'integrationScheme': 'semiImplicitEuler', } return hyperParameterDict @@ -135,6 +136,7 @@ def parseArguments(args, hyperParameterDict): hyperParameterDict['scaleShiftLoss'] = args.scaleShiftLoss if hasattr(args, 'scaleShiftLoss') else hyperParameterDict['scaleShiftLoss'] hyperParameterDict['activation'] = args.activation if hasattr(args, 'activation') else hyperParameterDict['activation'] hyperParameterDict['exportPath'] = args.exportPath if hasattr(args, 'exportPath') else hyperParameterDict['exportPath'] + hyperParameterDict['integrationScheme'] = args.integrationScheme if hasattr(args, 'integrationScheme') else hyperParameterDict['integrationScheme'] hyperParameterDict['device'] = args.device if hasattr(args, 'device') else hyperParameterDict['device'] # hyperParameterDict['dtype'] = torch. @@ -162,7 +164,7 @@ def parseArguments(args, hyperParameterDict): 'gain': 1, 'norm': True, 'layout': [32], - 'output': 1, + # 'output': 1, 'preNorm': False, 'postNorm': True, 'noLinear': True, @@ -263,6 +265,7 @@ def parseConfig(config, hyperParameterDict): parseEntry(cfg, 'shifting', 'networkType', hyperParameterDict, 'networkType') parseEntry(cfg, 'shifting', 'shiftLoss', hyperParameterDict, 'shiftLoss') parseEntry(cfg, 'shifting', 'scaleShiftLoss', hyperParameterDict, 'scaleShiftLoss') + parseEntry(cfg, 'shifting', 'integrationScheme', hyperParameterDict, 'integrationScheme') parseEntry(cfg, 'dataset', 'dataIndex', hyperParameterDict, 'dataIndex') parseEntry(cfg, 'shifting', 'skipLastShift', hyperParameterDict, 'skipLastShift') parseEntry(cfg, 'loss', 'dxdtLossScaling', hyperParameterDict, 'dxdtLossScaling') @@ -428,7 +431,7 @@ def toPandaDict(hyperParameterDict): 'skipLastShift': hyperParameterDict['skipLastShift'], 'dxdtLossScaling': hyperParameterDict['dxdtLossScaling'], 'scaleShiftLoss': hyperParameterDict['scaleShiftLoss'] if 'scaleShiftLoss' in hyperParameterDict else False, - + 'integrationScheme': hyperParameterDict['integrationScheme'], 'inputEncoder': True if hyperParameterDict['inputEncoder'] is not None else False, 'outputDecoder': True if hyperParameterDict['outputDecoder'] is not None else False, diff --git a/src/BasisConvolution/util/testcases.py b/src/BasisConvolution/util/testcases.py index 561481a..4fffea3 100644 --- a/src/BasisConvolution/util/testcases.py +++ b/src/BasisConvolution/util/testcases.py @@ -522,12 +522,13 @@ def loadFrame_newFormat(inFile, fileName, key, fileData, fileIndex, fileOffset, iPriorKey = int(key) - hyperParameterDict['frameDistance'] priorState = None - if buildPriorState: + if buildPriorState or hyperParameterDict['adjustForFrameDistance']: if iPriorKey < 0 or hyperParameterDict['frameDistance'] == 0: priorState = copy.deepcopy(state) else: priorState = loadGroup_newFormat(inFile, inFile['simulationExport']['%05d' % iPriorKey], staticBoundaryData, fileName, iPriorKey, fileData, fileIndex, fileOffset, dataset, hyperParameterDict, unrollLength = unrollLength, device = device, dtype = dtype, additionalData = additionalData, buildPriorState = False, buildNextState = False) - + + nextStates = [] if buildNextState: if unrollLength == 0 and hyperParameterDict['frameDistance'] == 0: @@ -543,6 +544,10 @@ def loadFrame_newFormat(inFile, fileName, key, fileData, fileIndex, fileOffset, nextState = loadGroup_newFormat(inFile, inFile['simulationExport']['%05d' % unrollKey], staticBoundaryData, fileName, iPriorKey, fileData, fileIndex, fileOffset, dataset, hyperParameterDict, unrollLength = unrollLength, device = device, dtype = dtype, additionalData = additionalData, buildPriorState = False, buildNextState = False) nextStates.append(nextState) + # if hyperParameterDict['adjustForFrameDistance']: + + +