Skip to content

Commit

Permalink
Moar tests 🧪, dockerfile fixes and code shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
byt3bl33d3r committed Nov 23, 2023
1 parent b423256 commit 31d6762
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ COPY . .

RUN pip wheel --wheel-dir ./dist '.[api]'

FROM python:3.10-slim
FROM python:3.11-slim

WORKDIR /app

Expand Down
3 changes: 2 additions & 1 deletion dnschef/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dnschef import __version__
from dnschef import kitchen
from dnschef.protocols import start_server
from dnschef.utils import header, parse_config_file
from dnschef.logger import (
log,
Expand Down Expand Up @@ -58,7 +59,7 @@ async def startup_event():

# Launch DNSChef
asyncio.create_task(
kitchen.start_cooking(
start_server(
interface=settings.interface,
nameservers=settings.nameservers,
tcp=settings.tcp,
Expand Down
66 changes: 24 additions & 42 deletions dnschef/kitchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,45 +180,27 @@ def findnametodns(self, qname, qtype):
#return { qtype: { k:v for k,v in CONFIG[qtype].items() if k == top_matched_domains[0] } }
return CONFIG[qtype][top_matched_domains[0]]

async def we_cookin(self, logger, data, addr):
try:
d = DNSRecord.parse(data)
except Exception:
logger.error("invalid DNS request")

else:
# Only Process DNS Queries
if QR[d.header.qr] == "QUERY":

qtype = QTYPE[d.q.qtype]
# Create a custom response to the query
response = DNSRecord(
DNSHeader(id=d.header.id, bitmap=d.header.bitmap, qr=1, aa=1, ra=1),
q=d.q
)

# Gather query parameters
# NOTE: Do not lowercase qname here, because we want to see
# any case request weirdness in the logs.
qname = str(d.q.qname)

# Chop off the last period
if qname[-1] == '.': qname = qname[:-1]

cooked_reply = self.findnametodns(qname, qtype)

# Check if there is a fake record for the current request qtype
if CONFIG.get(qtype) and cooked_reply:
logger.info("cooking response", type=qtype, name=qname) #record=record)

response_func = getattr(
self,
f"do_{qtype}",
self.do_default
)

response.add_answer(
(await response_func(addr, qname, qtype, cooked_reply))
)

return response
async def we_cookin(self, logger, d, qtype, qname, addr):
# Create a custom response to the query
response = DNSRecord(
DNSHeader(id=d.header.id, bitmap=d.header.bitmap, qr=1, aa=1, ra=1),
q=d.q
)

cooked_reply = self.findnametodns(qname, qtype)

# Check if there is a fake record for the current request qtype
if CONFIG.get(qtype) and cooked_reply:
logger.info("cooking response")

response_func = getattr(
self,
f"do_{qtype}",
self.do_default
)

response.add_answer(
(await response_func(addr, qname, qtype, cooked_reply))
)

return response
114 changes: 80 additions & 34 deletions dnschef/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import re
import random
import functools
import enum
from dnslib import DNSRecord, QR, QTYPE
from typing import List
from dnschef.logger import log

from dnschef import kitchen

class ClientProtocol(enum.Enum):
UDP = 1
TCP = 2

class UdpDnsClientProtocol:
def __init__(self, request, on_con_lost):
self.transport = None
Expand Down Expand Up @@ -53,11 +59,11 @@ def connection_lost(self, exc):
self.on_con_lost.set_result(True)

# Obtain a response from a real DNS server.
async def proxy_request(request, host, port=53, protocol="udp"):
async def proxy_request(request, host, protocol: ClientProtocol, port: int = 53):
loop = asyncio.get_running_loop()
on_con_lost = loop.create_future()

if protocol == "udp":
if protocol == ClientProtocol.UDP:
transport, protocol = await loop.create_datagram_endpoint(
lambda: UdpDnsClientProtocol(request, on_con_lost),
remote_addr=(host, int(port)))
Expand All @@ -76,7 +82,7 @@ async def proxy_request(request, host, port=53, protocol="udp"):

class UdpDnsServerProtocol:
def __init__(self, nameservers, dns_kitchen):
self.nameservers = nameservers
self.nameservers = [ re.split('[:#]', ns) for ns in nameservers ]
self.dns_kitchen = dns_kitchen

def connection_made(self, transport):
Expand All @@ -85,24 +91,44 @@ def connection_made(self, transport):
def datagram_received(self, data, addr):
logger = log.bind(address=addr[0], proto="udp")

def _cooked_cb(future):
response = future.result()
if response:
logger.debug("dns packet", packet=response.pack())
self.transport.sendto(response.pack(), addr)
else:
logger.info("proxying response")
nameserver_tuple = re.split('[:#]', random.choice(self.nameservers))

task = asyncio.create_task(proxy_request(data, *nameserver_tuple, protocol='udp'))
task.add_done_callback(functools.partial(lambda c, t, a: t.sendto(c.result(), a), t=self.transport, a=addr))

task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, data, addr))
task.add_done_callback(_cooked_cb)
try:
d = DNSRecord.parse(data)
except Exception:
logger.error("invalid DNS request")
else:
# Only Process DNS Queries
if not QR[d.header.qr] == "QUERY":
logger.warning("received a non-query DNS request")
return

qtype = QTYPE[d.q.qtype]
qname = str(d.q.qname).rstrip('.')
logger = logger.bind(name=qname, type=qtype)

def _cooked_cb(future):
response = future.result()
if response:
logger.debug("dns packet", packet=response.pack())
self.transport.sendto(response.pack(), addr)
else:
logger.info("proxying response")
task = asyncio.create_task(
proxy_request(
data,
*random.choice(self.nameservers),
protocol=ClientProtocol.UDP
)
)
task.add_done_callback(functools.partial(
lambda c, t, a: t.sendto(c.result(), a), t=self.transport, a=addr
))

task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, d, qtype, qname, addr))
task.add_done_callback(_cooked_cb)

