# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause from collections import namedtuple from enum import Enum import functools import os import random import socket import struct from struct import Struct import sys import yaml import ipaddress import uuid import queue import selectors import time from .nlspec import SpecFamily # # Generic Netlink code which should really be in some library, but I can't quickly find one. # class Netlink: # Netlink socket SOL_NETLINK = 270 NETLINK_ADD_MEMBERSHIP = 1 NETLINK_CAP_ACK = 10 NETLINK_EXT_ACK = 11 NETLINK_GET_STRICT_CHK = 12 # Netlink message NLMSG_ERROR = 2 NLMSG_DONE = 3 NLM_F_REQUEST = 1 NLM_F_ACK = 4 NLM_F_ROOT = 0x100 NLM_F_MATCH = 0x200 NLM_F_REPLACE = 0x100 NLM_F_EXCL = 0x200 NLM_F_CREATE = 0x400 NLM_F_APPEND = 0x800 NLM_F_CAPPED = 0x100 NLM_F_ACK_TLVS = 0x200 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH NLA_F_NESTED = 0x8000 NLA_F_NET_BYTEORDER = 0x4000 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER # Genetlink defines NETLINK_GENERIC = 16 GENL_ID_CTRL = 0x10 # nlctrl CTRL_CMD_GETFAMILY = 3 CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_MCAST_GRP_NAME = 1 CTRL_ATTR_MCAST_GRP_ID = 2 # Extack types NLMSGERR_ATTR_MSG = 1 NLMSGERR_ATTR_OFFS = 2 NLMSGERR_ATTR_COOKIE = 3 NLMSGERR_ATTR_POLICY = 4 NLMSGERR_ATTR_MISS_TYPE = 5 NLMSGERR_ATTR_MISS_NEST = 6 # Policy types NL_POLICY_TYPE_ATTR_TYPE = 1 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 NL_POLICY_TYPE_ATTR_PAD = 11 NL_POLICY_TYPE_ATTR_MASK = 12 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 's8', 's16', 's32', 's64', 'binary', 'string', 'nul-string', 'nested', 'nested-array', 'bitfield32', 'sint', 'uint']) class NlError(Exception): def __init__(self, nl_msg): self.nl_msg = nl_msg self.error = -nl_msg.error def __str__(self): return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" class ConfigError(Exception): pass class NlAttr: ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) type_formats = { 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("h"), Struct("I"), Struct("i"), Struct("Q"), Struct("q"), Struct(">= 1 i += 1 else: value = enum.entries_by_val[raw].name return value def _decode_binary(self, attr, attr_spec): if attr_spec.struct_name: decoded = self._decode_struct(attr.raw, attr_spec.struct_name) elif attr_spec.sub_type: decoded = attr.as_c_array(attr_spec.sub_type) else: decoded = attr.as_bin() if attr_spec.display_hint: decoded = self._formatted_string(decoded, attr_spec.display_hint) return decoded def _decode_array_attr(self, attr, attr_spec): decoded = [] offset = 0 while offset < len(attr.raw): item = NlAttr(attr.raw, offset) offset += item.full_len if attr_spec["sub-type"] == 'nest': subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) decoded.append({ item.type: subattrs }) elif attr_spec["sub-type"] == 'binary': subattrs = item.as_bin() if attr_spec.display_hint: subattrs = self._formatted_string(subattrs, attr_spec.display_hint) decoded.append(subattrs) elif attr_spec["sub-type"] in NlAttr.type_formats: subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) if attr_spec.display_hint: subattrs = self._formatted_string(subattrs, attr_spec.display_hint) decoded.append(subattrs) else: raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') return decoded def _decode_nest_type_value(self, attr, attr_spec): decoded = {} value = attr for name in attr_spec['type-value']: value = NlAttr(value.raw, 0) decoded[name] = value.type subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) decoded.update(subattrs) return decoded def _decode_unknown(self, attr): if attr.is_nest: return self._decode(NlAttrs(attr.raw), None) else: return attr.as_bin() def _rsp_add(self, rsp, name, is_multi, decoded): if is_multi == None: if name in rsp and type(rsp[name]) is not list: rsp[name] = [rsp[name]] is_multi = True else: is_multi = False if not is_multi: rsp[name] = decoded elif name in rsp: rsp[name].append(decoded) else: rsp[name] = [decoded] def _resolve_selector(self, attr_spec, search_attrs): sub_msg = attr_spec.sub_message if sub_msg not in self.sub_msgs: raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") sub_msg_spec = self.sub_msgs[sub_msg] selector = attr_spec.selector value = search_attrs.lookup(selector) if value not in sub_msg_spec.formats: raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") spec = sub_msg_spec.formats[value] return spec def _decode_sub_msg(self, attr, attr_spec, search_attrs): msg_format = self._resolve_selector(attr_spec, search_attrs) decoded = {} offset = 0 if msg_format.fixed_header: decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); offset = self._struct_size(msg_format.fixed_header) if msg_format.attr_set: if msg_format.attr_set in self.attr_sets: subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) decoded.update(subdict) else: raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") return decoded def _decode(self, attrs, space, outer_attrs = None): rsp = dict() if space: attr_space = self.attr_sets[space] search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) for attr in attrs: try: attr_spec = attr_space.attrs_by_val[attr.type] except (KeyError, UnboundLocalError): if not self.process_unknown: raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") attr_name = f"UnknownAttr({attr.type})" self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) continue if attr_spec["type"] == 'nest': subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) decoded = subdict elif attr_spec["type"] == 'string': decoded = attr.as_strz() elif attr_spec["type"] == 'binary': decoded = self._decode_binary(attr, attr_spec) elif attr_spec["type"] == 'flag': decoded = True elif attr_spec.is_auto_scalar: decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) elif attr_spec["type"] in NlAttr.type_formats: decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) if 'enum' in attr_spec: decoded = self._decode_enum(decoded, attr_spec) elif attr_spec.display_hint: decoded = self._formatted_string(decoded, attr_spec.display_hint) elif attr_spec["type"] == 'indexed-array': decoded = self._decode_array_attr(attr, attr_spec) elif attr_spec["type"] == 'bitfield32': value, selector = struct.unpack("II", attr.raw) if 'enum' in attr_spec: value = self._decode_enum(value, attr_spec) selector = self._decode_enum(selector, attr_spec) decoded = {"value": value, "selector": selector} elif attr_spec["type"] == 'sub-message': decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) elif attr_spec["type"] == 'nest-type-value': decoded = self._decode_nest_type_value(attr, attr_spec) else: if not self.process_unknown: raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') decoded = self._decode_unknown(attr) self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) return rsp def _decode_extack_path(self, attrs, attr_set, offset, target): for attr in attrs: try: attr_spec = attr_set.attrs_by_val[attr.type] except KeyError: raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") if offset > target: break if offset == target: return '.' + attr_spec.name if offset + attr.full_len <= target: offset += attr.full_len continue if attr_spec['type'] != 'nest': raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") offset += 4 subpath = self._decode_extack_path(NlAttrs(attr.raw), self.attr_sets[attr_spec['nested-attributes']], offset, target) if subpath is None: return None return '.' + attr_spec.name + subpath return None def _decode_extack(self, request, op, extack): if 'bad-attr-offs' not in extack: return msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, extack['bad-attr-offs']) if path: del extack['bad-attr-offs'] extack['bad-attr'] = path def _struct_size(self, name): if name: members = self.consts[name].members size = 0 for m in members: if m.type in ['pad', 'binary']: if m.struct: size += self._struct_size(m.struct) else: size += m.len else: format = NlAttr.get_format(m.type, m.byte_order) size += format.size return size else: return 0 def _decode_struct(self, data, name): members = self.consts[name].members attrs = dict() offset = 0 for m in members: value = None if m.type == 'pad': offset += m.len elif m.type == 'binary': if m.struct: len = self._struct_size(m.struct) value = self._decode_struct(data[offset : offset + len], m.struct) offset += len else: value = data[offset : offset + m.len] offset += m.len else: format = NlAttr.get_format(m.type, m.byte_order) [ value ] = format.unpack_from(data, offset) offset += format.size if value is not None: if m.enum: value = self._decode_enum(value, m) elif m.display_hint: value = self._formatted_string(value, m.display_hint) attrs[m.name] = value return attrs def _encode_struct(self, name, vals): members = self.consts[name].members attr_payload = b'' for m in members: value = vals.pop(m.name) if m.name in vals else None if m.type == 'pad': attr_payload += bytearray(m.len) elif m.type == 'binary': if m.struct: if value is None: value = dict() attr_payload += self._encode_struct(m.struct, value) else: if value is None: attr_payload += bytearray(m.len) else: attr_payload += bytes.fromhex(value) else: if value is None: value = 0 format = NlAttr.get_format(m.type, m.byte_order) attr_payload += format.pack(value) return attr_payload def _formatted_string(self, raw, display_hint): if display_hint == 'mac': formatted = ':'.join('%02x' % b for b in raw) elif display_hint == 'hex': if isinstance(raw, int): formatted = hex(raw) else: formatted = bytes.hex(raw, ' ') elif display_hint in [ 'ipv4', 'ipv6' ]: formatted = format(ipaddress.ip_address(raw)) elif display_hint == 'uuid': formatted = str(uuid.UUID(bytes=raw)) else: formatted = raw return formatted def handle_ntf(self, decoded): msg = dict() if self.include_raw: msg['raw'] = decoded op = self.rsp_by_value[decoded.cmd()] attrs = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) msg['name'] = op['name'] msg['msg'] = attrs self.async_msg_queue.put(msg) def check_ntf(self): while True: try: reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) except BlockingIOError: return nms = NlMsgs(reply) self._recv_dbg_print(reply, nms) for nl_msg in nms: if nl_msg.error: print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) print(nl_msg) continue if nl_msg.done: print("Netlink done while checking for ntf!?") continue decoded = self.nlproto.decode(self, nl_msg, None) if decoded.cmd() not in self.async_msg_ids: print("Unexpected msg id while checking for ntf", decoded) continue self.handle_ntf(decoded) def poll_ntf(self, duration=None): start_time = time.time() selector = selectors.DefaultSelector() selector.register(self.sock, selectors.EVENT_READ) while True: try: yield self.async_msg_queue.get_nowait() except queue.Empty: if duration is not None: timeout = start_time + duration - time.time() if timeout <= 0: return else: timeout = None events = selector.select(timeout) if events: self.check_ntf() def operation_do_attributes(self, name): """ For a given operation name, find and return a supported set of attributes (as a dict). """ op = self.find_operation(name) if not op: return None return op['do']['request']['attributes'].copy() def _encode_message(self, op, vals, flags, req_seq): nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK for flag in flags or []: nl_flags |= flag msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) if op.fixed_header: msg += self._encode_struct(op.fixed_header, vals) search_attrs = SpaceAttrs(op.attr_set, vals) for name, value in vals.items(): msg += self._add_attr(op.attr_set.name, name, value, search_attrs) msg = _genl_msg_finalize(msg) return msg def _ops(self, ops): reqs_by_seq = {} req_seq = random.randint(1024, 65535) payload = b'' for (method, vals, flags) in ops: op = self.ops[method] msg = self._encode_message(op, vals, flags, req_seq) reqs_by_seq[req_seq] = (op, msg, flags) payload += msg req_seq += 1 self.sock.send(payload, 0) done = False rsp = [] op_rsp = [] while not done: reply = self.sock.recv(self._recv_size) nms = NlMsgs(reply, attr_space=op.attr_set) self._recv_dbg_print(reply, nms) for nl_msg in nms: if nl_msg.nl_seq in reqs_by_seq: (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] if nl_msg.extack: self._decode_extack(req_msg, op, nl_msg.extack) else: op = None req_flags = [] if nl_msg.error: raise NlError(nl_msg) if nl_msg.done: if nl_msg.extack: print("Netlink warning:") print(nl_msg) if Netlink.NLM_F_DUMP in req_flags: rsp.append(op_rsp) elif not op_rsp: rsp.append(None) elif len(op_rsp) == 1: rsp.append(op_rsp[0]) else: rsp.append(op_rsp) op_rsp = [] del reqs_by_seq[nl_msg.nl_seq] done = len(reqs_by_seq) == 0 break decoded = self.nlproto.decode(self, nl_msg, op) # Check if this is a reply to our request if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: if decoded.cmd() in self.async_msg_ids: self.handle_ntf(decoded) continue else: print('Unexpected message: ' + repr(decoded)) continue rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) if op.fixed_header: rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) op_rsp.append(rsp_msg) return rsp def _op(self, method, vals, flags=None, dump=False): req_flags = flags or [] if dump: req_flags.append(Netlink.NLM_F_DUMP) ops = [(method, vals, req_flags)] return self._ops(ops)[0] def do(self, method, vals, flags=None): return self._op(method, vals, flags) def dump(self, method, vals): return self._op(method, vals, dump=True) def do_multi(self, ops): return self._ops(ops)