From d28d4f25f06064466b39979bdd2179eb0d46ebb7 Mon Sep 17 00:00:00 2001 From: ictxiangxin Date: Thu, 13 Apr 2017 23:27:19 +0800 Subject: [PATCH] fix bug: get right index which symbol in gradients. --- paradox/kernel/engine.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paradox/kernel/engine.py b/paradox/kernel/engine.py index f942870..cbb1635 100644 --- a/paradox/kernel/engine.py +++ b/paradox/kernel/engine.py @@ -105,11 +105,14 @@ def __compute_gradient(self, variable: Symbol): if hash(self.__symbol) == hash(variable): self.__gradients[variable] = broadcast(Constant(1), self.shape(self.__symbol)) return + current_operator = None for forward in variable.output: if self.gradient(forward) is not None: + if current_operator != forward.operator: + current_operator = forward.operator + index = -1 gradients = forward.operator.gradient(self, forward, *forward.input) - index = None - for i, _variable in enumerate(forward.input): + for i, _variable in enumerate(forward.input, start=index + 1): if hash(_variable) == hash(variable): index = i break