Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(driver): Simplify code to keep only one round state #71

Merged
merged 3 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 27 additions & 39 deletions Code/driver/src/driver.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use alloc::collections::BTreeMap;

use malachite_round::state_machine::RoundData;

use malachite_common::{
Expand Down Expand Up @@ -32,13 +30,11 @@ where
pub env: Env,
pub proposer_selector: PSel,

pub height: Ctx::Height,
pub address: Ctx::Address,
pub validator_set: Ctx::ValidatorSet,

pub round: Round,
pub votes: VoteKeeper<Ctx>,
pub round_states: BTreeMap<Round, RoundState<Ctx>>,
pub round_state: RoundState<Ctx>,
}

impl<Ctx, Env, PSel> Driver<Ctx, Env, PSel>
Expand All @@ -51,7 +47,6 @@ where
ctx: Ctx,
env: Env,
proposer_selector: PSel,
height: Ctx::Height,
validator_set: Ctx::ValidatorSet,
address: Ctx::Address,
) -> Self {
Expand All @@ -61,17 +56,17 @@ where
ctx,
env,
proposer_selector,
height,
address,
validator_set,
round: Round::NIL,
votes,
round_states: BTreeMap::new(),
round_state: RoundState::default(),
}
}

async fn get_value(&self, round: Round) -> Option<Ctx::Value> {
self.env.get_value(self.height.clone(), round).await
async fn get_value(&self) -> Option<Ctx::Value> {
self.env
.get_value(self.round_state.height.clone(), self.round_state.round)
.await
}

pub async fn execute(&mut self, msg: Event<Ctx>) -> Result<Option<Message<Ctx>>, Error<Ctx>> {
Expand All @@ -82,9 +77,7 @@ where

let msg = match round_msg {
RoundMessage::NewRound(round) => {
// XXX: Check if there is an existing state?
assert!(self.round < round);
Message::NewRound(round)
Message::NewRound(self.round_state.height.clone(), round)
}

RoundMessage::Proposal(proposal) => {
Expand All @@ -108,21 +101,27 @@ where
Ok(Some(msg))
}

async fn apply(&mut self, msg: Event<Ctx>) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
match msg {
Event::NewRound(round) => self.apply_new_round(round).await,
async fn apply(&mut self, event: Event<Ctx>) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
match event {
Event::NewRound(height, round) => self.apply_new_round(height, round).await,

Event::Proposal(proposal, validity) => {
Ok(self.apply_proposal(proposal, validity).await)
}

Event::Vote(signed_vote) => self.apply_vote(signed_vote),

Event::TimeoutElapsed(timeout) => Ok(self.apply_timeout(timeout)),
}
}

async fn apply_new_round(
&mut self,
height: Ctx::Height,
round: Round,
) -> Result<Option<RoundMessage<Ctx>>, Error<Ctx>> {
self.round_state = RoundState::new(height, round);

let proposer_address = self
.proposer_selector
.select_proposer(round, &self.validator_set);
Expand All @@ -136,7 +135,7 @@ where
// We are the proposer
// TODO: Schedule propose timeout

let Some(value) = self.get_value(round).await else {
let Some(value) = self.get_value().await else {
return Err(Error::NoValueToPropose);
};

Expand All @@ -145,13 +144,6 @@ where
RoundEvent::NewRound
};

assert!(self.round < round);
self.round_states.insert(
round,
RoundState::default().new_round(self.height.clone(), round),
);
self.round = round;

Ok(self.apply_event(round, event))
}

Expand All @@ -161,23 +153,24 @@ where
validity: Validity,
) -> Option<RoundMessage<Ctx>> {
// Check that there is an ongoing round
let Some(round_state) = self.round_states.get(&self.round) else {
// TODO: Add logging
if self.round_state.round == Round::NIL {
return None;
};
}

// Only process the proposal if there is no other proposal
if round_state.proposal.is_some() {
if self.round_state.proposal.is_some() {
return None;
}

// Check that the proposal is for the current height and round
if self.height != proposal.height() || self.round != proposal.round() {
if self.round_state.height != proposal.height()
|| self.round_state.round != proposal.round()
{
return None;
}

// TODO: Document
if proposal.pol_round().is_defined() && proposal.pol_round() >= round_state.round {
if proposal.pol_round().is_defined() && proposal.pol_round() >= self.round_state.round {
return None;
}

Expand Down Expand Up @@ -268,10 +261,9 @@ where

/// Apply the event, update the state.
fn apply_event(&mut self, round: Round, event: RoundEvent<Ctx>) -> Option<RoundMessage<Ctx>> {
// Get the round state, or create a new one
let round_state = self.round_states.remove(&round).unwrap_or_default();
let round_state = core::mem::take(&mut self.round_state);

let data = RoundData::new(round, &self.height, &self.address);
let data = RoundData::new(round, round_state.height.clone(), &self.address);

// Multiplex the event with the round state.
let mux_event = match event {
Expand All @@ -295,13 +287,9 @@ where
let transition = round_state.apply_event(&data, mux_event);

// Update state
self.round_states.insert(round, transition.next_state);
self.round_state = transition.next_state;

// Return message, if any
transition.message
}

pub fn round_state(&self, round: Round) -> Option<&RoundState<Ctx>> {
self.round_states.get(&round)
}
}
2 changes: 1 addition & 1 deletion Code/driver/src/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub enum Event<Ctx>
where
Ctx: Context,
{
NewRound(Round),
NewRound(Ctx::Height, Round),
Proposal(Ctx::Proposal, Validity),
Vote(SignedVote<Ctx>),
TimeoutElapsed(Timeout),
Expand Down
10 changes: 6 additions & 4 deletions Code/driver/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ where
Vote(SignedVote<Ctx>),
Decide(Round, Ctx::Value),
ScheduleTimeout(Timeout),
NewRound(Round),
NewRound(Ctx::Height, Round),
}

// NOTE: We have to derive these instances manually, otherwise
Expand All @@ -26,7 +26,7 @@ impl<Ctx: Context> Clone for Message<Ctx> {
Message::Vote(signed_vote) => Message::Vote(signed_vote.clone()),
Message::Decide(round, value) => Message::Decide(*round, value.clone()),
Message::ScheduleTimeout(timeout) => Message::ScheduleTimeout(*timeout),
Message::NewRound(round) => Message::NewRound(*round),
Message::NewRound(height, round) => Message::NewRound(height.clone(), *round),
}
}
}
Expand All @@ -39,7 +39,7 @@ impl<Ctx: Context> fmt::Debug for Message<Ctx> {
Message::Vote(signed_vote) => write!(f, "Vote({:?})", signed_vote),
Message::Decide(round, value) => write!(f, "Decide({:?}, {:?})", round, value),
Message::ScheduleTimeout(timeout) => write!(f, "ScheduleTimeout({:?})", timeout),
Message::NewRound(round) => write!(f, "NewRound({:?})", round),
Message::NewRound(height, round) => write!(f, "NewRound({:?}, {:?})", height, round),
}
}
}
Expand All @@ -60,7 +60,9 @@ impl<Ctx: Context> PartialEq for Message<Ctx> {
(Message::ScheduleTimeout(timeout), Message::ScheduleTimeout(other_timeout)) => {
timeout == other_timeout
}
(Message::NewRound(round), Message::NewRound(other_round)) => round == other_round,
(Message::NewRound(height, round), Message::NewRound(other_height, other_round)) => {
height == other_height && round == other_round
}
_ => false,
}
}
Expand Down
16 changes: 4 additions & 12 deletions Code/round/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,17 @@ impl<Ctx> State<Ctx>
where
Ctx: Context,
{
pub fn new() -> Self {
pub fn new(height: Ctx::Height, round: Round) -> Self {
Self {
height: Ctx::Height::default(),
round: Round::INITIAL,
height,
round,
step: Step::NewRound,
proposal: None,
locked: None,
valid: None,
}
}

pub fn new_round(self, height: Ctx::Height, round: Round) -> Self {
Self {
height,
round,
step: Step::NewRound,
..self
}
}
pub fn with_step(self, step: Step) -> Self {
Self { step, ..self }
}
Expand Down Expand Up @@ -97,7 +89,7 @@ where
Ctx: Context,
{
fn default() -> Self {
Self::new()
Self::new(Ctx::Height::default(), Round::NIL)
}
}

Expand Down
9 changes: 4 additions & 5 deletions Code/round/src/state_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ where
Ctx: Context,
{
pub round: Round,
pub height: &'a Ctx::Height,
pub height: Ctx::Height,
pub address: &'a Ctx::Address,
}

impl<'a, Ctx> RoundData<'a, Ctx>
where
Ctx: Context,
{
pub fn new(round: Round, height: &'a Ctx::Height, address: &'a Ctx::Address) -> Self {
pub fn new(round: Round, height: Ctx::Height, address: &'a Ctx::Address) -> Self {
Self {
round,
height,
Expand Down Expand Up @@ -62,7 +62,7 @@ where
match (state.step, event) {
// From NewRound. Event must be for current round.
(Step::NewRound, Event::NewRoundProposer(value)) if this_round => {
propose(state, data.height, value) // L11/L14
propose(state, &data.height, value) // L11/L14
}
(Step::NewRound, Event::NewRound) if this_round => schedule_timeout_propose(state), // L11/L20

Expand Down Expand Up @@ -331,8 +331,7 @@ pub fn round_skip<Ctx>(state: State<Ctx>, round: Round) -> Transition<Ctx>
where
Ctx: Context,
{
Transition::to(state.clone().new_round(state.height.clone(), round))
.with_message(Message::NewRound(round))
Transition::to(State::new(state.height.clone(), round)).with_message(Message::NewRound(round))
}

/// We received +2/3 precommits for a value - commit and decide that value!
Expand Down
Loading
Loading