Skip to content

Commit

Permalink
Add unit test for bean manager commit & reload
Browse files Browse the repository at this point in the history
  • Loading branch information
StdioA committed Aug 26, 2024
1 parent 94d53c2 commit 995251c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
22 changes: 14 additions & 8 deletions bean_utils/bean.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from beancount import loader
from beancount.parser import parser
from beancount.query import query
from beancount.core.data import Open, Transaction
from beancount.core.data import Open, Close, Transaction
from beancount.core.number import MISSING
from typing import List
from bean_utils.vec_query import query_txs
Expand All @@ -35,22 +35,28 @@ def __init__(self, fname=None) -> None:

def _load(self):
self._entries, errors, self._options = loader.load_file(self.fname)
self._accounts = []
self._accounts = set()
self.mtimes = {}
self.account_files = set()
for ent in self._entries:
if isinstance(ent, Open):
self._accounts.append(ent.account)
self._accounts.add(ent.account)
self.account_files.add(ent.meta["filename"])
elif isinstance(ent, Close):
self._accounts.remove(ent.account)
self.account_files.add(ent.meta["filename"])

# Fill mtime
for f in self._options["include"]:
self.mtimes[f] = Path(f).stat().st_mtime

def _auto_reload(self, accounts_only=False):
# Check and reload
for f, mtime in self.mtimes.items():
if accounts_only and ("accounts" not in f):
continue
if mtime != Path(f).stat().st_mtime:
files_to_check = self.mtimes.keys()
if accounts_only:
files_to_check = self.account_files
for fname in files_to_check:
if self.mtimes[fname] != Path(fname).stat().st_mtime:
self._load()
return

Expand Down Expand Up @@ -199,7 +205,7 @@ def commit_trx(self, data):
fname = self.fname
with open(fname, 'a') as f:
f.write("\n" + data + "\n")
subprocess.run(["bean-format", "-o", shlex.quote(fname), shlex.quote(fname)], # noqa: S607,S603
subprocess.run(["bean-format", "-o", shlex.quote(str(fname)), shlex.quote(str(fname))], # noqa: S607,S603
shell=False)


Expand Down
61 changes: 55 additions & 6 deletions bean_utils/bean_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime
import shutil
from pathlib import Path
import requests
import pytest
from conf.conf_test import load_config_from_dict, clear_config
Expand Down Expand Up @@ -69,16 +71,26 @@ def test_account_search(mock_config):


def assert_txs_equal(tx1_str, tx2_str):
tx1 = parser.parse_string(tx1_str)[0][0]
tx2 = parser.parse_string(tx2_str)[0][0]
if isinstance(tx1_str, str):
tx1 = parser.parse_string(tx1_str)[0][0]
else:
tx1 = tx1_str
if isinstance(tx2_str, str):
tx2 = parser.parse_string(tx2_str)[0][0]
else:
tx2 = tx2_str

def clean_lineno(tx):
tx.meta["lineno"] = 0
def clean_meta(tx):
keys = list(tx.meta.keys())
for key in keys:
del tx.meta[key]
for p in tx.postings:
p.meta["lineno"] = 0
keys = list(p.meta.keys())
for key in keys:
del p.meta[key]
return tx

assert clean_lineno(tx1) == clean_lineno(tx2)
assert clean_meta(tx1) == clean_meta(tx2)


def test_build_txs(mock_config):
Expand Down Expand Up @@ -257,3 +269,40 @@ def test_parse_args():

with pytest.raises(ValueError, match=bean.ArgsError.args[0]):
bean.parse_args("a “b c'")


@pytest.fixture
def copied_bean(tmp_path):
new_bean = tmp_path / "example.bean"
shutil.copyfile("testdata/example.bean", new_bean)
yield new_bean
Path(new_bean).unlink()


def test_manager_reload(mock_config, copied_bean):
manager = bean.BeanManager(copied_bean)
account_amount = len(manager.accounts)
entry_amount = len(manager.entries)
assert len(manager.accounts) == 63
assert len(manager.entries) == 2037

# Append a "close" entry
with open(copied_bean, "a") as f:
f.write(f"{today} close Assets:US:BofA:Checking\n")

# The account amount should reloaded
assert len(manager.accounts) == account_amount - 1
assert len(manager.entries) == entry_amount + 1


def test_manager_commmit(mock_config, copied_bean):
manager = bean.BeanManager(copied_bean)
assert len(manager.entries) == 2037
txs = f"""
{today} * "Test Payee" "Test Narration"
Liabilities:US:Chase:Slate -12.30 USD
Expenses:Food:Restaurant 12.30 USD
"""
manager.commit_trx(txs)
assert len(manager.entries) == 2038
assert_txs_equal(manager.entries[-1], txs)

0 comments on commit 995251c

Please sign in to comment.