diff --git a/Dockerfile b/Dockerfile index a7815c6..fdb30c3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ COPY . . RUN pip wheel --wheel-dir ./dist '.[api]' -FROM python:3.10-slim +FROM python:3.11-slim WORKDIR /app diff --git a/dnschef/api.py b/dnschef/api.py index 86412e0..23e1c95 100644 --- a/dnschef/api.py +++ b/dnschef/api.py @@ -1,5 +1,6 @@ from dnschef import __version__ from dnschef import kitchen +from dnschef.protocols import start_server from dnschef.utils import header, parse_config_file from dnschef.logger import ( log, @@ -58,7 +59,7 @@ async def startup_event(): # Launch DNSChef asyncio.create_task( - kitchen.start_cooking( + start_server( interface=settings.interface, nameservers=settings.nameservers, tcp=settings.tcp, diff --git a/dnschef/kitchen.py b/dnschef/kitchen.py index 6c9194d..c2242f0 100644 --- a/dnschef/kitchen.py +++ b/dnschef/kitchen.py @@ -180,45 +180,27 @@ def findnametodns(self, qname, qtype): #return { qtype: { k:v for k,v in CONFIG[qtype].items() if k == top_matched_domains[0] } } return CONFIG[qtype][top_matched_domains[0]] - async def we_cookin(self, logger, data, addr): - try: - d = DNSRecord.parse(data) - except Exception: - logger.error("invalid DNS request") - - else: - # Only Process DNS Queries - if QR[d.header.qr] == "QUERY": - - qtype = QTYPE[d.q.qtype] - # Create a custom response to the query - response = DNSRecord( - DNSHeader(id=d.header.id, bitmap=d.header.bitmap, qr=1, aa=1, ra=1), - q=d.q - ) - - # Gather query parameters - # NOTE: Do not lowercase qname here, because we want to see - # any case request weirdness in the logs. - qname = str(d.q.qname) - - # Chop off the last period - if qname[-1] == '.': qname = qname[:-1] - - cooked_reply = self.findnametodns(qname, qtype) - - # Check if there is a fake record for the current request qtype - if CONFIG.get(qtype) and cooked_reply: - logger.info("cooking response", type=qtype, name=qname) #record=record) - - response_func = getattr( - self, - f"do_{qtype}", - self.do_default - ) - - response.add_answer( - (await response_func(addr, qname, qtype, cooked_reply)) - ) - - return response + async def we_cookin(self, logger, d, qtype, qname, addr): + # Create a custom response to the query + response = DNSRecord( + DNSHeader(id=d.header.id, bitmap=d.header.bitmap, qr=1, aa=1, ra=1), + q=d.q + ) + + cooked_reply = self.findnametodns(qname, qtype) + + # Check if there is a fake record for the current request qtype + if CONFIG.get(qtype) and cooked_reply: + logger.info("cooking response") + + response_func = getattr( + self, + f"do_{qtype}", + self.do_default + ) + + response.add_answer( + (await response_func(addr, qname, qtype, cooked_reply)) + ) + + return response diff --git a/dnschef/protocols.py b/dnschef/protocols.py index 0253cd1..f7c231a 100644 --- a/dnschef/protocols.py +++ b/dnschef/protocols.py @@ -3,11 +3,17 @@ import re import random import functools +import enum +from dnslib import DNSRecord, QR, QTYPE from typing import List from dnschef.logger import log from dnschef import kitchen +class ClientProtocol(enum.Enum): + UDP = 1 + TCP = 2 + class UdpDnsClientProtocol: def __init__(self, request, on_con_lost): self.transport = None @@ -53,11 +59,11 @@ def connection_lost(self, exc): self.on_con_lost.set_result(True) # Obtain a response from a real DNS server. -async def proxy_request(request, host, port=53, protocol="udp"): +async def proxy_request(request, host, protocol: ClientProtocol, port: int = 53): loop = asyncio.get_running_loop() on_con_lost = loop.create_future() - if protocol == "udp": + if protocol == ClientProtocol.UDP: transport, protocol = await loop.create_datagram_endpoint( lambda: UdpDnsClientProtocol(request, on_con_lost), remote_addr=(host, int(port))) @@ -76,7 +82,7 @@ async def proxy_request(request, host, port=53, protocol="udp"): class UdpDnsServerProtocol: def __init__(self, nameservers, dns_kitchen): - self.nameservers = nameservers + self.nameservers = [ re.split('[:#]', ns) for ns in nameservers ] self.dns_kitchen = dns_kitchen def connection_made(self, transport): @@ -85,24 +91,44 @@ def connection_made(self, transport): def datagram_received(self, data, addr): logger = log.bind(address=addr[0], proto="udp") - def _cooked_cb(future): - response = future.result() - if response: - logger.debug("dns packet", packet=response.pack()) - self.transport.sendto(response.pack(), addr) - else: - logger.info("proxying response") - nameserver_tuple = re.split('[:#]', random.choice(self.nameservers)) - - task = asyncio.create_task(proxy_request(data, *nameserver_tuple, protocol='udp')) - task.add_done_callback(functools.partial(lambda c, t, a: t.sendto(c.result(), a), t=self.transport, a=addr)) - - task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, data, addr)) - task.add_done_callback(_cooked_cb) + try: + d = DNSRecord.parse(data) + except Exception: + logger.error("invalid DNS request") + else: + # Only Process DNS Queries + if not QR[d.header.qr] == "QUERY": + logger.warning("received a non-query DNS request") + return + + qtype = QTYPE[d.q.qtype] + qname = str(d.q.qname).rstrip('.') + logger = logger.bind(name=qname, type=qtype) + + def _cooked_cb(future): + response = future.result() + if response: + logger.debug("dns packet", packet=response.pack()) + self.transport.sendto(response.pack(), addr) + else: + logger.info("proxying response") + task = asyncio.create_task( + proxy_request( + data, + *random.choice(self.nameservers), + protocol=ClientProtocol.UDP + ) + ) + task.add_done_callback(functools.partial( + lambda c, t, a: t.sendto(c.result(), a), t=self.transport, a=addr + )) + + task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, d, qtype, qname, addr)) + task.add_done_callback(_cooked_cb) class TcpDnsServerProtocol(asyncio.Protocol): def __init__(self, nameservers, dns_kitchen): - self.nameservers = nameservers + self.nameservers = [ re.split('[:#]', ns) for ns in nameservers ] self.dns_kitchen = dns_kitchen def connection_made(self, transport): @@ -112,22 +138,42 @@ def data_received(self, data): addr = self.transport.get_extra_info('peername') logger = log.bind(address=addr[0], proto="tcp") - def _cooked_cb(future): - response = future.result() - if response: - logger.debug("dns packet", packet=response.pack()) - self.transport.write( - len(response.pack()).to_bytes(2, byteorder='big') + response.pack() - ) - else: - logger.info("proxying response") - nameserver_tuple = re.split('[:#]', random.choice(self.nameservers)) - - task = asyncio.create_task(proxy_request(data, *nameserver_tuple, protocol='tcp')) - task.add_done_callback(functools.partial(lambda c, t: t.write(c.result()), t=self.transport)) - - task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, data[2:], addr)) - task.add_done_callback(_cooked_cb) + try: + d = DNSRecord.parse(data[2:]) + except Exception: + logger.error("invalid DNS request") + else: + # Only Process DNS Queries + if not QR[d.header.qr] == "QUERY": + logger.warning("received a non-query DNS request") + return + + qtype = QTYPE[d.q.qtype] + qname = str(d.q.qname).rstrip('.') + logger = logger.bind(name=qname, type=qtype) + + def _cooked_cb(future): + response = future.result() + if response: + logger.debug("dns packet", packet=response.pack()) + self.transport.write( + len(response.pack()).to_bytes(2, byteorder='big') + response.pack() + ) + else: + logger.info("proxying response") + task = asyncio.create_task( + proxy_request( + data, + *random.choice(self.nameservers), + protocol=ClientProtocol.TCP + ) + ) + task.add_done_callback(functools.partial( + lambda c, t: t.write(c.result()), t=self.transport + )) + + task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, d, qtype, qname, addr)) + task.add_done_callback(_cooked_cb) async def start_server(interface: str, nameservers: List[str], tcp: bool = False, ipv6: bool = False, port: int = 53): loop = asyncio.get_running_loop() diff --git a/dnschef/utils.py b/dnschef/utils.py index 2ec35d5..8dc0648 100644 --- a/dnschef/utils.py +++ b/dnschef/utils.py @@ -4,7 +4,7 @@ from dnslib import RDMAP header = " _ _ __ \n" -header += " | | version {} | | / _| \n".format(__version__) +header += " | | v{} | | / _| \n".format(__version__) header += " __| |_ __ ___ ___| |__ ___| |_ \n" header += " / _` | '_ \/ __|/ __| '_ \ / _ \ _|\n" header += " | (_| | | | \__ \ (__| | | | __/ | \n" diff --git a/poetry.lock b/poetry.lock index 149d640..1c8a076 100644 --- a/poetry.lock +++ b/poetry.lock @@ -585,6 +585,51 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.25.1" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.25.1-py3-none-any.whl", hash = "sha256:fec7d6cc5c27c578a391f7e87b9aa7d3d8fbcd034f6399f9f79b45bcc12a866a"}, + {file = "httpx-0.25.1.tar.gz", hash = "sha256:ffd96d5cf901e63863d9f1b4b6807861dbea4d301613415d9e6e57ead15fc5d0"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.4" @@ -1655,4 +1700,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "08f9b2d3b9c7a7b2c3f77cfd06adc136b2612e3de5f8aba83a197649d5a57a76" +content-hash = "93b0598afa3abe86f8fab92c00f1ad1b832360860e4b922503fc12f7114aaf99" diff --git a/pyproject.toml b/pyproject.toml index b80dde8..b7a10ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,10 @@ log_cli = false log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +filterwarnings = [ + # note the use of single quote below to denote "raw" strings in TOML + 'ignore:`general_plain_validator_function` is deprecated', +] [tool.poetry.scripts] dnschef = 'dnschef.__main__:main' @@ -41,6 +45,7 @@ poetry-plugin-export = "^1.6.0" ruff = "^0.1.6" dnspython = "^2.4.2" pytest-cov = "^4.1.0" +httpx = "^0.25.1" [build-system] requires = ["poetry-core"] diff --git a/requirements-dev.txt b/requirements-dev.txt index bc59370..0692ce7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -325,6 +325,12 @@ filelock==3.13.1 ; python_version >= "3.11" and python_version < "4.0" \ h11==0.14.0 ; python_version >= "3.11" and python_version < "4.0" \ --hash=sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d \ --hash=sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761 +httpcore==1.0.2 ; python_version >= "3.11" and python_version < "4.0" \ + --hash=sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7 \ + --hash=sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535 +httpx==0.25.1 ; python_version >= "3.11" and python_version < "4.0" \ + --hash=sha256:fec7d6cc5c27c578a391f7e87b9aa7d3d8fbcd034f6399f9f79b45bcc12a866a \ + --hash=sha256:ffd96d5cf901e63863d9f1b4b6807861dbea4d301613415d9e6e57ead15fc5d0 idna==3.4 ; python_version >= "3.11" and python_version < "4.0" \ --hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \ --hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2 diff --git a/tests/conftest.py b/tests/conftest.py index 98853cc..34df6c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,17 +8,27 @@ import dns.asyncresolver from dnschef import kitchen +from dnschef.api import app from dnschef.protocols import start_server from dnschef.utils import parse_config_file -from dnschef.logger import log, debug_formatter +from dnschef.logger import log, debug_formatter, json_capture_formatter +from fastapi.testclient import TestClient -log.setLevel(logging.DEBUG) -log.handlers[0].setFormatter(debug_formatter) +#log.setLevel(logging.DEBUG) +#log.handlers[0].setFormatter(debug_formatter) + +jh = logging.StreamHandler() +jh.setFormatter(json_capture_formatter) +log.addHandler(jh) @pytest.fixture def random_string(): return ''.join(random.choices(string.ascii_letters, k=6)) +@pytest.fixture +def api_test_client(): + return TestClient(app) + @pytest.fixture(scope="session") def event_loop(): loop = asyncio.get_event_loop_policy().new_event_loop() diff --git a/tests/test_http_api.py b/tests/test_http_api.py new file mode 100644 index 0000000..8addf78 --- /dev/null +++ b/tests/test_http_api.py @@ -0,0 +1,57 @@ +import json + +def test_get_records(api_test_client, config_file): + r = api_test_client.get("/") + assert r.status_code == 200 + assert r.json() == config_file + +def test_add_record(api_test_client): + r = api_test_client.put( + "/", + json={"type": "A", "domain": "*.nashvillenibblers.com", "value": "192.168.69.69"} + ) + assert r.status_code == 200 + + r = api_test_client.get("/") + assert r.status_code == 200 + assert r.json()["A"]["*.nashvillenibblers.com"] == "192.168.69.69" + +def test_delete_record(api_test_client): + r = api_test_client.request( + method="DELETE", + url="/", + content=json.dumps({"type": "A", "domain": "*.nashvillenibblers.com", "value": "192.168.69.69"}).encode() + ) + + assert r.status_code == 200 + + r = api_test_client.get("/") + assert r.status_code == 200 + assert not r.json()["A"].get("*.nashvillenibblers.com", None) + +def test_logs(api_test_client): + r = api_test_client.get("/logs") + assert r.status_code == 200 + + r = api_test_client.get( + "/logs", + params={"type": "A"} + ) + assert r.status_code == 200 + assert len(r.json()) + + r = api_test_client.get( + "/logs", + params={"name": "fuck.shit.com"} + ) + + assert r.status_code == 200 + assert len(r.json()) + + r = api_test_client.get( + "/logs", + params={"name": "fuck.shit.com", "type": "A"} + ) + + assert r.status_code == 200 + assert len(r.json())