forked from Corollarium/localtls
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dnsserver.py
375 lines (337 loc) · 11.6 KB
/
dnsserver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
#!/usr/bin/env python3.6
# -*- coding: utf-8 -*-
import json
import logging
import os
import sys
import signal
import re
import socket
import argparse
import ipaddress
from datetime import datetime
from time import sleep
import threading
from multiprocessing.connection import Listener
import dnslib
from dnslib import DNSLabel, QTYPE, RR, dns
from dnslib.proxy import ProxyResolver
from dnslib.server import DNSServer, DNSLogger
import httpserver
import confs
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
handler.setFormatter(logging.Formatter('%(asctime)s: %(message)s', datefmt='%H:%M:%S'))
logger = logging.getLogger('localtls')
logger.addHandler(handler)
TYPE_LOOKUP = {
'A': (dns.A, QTYPE.A),
'AAAA': (dns.AAAA, QTYPE.AAAA),
'CAA': (dns.CAA, QTYPE.CAA),
'CNAME': (dns.CNAME, QTYPE.CNAME),
'DNSKEY': (dns.DNSKEY, QTYPE.DNSKEY),
'MX': (dns.MX, QTYPE.MX),
'NAPTR': (dns.NAPTR, QTYPE.NAPTR),
'NS': (dns.NS, QTYPE.NS),
'PTR': (dns.PTR, QTYPE.PTR),
'RRSIG': (dns.RRSIG, QTYPE.RRSIG),
'SOA': (dns.SOA, QTYPE.SOA),
'SRV': (dns.SRV, QTYPE.SRV),
'TXT': (dns.TXT, QTYPE.TXT),
}
TXT_RECORDS = {}
def get_ipv4():
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
# doesn't even have to be reachable
s.connect(('10.255.255.255', 1))
IP = s.getsockname()[0]
except:
IP = ''
finally:
s.close()
return IP
def get_ipv6():
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
try:
s.connect(('2001:0db8:85a3:0000:0000:8a2e:0370:7334', 1))
IP = s.getsockname()[0]
except:
IP = ''
finally:
s.close()
return IP
class Resolver(ProxyResolver):
def __init__(self, upstream):
super().__init__(upstream, 53, 5)
if confs.SOA_MNAME and confs.SOA_RNAME:
self.SOA = dnslib.SOA(
mname=DNSLabel(confs.SOA_MNAME),
rname=DNSLabel(confs.SOA_RNAME.replace('@', '.')), # TODO: . before @ should be escaped
times=(
confs.SOA_SERIAL, # serial number
60 * 60 * 1, # refresh
60 * 60 * 2, # retry
60 * 60 * 24, # expire
60 * 60 * 1, # minimum
)
)
else:
self.SOA=None
if confs.NS_SERVERS:
self.NS = [dnslib.NS(i) for i in confs.NS_SERVERS]
else:
self.NS = []
def match_suffix_insensitive(self, request):
name = request.q.qname
# skip the last dot
suffixLower = str(name)[-len(confs.BASE_DOMAIN)-1:-1].lower()
return suffixLower == confs.BASE_DOMAIN
def resolve(self, request, handler):
global TXT_RECORDS
reply = request.reply()
name = request.q.qname
logger.info("query %s", request.q.qname)
# handle the main domain
if (name == confs.BASE_DOMAIN or
name == '_acme-challenge.' + confs.BASE_DOMAIN
):
r = RR(
rname=request.q.qname,
rdata=dns.A(confs.LOCAL_IPV4),
rtype=QTYPE.A,
ttl=60*60
)
reply.add_answer(r)
if self.SOA:
r = RR(
rname=request.q.qname,
rdata=self.SOA,
rtype=QTYPE.SOA,
ttl=60*60
)
reply.add_answer(r)
if len(self.NS):
for i in self.NS:
r = RR(
rname=request.q.qname,
rdata=i,
rtype=QTYPE.NS,
ttl=60*60
)
reply.add_answer(r)
if confs.LOCAL_IPV6:
r = RR(
rname=request.q.qname,
rdata=dns.AAAA(confs.LOCAL_IPV6),
rtype=QTYPE.AAAA,
ttl=60*60
)
reply.add_answer(r)
if len(TXT_RECORDS):
r = RR(
rname=request.q.qname,
rdata=dns.TXT(['{1}'.format(k, v) for k, v in TXT_RECORDS.items()]),
rtype=QTYPE.TXT
)
reply.add_answer(r)
return reply
# handle subdomains
elif self.match_suffix_insensitive(request):
labelstr = str(request.q.qname)
logger.info("requestx: %s, %s", labelstr, confs.ONLY_PRIVATE_IPS)
subdomains = labelstr.split('.')
if len(subdomains) == 4: # TODO: dynamic
ip = None
try:
ip = ipaddress.ip_address(subdomains[0].replace('-', '.'))
except:
pass
try:
if ip == None:
ip = ipaddress.ip_address(subdomains[0].replace('-', ':'))
except:
logger.info('invalid ip %s', labelstr)
return reply
# check if we only want private ips
if not ip.is_private and confs.ONLY_PRIVATE_IPS:
return reply
if ip.is_reserved and confs.NO_RESERVED_IPS:
return reply
# check if it's a valid ip for a machine
if ip.is_multicast or ip.is_unspecified:
return reply
if type(ip) == ipaddress.IPv4Address:
ipv4 = subdomains[0].replace('-', '.')
logger.info("ip is ipv4 %s", ipv4)
r = RR(
rname=request.q.qname,
rdata=dns.A(ipv4),
rtype=QTYPE.A,
ttl=24*60*60
)
reply.add_answer(r)
elif type(ip) == ipaddress.IPv6Address:
ipv6 = subdomains[0].replace('-', ':')
logger.info("ip is ipv6 %s", ipv6)
r = RR(
rname=request.q.qname,
rdata=dns.AAAA(ipv6),
rtype=QTYPE.AAAA,
ttl=24*60*60
)
reply.add_answer(r)
else:
return reply
logger.info('found zone for %s, %d replies', request.q.qname, len(reply.rr))
return reply
elif self.address == "":
return reply
return super().resolve(request, handler)
def handle_sig(signum, frame):
logger.info('pid=%d, got signal: %s, stopping...', os.getpid(), signal.Signals(signum).name)
exit(0)
# this is used to hear for new TXT records from the certbotdns script. We need to get them ASAP to
# validate the certbot request.
def messageListener():
global TXT_RECORDS
address = ('localhost', 6000) # family is deduced to be 'AF_INET'
listener = Listener(address, authkey=os.getenv('KEY', b'secret')) # not very secret, but we're bound to localhost.
while True:
conn = None
try:
conn = listener.accept()
msg = conn.recv()
# do something with msg
msg = json.loads(msg, encoding="utf-8")
if msg['command'] == "ADDTXT":
TXT_RECORDS[msg["key"]] = msg["val"]
elif msg['command'] == "REMOVETXT":
TXT_RECORDS.pop(msg["key"])
conn.close()
except Exception as e:
logger.error(e)
if conn:
conn.close()
pass
listener.close()
def main():
signal.signal(signal.SIGTERM, handle_sig)
parser = argparse.ArgumentParser(description='LocalTLS')
parser.add_argument(
'--domain',
required = True,
help = "Your domain or subdomain."
)
parser.add_argument(
'--soa-master',
help = "Primary master name server for SOA record. You should fill this."
)
parser.add_argument(
'--soa-email',
help = "Email address for administrator for SOA record. You should fill this."
)
parser.add_argument(
'--ns-servers',
help = "List of ns servers, separated by comma"
)
parser.add_argument(
'--dns-port',
default=53,
help = "DNS server port"
)
parser.add_argument(
'--domain-ipv4',
default='',
help = "The IPV4 for the naked domain. If empty, use this machine's."
)
parser.add_argument(
'--domain-ipv6',
default='',
help = "The IPV6 for the naked domain. If empty, use this machine's."
)
parser.add_argument(
'--only-private-ips',
action='store_true',
default=False,
help = "Resolve only IPs in private ranges."
)
parser.add_argument(
'--no-reserved-ips',
action='store_true',
default=False,
help = "If true ignore ips that are reserved."
)
parser.add_argument(
'--dns-fallback',
default='1.1.1.1',
help = "DNS fallback server. Default: 1.1.1.1"
)
parser.add_argument(
'--http-port',
help = "HTTP server port. If not set, no HTTP server is started"
)
parser.add_argument(
'--http-index-file',
default = 'index.html',
help = "HTTP index.html file."
)
parser.add_argument(
'--log-level',
default = 'ERROR',
help = "INFO|WARNING|ERROR|DEBUG"
)
args = parser.parse_args()
# The primary addresses
confs.LOCAL_IPV4 = args.domain_ipv4 if args.domain_ipv4 else get_ipv4()
confs.LOCAL_IPV6 = args.domain_ipv6 if args.domain_ipv6 else get_ipv6()
try:
ipaddress.ip_address(confs.LOCAL_IPV4)
except:
logger.critical('Invalid IPV4 %s', LOCAL_IPV4)
sys.exit(1)
try:
if confs.LOCAL_IPV6:
ipaddress.ip_address(confs.LOCAL_IPV6)
except:
logger.critical('Invalid IPV6 %s', LOCAL_IPV6)
sys.exit(1)
logger.setLevel(args.log_level)
confs.ONLY_PRIVATE_IPS = args.only_private_ips
confs.NO_RESERVED_IPS = args.no_reserved_ips
confs.BASE_DOMAIN = args.domain
confs.SOA_MNAME = args.soa_master
confs.SOA_RNAME = args.soa_email
if not confs.SOA_MNAME or not confs.SOA_RNAME:
logger.error('Setting SOA is strongly recommended')
if args.ns_servers:
confs.NS_SERVERS=args.ns_servers.split(',')
# handle local messages to add TXT records
threadMessage = threading.Thread(target=messageListener)
threadMessage.start()
# open the DNS server
port = int(args.dns_port)
upstream = args.dns_fallback
resolver = Resolver(upstream)
if args.log_level == 'debug':
logmode = "+request,+reply,+truncated,+error"
else:
logmode = "-request,-reply,-truncated,+error"
dnslogger = DNSLogger(log=logmode, prefix=False)
udp_server = DNSServer(resolver, port=port, logger=dnslogger)
tcp_server = DNSServer(resolver, port=port, tcp=True, logger=dnslogger)
logger.critical('starting DNS server on %s/%s on port %d, upstream DNS server "%s"', confs.LOCAL_IPV4, confs.LOCAL_IPV6, port, upstream)
udp_server.start_thread()
tcp_server.start_thread()
# open the HTTP server
if args.http_port:
logger.critical('Starting httpd...')
threadHTTP = threading.Thread(target=httpserver.run, kwargs={"port": int(args.http_port), "index": args.http_index_file})
threadHTTP.start()
try:
while udp_server.isAlive():
sleep(1)
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()