#!/usr/bin/python2.7
#
# Univention IP Calculator
"""helper script: base on IP address and network mask the script
calculates broadcast address, network address and DNS entries"""
#
# Copyright 2004-2020 Univention GmbH
#
# https://www.univention.de/
#
# All rights reserved.
#
# The source code of this program is made available
# under the terms of the GNU Affero General Public License version 3
# (GNU AGPL V3) as published by the Free Software Foundation.
#
# Binary versions of this program provided by Univention to you as
# well as other copyrighted, protected or trademarked materials like
# Logos, graphics, fonts, specific documentations and configurations,
# cryptographic keys etc. are subject to a license agreement between
# you and Univention and not subject to the GNU AGPL V3.
#
# In the case you use this program under the terms of the GNU AGPL V3,
# the program is provided in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public
# License with the Debian GNU/Linux or Univention distribution in file
# /usr/share/common-licenses/AGPL-3; if not, see
# <https://www.gnu.org/licenses/>.

from __future__ import print_function
import optparse
import sys


def parse_ipv4(addr):
	"""Parse IPv4 address.
	>>> parse_ipv4('1.2.3.4')
	[1, 2, 3, 4]
	>>> parse_ipv4('1.2.3')
	Traceback (most recent call last):
	...
	ValueError: need 4 octets
	>>> parse_ipv4('1.2.3.4.5')
	Traceback (most recent call last):
	...
	ValueError: need 4 octets
	>>> parse_ipv4('1.2.3.-4')
	Traceback (most recent call last):
	...
	ValueError: need 4 octets
	>>> parse_ipv4('1.2.3.256')
	Traceback (most recent call last):
	...
	ValueError: need 4 octets
	"""
	octets = addr.split('.')
	octets = [int(_) for _ in octets]
	if len(octets) != 4:
		raise ValueError('need 4 octets')
	if min(octets) < 0 or max(octets) > 255:
		raise ValueError('need 4 octets')
	return octets


def format_ipv4(addr, extend=None):
	"""Format IPv4 address.
	>>> format_ipv4([1, 2, 3, 4])
	'1.2.3.4'
	>>> format_ipv4([1, 2, 3], extend=0)
	'1.2.3.0'
	"""
	if extend is not None:
		addr = (addr + [extend, extend, extend, extend])[:4]
	return '.'.join(('%d' % (_,) for _ in addr))


def parse_options():
	"""Parse command line options."""
	usage = '%prog [options] --ip <addr> --netmask <mask> [--output <type>]'
	epilog = 'Calculate network values from network address for DNS records.'
	parser = optparse.OptionParser(usage=usage, epilog=epilog)
	parser.add_option(
		'--ip', dest='ip',
		help='IPv4 or IPv6 address')
	parser.add_option(
		'--netmask', dest='netmask',
		help='Netmask or prefix length')
	parser.add_option(
		'--output', dest='output', default='all',
		choices=('network', 'reverse', 'pointer', 'all'),
		help='Specify requested output type')
	parser.add_option(
		'--calcdns', dest='calcdns',
		action='store_true',
		help='Request to calcuale DNS record entries')
	parser.add_option(
		'--quiet', dest='quiet',
		action='store_true',
		help='Supress error output')
	parser.add_option(
		'--full', dest='full',
		action='store_true',
		help='Extend network address to 4-tuple')
	parser.add_option(
		'--test-internal', action='callback', callback=__test,
		help=optparse.SUPPRESS_HELP)
	(options, _args) = parser.parse_args()
	if not options.ip or not options.netmask:
		if not options.quiet:
			parser.error("Needed --ip and --netmask")
		sys.exit(1)

	try:
		addr = parse_ipv4(options.ip)
	except ValueError:
		if not options.quiet:
			parser.error('Invalid ip')
		sys.exit(1)

	try:
		netmask = parse_ipv4(options.netmask)
	except ValueError:
		if not options.quiet:
			parser.error('Invalid netmask')
		sys.exit(1)

	return addr, netmask, options


