Skip to content

Commit

Permalink
doesn't crash now; other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
r3w0p committed Jan 9, 2025
1 parent ffa6d73 commit 111aba3
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 47 deletions.
1 change: 1 addition & 0 deletions include/caravan/core/training.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ typedef struct TrainConfig {
float learning{0.0};
uint32_t episode_max{0};
uint32_t episode{0};
PlayerName focus{NO_PLAYER};
} TrainConfig;


Expand Down
32 changes: 12 additions & 20 deletions src/caravan/core/training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,20 +452,11 @@ bool train_on_game(Game *game, QTable &q_table, ActionSpace &action_space,

// Only first player is learning
// Opp always makes random moves and does not influence learning
bool learning = pturn == gc.player_first;
bool learning = pturn == tc.focus;

// Get game state in relation to current player
get_game_state(&gs, game, pturn);

// Maybe add game state and actions if new state discovered
/*
if (!q_table.contains(gs)) {
for (uint16_t i = 0; i < SIZE_ACTION_SPACE; i++) {
q_table[gs][action_space[i]] = 0;
}
}
*/

// Use action pool that depletes as actions are found to be invalid
for (int i = 0; i < SIZE_ACTION_SPACE; i++) {
action_pool.push_back(action_space[i]);
Expand All @@ -476,6 +467,12 @@ bool train_on_game(Game *game, QTable &q_table, ActionSpace &action_space,

// Find a valid action
while (true) {
if (action_pool.empty()) {
// Move that is guaranteed to be valid
command = {.option = OPTION_DISCARD, .pos_hand = 1};
break;
}

if (!learning or explore or !q_table.contains(gs)) {
// If exploring, fetch a random action from the action pool
std::uniform_int_distribution<uint16_t> dist_pool(
Expand All @@ -491,8 +488,7 @@ bool train_on_game(Game *game, QTable &q_table, ActionSpace &action_space,
// Try all known actions first to see if any are above 0
for (auto it_q = q_table[gs].begin(); it_q != q_table[gs].end(); it_q++) {
Action a = it_q->first;

if (q_table[gs][a] > action_value) {
if (action_index == -1 or q_table[gs][a] > action_value) {
// Find its index in action pool
auto it_ap = std::find(
action_pool.begin(), action_pool.end(), a);
Expand All @@ -501,19 +497,15 @@ bool train_on_game(Game *game, QTable &q_table, ActionSpace &action_space,
if (it_ap == action_pool.end()) continue;

action_index = std::distance(action_pool.begin(), it_ap);

// Found an action explored in the past with a better-than-default value
action_value = q_table[gs][action_pool[action_index]];
action = action_pool[action_index];
action_value = q_table[gs][action];
}
}

if (action_index == -1) {
explore = true;
continue;
}

// Otherwise, pick the optimal action from the q-table
action = action_pool[action_index];
}

// Generate input from action
Expand Down Expand Up @@ -554,15 +546,15 @@ bool train_on_game(Game *game, QTable &q_table, ActionSpace &action_space,
PlayerName winner_name = game->get_winner_name();

if (winner_name != NO_PLAYER) {
if (winner_name == pturn) {
if (winner_name == tc.focus) {
q_table[gs][action] = 1;
winner = true;
} else {
q_table[gs][action] = -1;
}
}

if (learning)
if (learning or winner_name != NO_PLAYER)
q_table[last_gs][last_action] =
q_table[last_gs][last_action] + tc.learning * (
tc.discount * q_table[gs][action] -
Expand Down
58 changes: 31 additions & 27 deletions src/caravan/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,65 +15,69 @@ int main(int argc, char *argv[]) {
std::unique_ptr<Game> game = nullptr;
GameConfig gc;
TrainConfig tc;
uint8_t rand_first;

// Training
QTable q_table;
ActionSpace action_space;

// Random number generator
// Random number generators
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> dist_first_player(
NUM_PLAYER_ABC, NUM_PLAYER_DEF);
std::uniform_int_distribution<uint8_t> dist_num_cards(
DECK_CARAVAN_MIN, DECK_CARAVAN_MAX);
std::uniform_int_distribution<uint8_t> dist_num_samples(
SAMPLE_DECKS_MIN, SAMPLE_DECKS_MAX);
std::uniform_int_distribution<uint8_t> dist_balanced(0, 1);

uint16_t checkpoint = 1000;
uint16_t checkpoint = 10000;
uint16_t num_wins = 0;

// Training parameters TODO user-defined arguments
float discount = 0.95;
float learning = 0.7;
uint32_t episode_max = 1000000;
uint32_t episode_max = 100000;
uint32_t episode_half = episode_max / 2;

try {
// Fill action space with all possible actions
populate_action_space(&action_space);

// Game config uses largest deck with most samples and balance to
// maximise chance of encountering every player hand combination.
// TODO random card and sample sizes
gc = {
.player_abc_cards = DECK_CARAVAN_MAX,
.player_abc_samples = SAMPLE_DECKS_MAX,
.player_abc_balanced = true,
.player_def_cards = DECK_CARAVAN_MAX,
.player_def_samples = SAMPLE_DECKS_MAX,
.player_def_balanced = true
};

// Train config is passed to bots to manage their training.
tc = {
.episode_max = episode_max,
.episode = 1
.episode = 1,
.focus = PLAYER_ABC
};

for(; tc.episode <= tc.episode_max; tc.episode++) {
// Random first player
rand_first = dist_first_player(gen);
gc.player_first = rand_first == NUM_PLAYER_ABC ?
PLAYER_ABC : PLAYER_DEF;
// Game config uses largest deck with most samples and balance to
// maximise chance of encountering every player hand combination.
uint8_t rand_first = dist_first_player(gen);

gc = {
.player_abc_cards = DECK_CARAVAN_MAX,
.player_abc_samples = SAMPLE_DECKS_MAX,
.player_abc_balanced = true,
.player_def_cards = DECK_CARAVAN_MAX,
.player_def_samples = SAMPLE_DECKS_MAX,
.player_def_balanced = true,
.player_first = rand_first == NUM_PLAYER_ABC ? PLAYER_ABC : PLAYER_DEF
};

// Set training parameters
tc.discount = discount;
tc.learning = learning;


if (tc.episode > (tc.episode_max / 2))
if (tc.episode > episode_half) {
tc.explore =
static_cast<float>((tc.episode_max/2) - (tc.episode - (tc.episode_max/2) - 1)) /
static_cast<float>(tc.episode_max/2);
else
static_cast<float>(episode_half - (tc.episode - episode_half - 1)) /
static_cast<float>(episode_half);
} else {
tc.explore = 1.0;

tc.learning = learning;
}

if (tc.episode % checkpoint == 0) {
float per_wins = static_cast<float>(num_wins) / static_cast<float>(checkpoint);
Expand Down

0 comments on commit 111aba3

Please sign in to comment.