diff --git a/dnschef/kitchen.py b/dnschef/kitchen.py index c2242f0..ef44ecd 100644 --- a/dnschef/kitchen.py +++ b/dnschef/kitchen.py @@ -37,7 +37,7 @@ def get_file_chunk(file_path, chunk_index, chunk_size): return next(itertools.islice( chunk_file(file_path, chunk_size), chunk_index, - chunk_index+1 + chunk_index + 1 ), b'') @@ -65,28 +65,36 @@ async def do_A(self, addr, qname, qtype, record): log.warning(f"chunk_size {chunk_size} is too large for A record, defaulting to 4") chunk_size = 4 - file_chunk = await stage_file(qname, record, chunk_size) - if file_chunk: - record = file_chunk + record = await stage_file(qname, record, chunk_size) + if record and len(record) < 4: + record = record.ljust(4, b'\x00') - ipv4_hex_tuple = list(map(int, IPv4Address(record).packed)) - return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](ipv4_hex_tuple)) + if record: + ipv4_hex_tuple = list(map(int, IPv4Address(record).packed)) + return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](ipv4_hex_tuple)) async def do_TXT(self, addr, qname, qtype, record): if isinstance(record, dict): - prefix = random.choice(record.get('response_prefix_pool')) - response_format = record.get('response_format') + chunk_size = record.get('chunk_size') + prefix = random.choice(record.get('response_prefix_pool', [''])) + response_format = record.get('response_format', '{prefix}{chunk}') space_left = 255 - len(response_format.format(prefix=prefix, chunk='')) max_data_len = ( space_left // 4 ) * 3 - file_chunk = await stage_file(qname, record, chunk_size=max_data_len) - if file_chunk: - record = response_format.format(prefix=prefix, chunk=base64.b64encode(file_chunk).decode()) + if chunk_size: + max_data_len = min(chunk_size, max_data_len) + if chunk_size > max_data_len: + log.warning(f"chunk_size {chunk_size} is too large for the TXT record, defaulting to {max_data_len}") - # dnslib doesn't like trailing dots - if record[-1] == ".": record = record[:-1] - return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](record)) + record = await stage_file(qname, record, chunk_size=max_data_len) + if record: + record = response_format.format(prefix=prefix, chunk=base64.b64encode(record).decode()) + + if record: + # dnslib doesn't like trailing dots + record = record.rstrip('.') + return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](record)) async def do_AAAA(self, addr, qname, qtype, record): if isinstance(record, dict): @@ -95,12 +103,13 @@ async def do_AAAA(self, addr, qname, qtype, record): log.warning(f"chunk_size {chunk_size} is too large for AAAA record, defaulting to 16") chunk_size = 16 - file_chunk = await stage_file(qname, record, chunk_size) - if file_chunk: - record = file_chunk + record = await stage_file(qname, record, chunk_size) + if record and len(record) < 16: + record = record.ljust(16, b'\x00') - ipv6_hex_tuple = list(map(int, IPv6Address(record).packed)) - return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](ipv6_hex_tuple)) + if record: + ipv6_hex_tuple = list(map(int, IPv6Address(record).packed)) + return RR(qname, getattr(QTYPE, qtype), rdata=RDMAP[qtype](ipv6_hex_tuple)) async def do_HTTPS(self, addr, qname, qtype, record): kv_pairs = record.split(" ") @@ -199,8 +208,8 @@ async def we_cookin(self, logger, d, qtype, qname, addr): self.do_default ) - response.add_answer( - (await response_func(addr, qname, qtype, cooked_reply)) - ) + answer = await response_func(addr, qname, qtype, cooked_reply) + if answer: + response.add_answer(answer) return response diff --git a/tests/conftest.py b/tests/conftest.py index 9300ef4..06080a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,14 @@ def random_string(): return ''.join(random.choices(string.ascii_letters, k=6)) +@pytest.fixture +def random_string_gen(): + def _random_gen(): + while True: + yield ''.join(random.choices(string.ascii_letters, k=6)) + + return _random_gen() + @pytest.fixture def api_test_client(): return TestClient(app) diff --git a/tests/dnschef-tests.toml b/tests/dnschef-tests.toml index 4e74bbe..578b389 100644 --- a/tests/dnschef-tests.toml +++ b/tests/dnschef-tests.toml @@ -4,11 +4,11 @@ "*.*.thesprawl.org" = "1.1.1.1" "c.*.*.thesprawl.org" = "1.1.2.2" "fuck.shit.com" = "192.168.0.1" -"*.wat.org" = { file = "./requirements.txt", chunk_size = 122 } +"*.wat.org" = { file = "tests/small-bin-test", chunk_size = 122 } [AAAA] # Queries for IPv6 address records "*.thesprawl.org" = "2001:db8::1" -"*.wat.org" = { file = "./requirements.txt", chunk_size = 122 } +"*.wat.org" = { file = "tests/small-bin-test", chunk_size = 122 } [MX] # Queries for mail server records "*.thesprawl.org" = "mail.fake.com" @@ -23,11 +23,13 @@ "*.thesprawl.org" = "fake message" "ok.thesprawl.org" = "fake message" "*.something.wattahog.org" = "fuck off" -"wa*.aint.nothing.org" = "sequoia banshee buggers" -"ns*.shit.fuck.org" = { file = "./requirements.txt", chunk_size = 189, response_format = "{prefix}test-{chunk}", response_prefix_pool = ["atlassian-domain-verification=", "onetrust-domain-verification=", "docusign=" ] } +"wa*.aint.nothing.org" = "sequoia banshee boogers" +"ns*.shit.fuck.org" = { file = "tests/thicc-bin-test", chunk_size = 189, response_format = "{prefix}test-{chunk}", response_prefix_pool = ["atlassian-domain-verification=", "onetrust-domain-verification=", "docusign=" ] } +"ns*.fronted.brick.org" = { file = "tests/thicc-bin-test" } +"ns*.filtered.crack.org" = { file = "tests/thicc-bin-test", chunk_size = 50, response_format = "{prefix}test-{chunk}", response_prefix_pool = ["atlassian-domain-verification=", "onetrust-domain-verification=", "docusign=" ] } [TXT."*.wattahog.org"] -file = "./requirements.txt" +file = "tests/thicc-bin-test" chunk_size = 189 response_format = "{prefix}test-{chunk}" response_prefix_pool = [ "atlassian-domain-verification=", "onetrust-domain-verification=" , "docusign=" ] diff --git a/tests/small-bin-test b/tests/small-bin-test new file mode 100755 index 0000000..fa4f671 --- /dev/null +++ b/tests/small-bin-test @@ -0,0 +1,3 @@ +#!/bin/sh +cmd=${0##*/} +exec grep -F "$@" diff --git a/tests/test_dns_server.py b/tests/test_dns_server.py index a68986e..b9dddfe 100644 --- a/tests/test_dns_server.py +++ b/tests/test_dns_server.py @@ -35,10 +35,10 @@ async def test_correct_wildcard_behavior(dns_client): assert answers[0].address == "1.1.2.2" answers = await dns_client.resolve("wa1.aint.nothing.org", "TXT", tcp=proto) - assert answers[0].to_text().strip('"') == 'sequoia banshee buggers' + assert answers[0].to_text().strip('"') == 'sequoia banshee boogers' answers = await dns_client.resolve("wattahog.aint.nothing.org", "TXT", tcp=proto) - assert answers[0].to_text().strip('"') == 'sequoia banshee buggers' + assert answers[0].to_text().strip('"') == 'sequoia banshee boogers' @pytest.mark.asyncio diff --git a/tests/test_file_staging.py b/tests/test_file_staging.py new file mode 100644 index 0000000..85b056a --- /dev/null +++ b/tests/test_file_staging.py @@ -0,0 +1,75 @@ +import pytest +import hashlib +from base64 import b64decode +from ipaddress import IPv4Address, IPv6Address + + +def compare_file_digests(tmp_file_path, orig_file_path): + with tmp_file_path.open('rb') as staged_file: + with open(orig_file_path, 'rb') as orig_file: + staged_file_digest = hashlib.file_digest(staged_file, "md5").digest() + orig_file_digest = hashlib.file_digest(orig_file, "md5").digest() + + return staged_file_digest == orig_file_digest + +@pytest.mark.asyncio +async def test_A_file_staging(dns_client, tmp_path, random_string_gen): + orig_file_path = "tests/small-bin-test" + for proto in [False, True]: + chunk_n = 0 + tmp_file_path = tmp_path / next(random_string_gen) + with tmp_file_path.open('ab') as f: + while True: + answers = await dns_client.resolve(f"lala{chunk_n}dayum.wat.org", "A", tcp=proto, raise_on_no_answer=False) + print(list(answers)) + for answer in answers: + data = IPv4Address(answer.address).packed + data = data.replace(b'\x00', b'') + f.write(data) + + if not len(answers): + break + + chunk_n += 1 + + assert compare_file_digests(tmp_file_path, orig_file_path) == True + +@pytest.mark.asyncio +async def test_AAAA_file_staging(dns_client, tmp_path, random_string_gen): + orig_file_path = "tests/small-bin-test" + for proto in [False, True]: + chunk_n = 0 + tmp_file_path = tmp_path / next(random_string_gen) + with tmp_file_path.open('ab') as f: + while True: + answers = await dns_client.resolve(f"lala{chunk_n}dayum.wat.org", "AAAA", tcp=proto, raise_on_no_answer=False) + for answer in answers: + data = IPv6Address(answer.address).packed + data = data.replace(b'\x00', b'') + f.write(data) + + if not len(answers): + break + + chunk_n += 1 + + assert compare_file_digests(tmp_file_path, orig_file_path) == True + +@pytest.mark.asyncio +async def test_TXT_file_staging(dns_client, tmp_path, random_string_gen): + orig_file_path = "tests/thicc-bin-test" + for proto in [False, True]: + chunk_n = 0 + tmp_file_path = tmp_path / next(random_string_gen) + with tmp_file_path.open('ab') as f: + while True: + answers = await dns_client.resolve(f"ns{chunk_n}.fronted.brick.org", "TXT", tcp=proto, raise_on_no_answer=False) + for answer in answers: + f.write(b64decode(answer.to_text().strip('"'))) + + if not len(answers): + break + + chunk_n += 1 + + assert compare_file_digests(tmp_file_path, orig_file_path) == True diff --git a/tests/thicc-bin-test b/tests/thicc-bin-test new file mode 100755 index 0000000..18baa06 Binary files /dev/null and b/tests/thicc-bin-test differ