From 082e5c19817401ff8e9ec4f99b8ec9f3daf9cc2b Mon Sep 17 00:00:00 2001 From: Stijn Tintel Date: Wed, 20 Dec 2023 22:33:53 +0200 Subject: [PATCH] WIP: was: support SR model update in post_config --- app/internal/was.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/app/internal/was.py b/app/internal/was.py index d0616f4..b757703 100644 --- a/app/internal/was.py +++ b/app/internal/was.py @@ -252,6 +252,26 @@ def get_nvs(): return get_json_from_file(STORAGE_USER_NVS) +def get_asset_url(asset_type, asset): + was_url = get_was_url() + parsed = urllib.parse.urlparse(was_url) + + if parsed.scheme == "ws": + parsed = parsed._replace(scheme="http") + elif parsed.scheme == "wss": + parsed = parsed._replace(scheme="https") + + parsed = parsed._replace(path="/api/asset") + + query = f"asset={asset}&type={asset_type}" + parsed = parsed._replace(query=query) + + url = urllib.parse.urlunparse(parsed) + + log.debug("get_asset_url: url='{url}'") + return url + + # TODO: Support HTTPs def get_release_url(was_url, version, platform): url_parts = re.match(r"^(?:\w+:\/\/)?([^\/:]+)(?::(\d+))?", was_url) @@ -376,10 +396,17 @@ def merge_dict(dict_1, dict_2): async def post_config(request, apply=False): data = await request.json() + if 'hostname' in data: hostname = data["hostname"] data = get_config_db() - msg = build_msg(data, "config") + # TODO only flash srmodel when needed + if "wake_model" in data and "wake_word_friendly" in data: + ota_url = get_asset_url("other", "srmodels.bin") + msg = json.dumps({'cmd': 'srmodels_ota_start', 'ota_url': ota_url}) + else: + msg = build_msg(data, "config") + try: ws = request.app.connmgr.get_client_by_hostname(hostname) await ws.send_text(msg) @@ -388,12 +415,16 @@ async def post_config(request, apply=False): log.error(f"Failed to apply config to {hostname} ({e})") return "Error" else: + if "wake_model" in data and "wake_word_friendly" in data: + build_srmodels_bin([data["wake_model"]]) + if "wis_tts_url" in data: data["wis_tts_url_v2"] = construct_wis_tts_url(data["wis_tts_url"]) del data["wis_tts_url"] log.debug(f"wis_tts_url_v2: {data['wis_tts_url_v2']}") save_config_to_db(data) + # TODO support flash srmodel broadcast msg = build_msg(data, "config") log.debug(str(msg)) if apply: