#!/bin/bash

# NFT_TEST_REQUIRES(NFT_TEST_HAVE_socat)

# Tests reject functionality for both IPv4 and IPv6 with TCP and ICMP traffic on
# the loopback interface.
#
# - check reject works, i.e. ping and connect fail
# - check we don't reply to tcp resets with another tcp reset
# - check we don't reply to icmp error with another icmp error

ret=0
port=14512

ip link set lo up

load_ruleset_netdev()
{
echo load netdev test ruleset
$NFT -f -<<EOF
table netdev t {
	chain in {
		type filter hook ingress device lo priority 0

		ip protocol icmp counter reject
		ip6 nexthdr icmpv6 counter reject
		meta l4proto tcp counter reject with tcp reset
	}
}
EOF
}

load_ruleset_inet()
{
echo load inet test ruleset
$NFT -f -<<EOF
table inet t {
        chain in {
                type filter hook prerouting priority filter; policy accept;
                iifname lo jump {
	                icmp type echo-request counter reject with icmp port-unreachable
	                icmpv6 type echo-request counter reject with icmpv6 port-unreachable
	                meta l4proto tcp tcp flags syn counter reject with tcp reset
		}
        }
}
EOF
}

# try to get nf_tables to reset tcp rest with tcp reset and
# reject icmp port-unreach with port-unreach.
# Should NOT be possible.
# Note that this ruleset is nonsensical:
# meta l4proto tcp ... reject with tcp reset, on loopback,
# will drop the reset packet so the client times out.
#
# This isn't an issue for remote clients, as the reset
# won't appear in prerouting.
load_ruleset_inet_loop()
{
echo load inet loop ruleset
$NFT -f -<<EOF
table inet t {
	counter tcprstc { }
	counter icmp4c { }
	counter icmp6c { }

        chain in {
                type filter hook prerouting priority filter; policy accept;
                iifname lo jump {
	                ip protocol icmp counter name icmp4c reject with icmp port-unreachable
	                ip6 nexthdr icmpv6 counter name icmp6c reject with icmpv6 port-unreachable
	                meta l4proto tcp counter name tcprstc reject with tcp reset
		}
        }
}
EOF
}

load_ruleset_netdev_loop()
{
echo load netdev loop ruleset
$NFT -f -<<EOF
table netdev t {
	counter tcprstc { }
	counter icmp4c { }
	counter icmp6c { }

        chain in {
		type filter hook ingress device lo priority 0
                ip protocol icmp counter name icmp4c reject with icmp port-unreachable
                ip6 nexthdr icmpv6 counter name icmp6c reject with icmpv6 port-unreachable
                meta l4proto tcp counter name tcprstc reject with tcp reset
        }
}
EOF
}

check_counter()
{
	local family="$1"
	local countername="$2"
	local wanted_packetcount="$3"
	local max_packetcount="$4"

	echo "counter $family t $countername has $pcount packets"
	if $NFT list counter "$family" "t" "$countername" | grep packets\ $wanted_packetcount;then
		return
	fi

	# the _loop rulesets drop tcp resets, so we must tolerate retransmitted syns.
	if [ "$max_packetcount" -gt 0 ];then
		local pcount=$($NFT list counter "$family" "t" "$countername" | grep packets)

		pcount=${pcount%bytes*}
		pcount=${pcount#*packets}

		if [ "$pcount" -gt 0 ] && [ "$pcount" -le "$max_packetcount" ];then
			echo "Tolerated $pcount packets (max $max_packetcount)"
			return
		fi
	fi

	echo "Unexpected packetcount, expected $wanted_packetcount / max $max_packetcount"
	$NFT list counter "$family" "t" "$countername"
	ret=1
}

check_counters()
{
	local family="$1"

	# one syn, one rst
	check_counter "$family" tcprstc 4 16

	# one for echo, one for dst-unreach
	check_counter "$family" icmp4c 2 2
	check_counter "$family" icmp6c 2 2
}

maybe_error()
{
	local ret="$1"
	shift
	local err_wanted="$1"
	shift

	local errmsg="$@"


	if [ $ret -eq 0 ];then
		errmsg="$errmsg succeeded"

		if [ $err_wanted -ne 0 ]; then
			echo "$errmsg but expected to fail"
			ret=1
			return
		fi
	else
		errmsg="$errmsg failed ($ret)"

		if [ $err_wanted -eq 0 ]; then
			echo "$errmsg but expected to work"
			ret=1
			return
		fi

	fi
}

test_all()
{
	local err_wanted="$1"

	ping -W 1 -q -c 1 127.0.0.1 > /dev/null
	maybe_error $? "$err_wanted" "ping 127.0.0.1"

	socat -u STDIN TCP-CONNECT:127.0.0.1:$port,connect-timeout=1 < /dev/null 2>/dev/null
	maybe_error $? "$err_wanted" "connect 127.0.0.1"

	ping -W 1 -q -c 1 ::1 > /dev/null
	maybe_error $? "$err_wanted" "connect 127.0.0.1"

	socat -u STDIN TCP-CONNECT:[::1]:$port,connect-timeout=1 < /dev/null 2>/dev/null
	maybe_error $? "$err_wanted" "connect ::1"
}

# Start socat listeners in background
timeout 10 socat TCP-LISTEN:$port,bind=127.0.0.1,reuseaddr PIPE &
SOCAT_PID4=$!

timeout 10 socat TCP6-LISTEN:$port,bind=::1,reuseaddr PIPE &
SOCAT_PID6=$!

# Give listeners time to start
sleep 1

# empty ruleset
test_all 0

load_ruleset_inet
test_all 1
$NFT delete table inet t

load_ruleset_netdev
test_all 1
$NFT delete table netdev t

load_ruleset_inet_loop
test_all 1
check_counters inet
$NFT delete table inet t

load_ruleset_netdev_loop
test_all 1
check_counters netdev
$NFT delete table netdev t

# Clean up listeners
kill $SOCAT_PID4 $SOCAT_PID6 2>/dev/null

echo "Exiting with $ret"
exit $ret
