-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prioritized Replay Buffer for Off-policy Algorithms #158
Comments
Thanks for this! Yes we should allow it! It is still more expensive to use a prioritized sampler (with priority 0) but we can add an arg where user can request it and we change the sampler type. It would be as simple a change as you propose becuse the infrastructure to pass the priority is already all in place BenchMARL/benchmarl/experiment/experiment.py Line 795 in e910a83
Would you be down to try this out and let me know if you get improvements over the normal buffer? |
I ran a simple test on VMAS Balance environment with IQL having PRB with 'alpha': 0.6, 'beta': 0.4. result_agents_return_balance-crop.pdf IQL with PRB is worse in this case but I didn't run an extensive test with a hyperparameter search. |
Yes this is also similar to the results I got back in the days (benchmarl had prioritised buffers before the public release). I wonder if there is an implementation error somewhere in the torchrl implementation. Could you try with alpha=0 and see if the performance matches? at least this would give us a small peace of mind. And yes a PR would be lovely thanks! |
Here it is compared with PRB having alpha=0. result_return_balance-crop.pdf As expected, with alpha=0, it reduces to random sampling. I'll make a PR. Thanks! |
Great to check! Thanks! |
Ah, another question: in terms of time taken how do these 3 compare? |
Is there a way to use Prioritized Replay Buffer for off-policy algorithms in BenchMARL?
As far as I understand, the framework does not support it at the moment. It uses RandomSampler here;
BenchMARL/benchmarl/algorithms/common.py
Line 161 in e910a83
The documentation of the returned TensorDictReplayBuffer says;
priority_key (str, optional): the key at which priority is assumed to
be stored within TensorDicts added to this ReplayBuffer.
This is to be used when the sampler is of type
:class:
~torchrl.data.PrioritizedSampler
.Defaults to
"td_error"
.Thus, I guess it does not use priorities while sampling. Prioritized Replay Buffer can be put in place by changing the line to;
sampler = SamplerWithoutReplacement() if self.on_policy else PrioritizedSampler(memory_size, prb_alpha, prb_beta)
When
prb_alpha = 0
, the priorities are not used, so the user can still use uniform sampling.Thanks in advance!
The text was updated successfully, but these errors were encountered: