Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev/weighted_loss' into dev/weig…
Browse files Browse the repository at this point in the history
…hted_loss
  • Loading branch information
qiaoting159753 committed Dec 30, 2024
2 parents 3e2127b + 7884942 commit 5e25ad4
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions cares_reinforcement_learning/util/network_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def create_STEVESAC_Bounded(

from cares_reinforcement_learning.algorithm.mbrl import STEVESAC_Bounded
from cares_reinforcement_learning.networks.SAC import Actor, Critic
from cares_reinforcement_learning.networks.world_models.ensemble import Ensemble_Dyna_Big
from cares_reinforcement_learning.networks.world_models.ensemble import (
Ensemble_Dyna_Big,
)

actor = Actor(observation_size, action_num, config=config)
critic = Critic(observation_size, action_num, config=config)
Expand Down Expand Up @@ -445,10 +447,10 @@ def create_STEVESAC_Bounded(

class NetworkFactory:
def create_network(
self,
observation_size,
action_num: int,
config: acf.AlgorithmConfig,
self,
observation_size,
action_num: int,
config: acf.AlgorithmConfig,
):
algorithm = config.algorithm

Expand Down

0 comments on commit 5e25ad4

Please sign in to comment.