class TcpDnsServerProtocol(asyncio.Protocol):
def __init__(self, nameservers, dns_kitchen):
self.nameservers = nameservers
self.nameservers = [ re.split('[:#]', ns) for ns in nameservers ]
self.dns_kitchen = dns_kitchen

def connection_made(self, transport):
Expand All @@ -112,22 +138,42 @@ def data_received(self, data):
addr = self.transport.get_extra_info('peername')
logger = log.bind(address=addr[0], proto="tcp")

def _cooked_cb(future):
response = future.result()
if response:
logger.debug("dns packet", packet=response.pack())
self.transport.write(
len(response.pack()).to_bytes(2, byteorder='big') + response.pack()
)
else:
logger.info("proxying response")
nameserver_tuple = re.split('[:#]', random.choice(self.nameservers))

task = asyncio.create_task(proxy_request(data, *nameserver_tuple, protocol='tcp'))
task.add_done_callback(functools.partial(lambda c, t: t.write(c.result()), t=self.transport))

task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, data[2:], addr))
task.add_done_callback(_cooked_cb)
try:
d = DNSRecord.parse(data[2:])
except Exception:
logger.error("invalid DNS request")
else:
# Only Process DNS Queries
if not QR[d.header.qr] == "QUERY":
logger.warning("received a non-query DNS request")
return

qtype = QTYPE[d.q.qtype]
qname = str(d.q.qname).rstrip('.')
logger = logger.bind(name=qname, type=qtype)

def _cooked_cb(future):
response = future.result()
if response:
logger.debug("dns packet", packet=response.pack())
self.transport.write(
len(response.pack()).to_bytes(2, byteorder='big') + response.pack()
)
else:
logger.info("proxying response")
task = asyncio.create_task(
proxy_request(
data,
*random.choice(self.nameservers),
protocol=ClientProtocol.TCP
)
)
task.add_done_callback(functools.partial(
lambda c, t: t.write(c.result()), t=self.transport
))

task = asyncio.create_task(self.dns_kitchen.we_cookin(logger, d, qtype, qname, addr))
task.add_done_callback(_cooked_cb)

async def start_server(interface: str, nameservers: List[str], tcp: bool = False, ipv6: bool = False, port: int = 53):
loop = asyncio.get_running_loop()
Expand Down
2 changes: 1 addition & 1 deletion dnschef/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dnslib import RDMAP

header = " _ _ __ \n"
header += " | | version {} | | / _| \n".format(__version__)
header += " | | v{} | | / _| \n".format(__version__)
header += " __| |_ __ ___ ___| |__ ___| |_ \n"
header += " / _` | '_ \/ __|/ __| '_ \ / _ \ _|\n"
header += " | (_| | | | \__ \ (__| | | | __/ | \n"
Expand Down
47 changes: 46 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ log_cli = false
log_cli_level = "INFO"
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_cli_date_format = "%Y-%m-%d %H:%M:%S"
filterwarnings = [
# note the use of single quote below to denote "raw" strings in TOML
'ignore:`general_plain_validator_function` is deprecated',
]

[tool.poetry.scripts]
dnschef = 'dnschef.__main__:main'
Expand All @@ -41,6 +45,7 @@ poetry-plugin-export = "^1.6.0"
ruff = "^0.1.6"
dnspython = "^2.4.2"
pytest-cov = "^4.1.0"
httpx = "^0.25.1"

[build-system]
requires = ["poetry-core"]
Expand Down
6 changes: 6 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,12 @@ filelock==3.13.1 ; python_version >= "3.11" and python_version < "4.0" \
h11==0.14.0 ; python_version >= "3.11" and python_version < "4.0" \
--hash=sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d \
--hash=sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761
httpcore==1.0.2 ; python_version >= "3.11" and python_version < "4.0" \
--hash=sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7 \
--hash=sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535
httpx==0.25.1 ; python_version >= "3.11" and python_version < "4.0" \
--hash=sha256:fec7d6cc5c27c578a391f7e87b9aa7d3d8fbcd034f6399f9f79b45bcc12a866a \
--hash=sha256:ffd96d5cf901e63863d9f1b4b6807861dbea4d301613415d9e6e57ead15fc5d0
idna==3.4 ; python_version >= "3.11" and python_version < "4.0" \
--hash=sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4 \
--hash=sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2
Expand Down
16 changes: 13 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,27 @@
import dns.asyncresolver

from dnschef import kitchen
from dnschef.api import app
from dnschef.protocols import start_server
from dnschef.utils import parse_config_file
from dnschef.logger import log, debug_formatter
from dnschef.logger import log, debug_formatter, json_capture_formatter
from fastapi.testclient import TestClient

log.setLevel(logging.DEBUG)
log.handlers[0].setFormatter(debug_formatter)
#log.setLevel(logging.DEBUG)
#log.handlers[0].setFormatter(debug_formatter)

jh = logging.StreamHandler()
jh.setFormatter(json_capture_formatter)
log.addHandler(jh)

@pytest.fixture
def random_string():
return ''.join(random.choices(string.ascii_letters, k=6))

@pytest.fixture
def api_test_client():
return TestClient(app)

@pytest.fixture(scope="session")
def event_loop():
loop = asyncio.get_event_loop_policy().new_event_loop()
Expand Down
Loading

0 comments on commit 31d6762

Please sign in to comment.