def calc_dns(addr, netmask):
	"""Calculate values for use in DNS (in octet resolution).
	>>> calc_dns([170,85,204,51],[0,0,0,0])
	([170], [170, 255, 255, 255], [51, 204, 85])
	>>> calc_dns([170,85,204,51],[192,0,0,0])
	([170], [170, 255, 255, 255], [51, 204, 85])
	>>> calc_dns([170,85,204,51],[255,0,0,0])
	([170], [170, 255, 255, 255], [51, 204, 85])
	>>> calc_dns([170,85,204,51],[255,255,255,0])
	([170, 85, 204], [170, 85, 204, 255], [51])
	>>> calc_dns([170,85,204,51],[255,255,255,128])
	([170, 85, 204], [170, 85, 204, 255], [51])
	>>> calc_dns([170,85,204,51],[255,255,255,255])
	([170, 85, 204, 51], [170, 85, 204, 51], [51])
	"""
	for i in range(1, 4):
		if netmask[i] < 255:
			break
	else:
		i = 4
	network = addr[:i]
	broadcast = (addr[:i] + [255, 255, 255, 255])[:4]
	pointer = addr[min(i, 3):]
	pointer.reverse()
	return network, broadcast, pointer


def calc_full(addr, netmask):
	"""Calculate values in full (bit) resolution.
	>>> calc_full([170,85,204,51],[0,0,0,0])
	([0, 0, 0, 0], [255, 255, 255, 255], [51, 204, 85, 170])
	>>> calc_full([170,85,204,51],[192,0,0,0])
	([128, 0, 0, 0], [191, 255, 255, 255], [51, 204, 85, 170])
	>>> calc_full([170,85,204,51],[255,0,0,0])
	([170, 0, 0, 0], [170, 255, 255, 255], [51, 204, 85])
	>>> calc_full([170,85,204,51],[255,255,255,0])
	([170, 85, 204, 0], [170, 85, 204, 255], [51])
	>>> calc_full([170,85,204,51],[255,255,255,128])
	([170, 85, 204, 0], [170, 85, 204, 127], [51])
	>>> calc_full([170,85,204,51],[255,255,255,255])
	([170, 85, 204, 51], [170, 85, 204, 51], [51])
	"""
	network = [a & m for (a, m) in zip(addr, netmask)]
	broadcast = [a | ~m & 255 for (a, m) in zip(addr, netmask)]
	for i in range(4):
		if netmask[i] < 255:
			pointer = addr[i:]
			break
	else:
		pointer = addr[3:]
	pointer.reverse()
	return network, broadcast, pointer


def main():
	"""Calculates broadcast address, network address and DNS entries."""
	addr, netmask, options = parse_options()
	if options.calcdns:
		network, broadcast, pointer = calc_dns(addr, netmask)
	else:
		network, broadcast, pointer = calc_full(addr, netmask)

	if options.netmask == "255.255.255.255":
		reverse = network[:-1]
	else:
		reverse = network[:]

	if options.output == 'all':
		print('Network: %s' % format_ipv4(network))
		reverse.reverse()
		print('Reverse: %s' % format_ipv4(reverse))
		print('Pointer: %s' % format_ipv4(pointer))
		if options.full:
			print('Network full: %s' % format_ipv4(network, extend=0))
		print('Broadcast: %s' % format_ipv4(broadcast))
	elif options.output == 'network':
		print(format_ipv4(network, extend=0 if options.full else None))
	elif options.output == 'reverse':
		print(format_ipv4(reverse))
	elif options.output == 'pointer':
		print(format_ipv4(pointer))
	elif options.output == 'broadcast':
		print(format_ipv4(broadcast))


def __test(_option, _opt_str, _value, _parser):
	"""Run internal test suite."""
	import doctest
	res = doctest.testmod()
	sys.exit(int(bool(res[0])))


if __name__ == '__main__':
	main()
