diff --git a/agent/agent.py b/agent/agent.py index bc6eafa9aa7..9ade5773c5b 100644 --- a/agent/agent.py +++ b/agent/agent.py @@ -40,7 +40,7 @@ if sys.maxsize > 2**32 and sys.platform == "win32": sys.exit("You should install python3 x86! not x64") -AGENT_VERSION = "0.16" +AGENT_VERSION = "0.17" AGENT_FEATURES = [ "execpy", "execute", @@ -49,6 +49,8 @@ "largefile", "unicodepath", ] +BASE_64_ENCODING = "base64" + if sys.platform == "win32": AGENT_FEATURES.append("mutex") MUTEX_TIMEOUT_MS = 500 @@ -257,26 +259,30 @@ class send_file: """Wrapper that represents Flask.send_file functionality.""" def __init__(self, path, encoding): + self.length = None self.path = path self.status_code = 200 self.encoding = encoding + def okay_to_send(self): + return os.path.isfile(self.path) and os.access(self.path, os.R_OK) + def init(self): - if os.path.isfile(self.path) and os.access(self.path, os.R_OK): - self.length = os.path.getsize(self.path) + if self.okay_to_send(): + if self.encoding != BASE_64_ENCODING: + self.length = os.path.getsize(self.path) else: self.status_code = 404 - self.length = 0 def write(self, httplog, sock): - if not self.length: + if not self.okay_to_send(): return try: with open(self.path, "rb") as f: buf = f.read(1024 * 1024) while buf: - if self.encoding == "base64": + if self.encoding == BASE_64_ENCODING: buf = base64.b64encode(buf) sock.write(buf) buf = f.read(1024 * 1024) @@ -632,7 +638,7 @@ def do_execute(): else: p = subprocess.Popen(command_to_execute, shell=shell, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = p.communicate() - if request.form.get("encoding", "") == "base64": + if request.form.get("encoding", "") == BASE_64_ENCODING: stdout = base64.b64encode(stdout) stderr = base64.b64encode(stderr) except Exception as ex: @@ -693,7 +699,7 @@ def do_execpy(): # Execute the command asynchronously? As a shell command? async_exec = "async" in request.form - base64_encode = request.form.get("encoding", "") == "base64" + base64_encode = request.form.get("encoding", "") == BASE_64_ENCODING cwd = request.form.get("cwd") diff --git a/agent/test_agent.py b/agent/test_agent.py index 40fe686339a..6dbfb76cccb 100644 --- a/agent/test_agent.py +++ b/agent/test_agent.py @@ -1,5 +1,5 @@ """Tests for the agent.""" - +import base64 import datetime import io import json @@ -473,6 +473,14 @@ def test_retrieve(self): assert r.status_code == 200 assert first_line in r.text assert last_line in r.text + # Also test the base64-encoded retrieval. + form["encoding"] = "base64" + r = requests.post(f"{BASE_URL}/retrieve", data=form) + assert r.status_code == 200 + decoded = base64.b64decode(r.text + "==").decode() + assert "test data" in decoded + assert first_line in decoded + assert last_line in decoded def test_retrieve_invalid(self): js = self.post_form("retrieve", {}, 400)