From nobody Fri Jan 13 21:25:19 2023 X-Original-To: dev-commits-src-branches@mlmmj.nyi.freebsd.org Received: from mx1.freebsd.org (mx1.freebsd.org [IPv6:2610:1c1:1:606c::19:1]) by mlmmj.nyi.freebsd.org (Postfix) with ESMTP id 4Ntvb33Lwqz2ql4x; Fri, 13 Jan 2023 21:25:19 +0000 (UTC) (envelope-from git@FreeBSD.org) Received: from mxrelay.nyi.freebsd.org (mxrelay.nyi.freebsd.org [IPv6:2610:1c1:1:606c::19:3]) (using TLSv1.3 with cipher TLS_AES_256_GCM_SHA384 (256/256 bits) key-exchange X25519 server-signature RSA-PSS (4096 bits) server-digest SHA256 client-signature RSA-PSS (4096 bits) client-digest SHA256) (Client CN "mxrelay.nyi.freebsd.org", Issuer "R3" (verified OK)) by mx1.freebsd.org (Postfix) with ESMTPS id 4Ntvb335P2z42H8; Fri, 13 Jan 2023 21:25:19 +0000 (UTC) (envelope-from git@FreeBSD.org) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/relaxed; d=freebsd.org; s=dkim; t=1673645119; h=from:from:reply-to:subject:subject:date:date:message-id:message-id: to:to:cc:mime-version:mime-version:content-type:content-type: content-transfer-encoding:content-transfer-encoding; bh=xzEbuo862Kr0bPrGWZV+dhcRXV5gB9+lPUW1oS7h76Y=; b=XcqlMoa19UaBWstqn6cbJl6uK5BLSyPfVaQvqiKE+GAdYj5q7vxAzjKYvuAOFM7pPREnD+ hT8x7rTmIAlDvAhiLhHtJ60WqBKDY1K79EBk7aThX1ROANbSenzAehFlhaDIAlSvLgQefT 5az8rMj0eqSPslAZCkkmYyjhC21vfxX7vJ/paYWRdVQafCyR5vr578CuWwTN9C5dTt6c39 JsZEYFFtKfM5AV444/jCWO7C4B1dmvnG4YJH2rXp1jpF6/PKPDzhtTCvrFiolyg1g0z1xx iNh9y1CBq6K6b8wTkIUIRNtZoAXgF4yo+49eq9YaCsJ0oW3E9Lit4XScORYGzw== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=freebsd.org; s=dkim; t=1673645119; h=from:from:reply-to:subject:subject:date:date:message-id:message-id: to:to:cc:mime-version:mime-version:content-type:content-type: content-transfer-encoding:content-transfer-encoding; bh=xzEbuo862Kr0bPrGWZV+dhcRXV5gB9+lPUW1oS7h76Y=; b=xzefcF2fWdRqfJxlYYTzyBuo3TEaelNOM3mMu/AuGD573uQigPOJG5n70/EEIGdMGSLxeb 4Dh4LDGRLaAQRp6t3BtNqh/mv4tCwcjTablfYcgq8QfBV19eqB9EfApB5b6oDyECFmSvYU cuuDFRlsO9UKGMoCp5ELgNyXGIQeYtKAamsxhYIEYyYkc5nicfAs2l3Fxxu3rm+rbtWDgc xZQEhx0g+pnSGayF4lcEAIxxFnu5gvn1KjR9PEK/G+IoBOCrPfQZrQyfv7QFPaDrY9JhXC AGsire6LMYlBJM0jCglhHTZD/NgFJJQWVet7q8qMli6eLkEoeLkxw/I5NnGU+g== ARC-Authentication-Results: i=1; mx1.freebsd.org; none ARC-Seal: i=1; s=dkim; d=freebsd.org; t=1673645119; a=rsa-sha256; cv=none; b=D0MRpguIsi2n2QuLEvDAO9p0kYVaJS9ORLhvAsFV7WyaZQFzqkkXScYTT+yuexdDW1xoWM gyy8LYIf5OaBTknrWAzuDXzVFWo1A+GFrHYAeZ/UgZBkHhamCGONYd4RxY8Lxw993Iu0u4 OmWQ0GInouCHxUgPuFnPjR+RRetHhRUS8l1omKTENFUXpO2E7prZMsTrhQIYUYS5FRQB8F CCy0jGjuUbnOWozNfm2lxBTFdmHp8+birt7pDMYefFAKrNDu8OUJX8ZxMox++G5l3odnko hqU8jKe9+KDvCZA9nCopNnt8EbGGsVpkb3KD4NYOLgi5jgAQv8w2pBaoidAJHQ== Received: from gitrepo.freebsd.org (gitrepo.freebsd.org [IPv6:2610:1c1:1:6068::e6a:5]) (using TLSv1.3 with cipher TLS_AES_256_GCM_SHA384 (256/256 bits) key-exchange X25519 server-signature RSA-PSS (4096 bits) server-digest SHA256) (Client did not present a certificate) by mxrelay.nyi.freebsd.org (Postfix) with ESMTPS id 4Ntvb329GPzMvC; Fri, 13 Jan 2023 21:25:19 +0000 (UTC) (envelope-from git@FreeBSD.org) Received: from gitrepo.freebsd.org ([127.0.1.44]) by gitrepo.freebsd.org (8.16.1/8.16.1) with ESMTP id 30DLPJ6w041677; Fri, 13 Jan 2023 21:25:19 GMT (envelope-from git@gitrepo.freebsd.org) Received: (from git@localhost) by gitrepo.freebsd.org (8.16.1/8.16.1/Submit) id 30DLPJ4O041676; Fri, 13 Jan 2023 21:25:19 GMT (envelope-from git) Date: Fri, 13 Jan 2023 21:25:19 GMT Message-Id: <202301132125.30DLPJ4O041676@gitrepo.freebsd.org> To: src-committers@FreeBSD.org, dev-commits-src-all@FreeBSD.org, dev-commits-src-branches@FreeBSD.org From: "Alexander V. Chernikov" Subject: git: 0a76e8d75bea - stable/13 - testing: Add basic atf support to pytest. List-Id: Commits to the stable branches of the FreeBSD src repository List-Archive: https://lists.freebsd.org/archives/dev-commits-src-branches List-Help: List-Post: List-Subscribe: List-Unsubscribe: Sender: owner-dev-commits-src-branches@freebsd.org X-BeenThere: dev-commits-src-branches@freebsd.org MIME-Version: 1.0 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit X-Git-Committer: melifaro X-Git-Repository: src X-Git-Refname: refs/heads/stable/13 X-Git-Reftype: branch X-Git-Commit: 0a76e8d75beabf5a4274b3335866b6c853f42b00 Auto-Submitted: auto-generated X-ThisMailContainsUnwantedMimeParts: N The branch stable/13 has been updated by melifaro: URL: https://cgit.FreeBSD.org/src/commit/?id=0a76e8d75beabf5a4274b3335866b6c853f42b00 commit 0a76e8d75beabf5a4274b3335866b6c853f42b00 Author: Alexander V. Chernikov AuthorDate: 2022-06-25 19:01:45 +0000 Commit: Alexander V. Chernikov CommitDate: 2023-01-13 21:24:10 +0000 testing: Add basic atf support to pytest. Implementation consists of the pytest plugin implementing ATF format and a simple C++ wrapper, which reorders the provided arguments from ATF format to the format understandable by pytest. Each test has this wrapper specified after the shebang. When kyua executes the test, wrapper calls pytest, which loads atf plugin, does the work and returns the result. Additionally, a separate python "package", `/usr/tests/atf_python` has been added to collect code that may be useful across different tests. Current limitations: * Opaque metadata passing via X-Name properties. Require some fixtures to write * `-s srcdir` parameter passed by the runner is ignored. * No `atf-c-api(3)` or similar - relying on pytest framework & existing python libraries * No support for `atf_tc__config_var()` & `atf_tc_set_md_var()`. Can be probably implemented with env variables & autoload fixtures Differential Revision: https://reviews.freebsd.org/D31084 Reviewed by: kp, ngie (cherry picked from commit 8eb2bee6c0f4957c6c1cea826e59cda4d18a2a64) --- share/mk/atf.test.mk | 47 ++ tests/Makefile | 4 +- tests/__init__.py | 0 tests/atf_python/Makefile | 12 + tests/atf_python/__init__.py | 4 + tests/atf_python/atf_pytest.py | 218 +++++++++ tests/atf_python/sys/Makefile | 11 + tests/atf_python/sys/__init__.py | 0 tests/atf_python/sys/net/Makefile | 10 + tests/atf_python/sys/net/__init__.py | 0 tests/atf_python/sys/net/rtsock.py | 604 ++++++++++++++++++++++++ tests/atf_python/sys/net/tools.py | 33 ++ tests/atf_python/sys/net/vnet.py | 203 ++++++++ tests/conftest.py | 121 +++++ tests/freebsd_test_suite/Makefile | 13 + tests/freebsd_test_suite/atf_pytest_wrapper.cpp | 192 ++++++++ 16 files changed, 1471 insertions(+), 1 deletion(-) diff --git a/share/mk/atf.test.mk b/share/mk/atf.test.mk index e7a8c82b7a8f..c00c00bc9b3d 100644 --- a/share/mk/atf.test.mk +++ b/share/mk/atf.test.mk @@ -22,6 +22,7 @@ ATF_TESTS_C?= ATF_TESTS_CXX?= ATF_TESTS_SH?= ATF_TESTS_KSH93?= +ATF_TESTS_PYTEST?= .if !empty(ATF_TESTS_C) PROGS+= ${ATF_TESTS_C} @@ -109,3 +110,49 @@ ${_T}: ${ATF_TESTS_KSH93_SRC_${_T}} mv ${.TARGET}.tmp ${.TARGET} .endfor .endif + +.if !empty(ATF_TESTS_PYTEST) +# bsd.prog.mk SCRIPTS interface removes file extension unless +# SCRIPTSNAME is set, which is not possible to do here. +# Workaround this by appending another extension (.xtmp) to the +# file name. Use separate loop to avoid dealing with explicitly +# stating expansion for each and every variable. +# +# ATF_TESTS_PYTEST -> contains list of files as is (test_something.py ..) +# _ATF_TESTS_PYTEST -> (test_something.py.xtmp ..) +# +# Former array is iterated to construct Kyuafile, where original file +# names need to be written. +# Latter array is iterated to enable bsd.prog.mk scripts framework - +# namely, installing scripts without .xtmp prefix. Note: this allows to +# not bother about the fact that make target needs to be different from +# the source file. +_TESTS+= ${ATF_TESTS_PYTEST} +_ATF_TESTS_PYTEST= +.for _T in ${ATF_TESTS_PYTEST} +_ATF_TESTS_PYTEST += ${_T}.xtmp +TEST_INTERFACE.${_T}= atf +TEST_METADATA.${_T}+= required_programs="pytest" +.endfor + +SCRIPTS+= ${_ATF_TESTS_PYTEST} +.for _T in ${_ATF_TESTS_PYTEST} +SCRIPTSDIR_${_T}= ${TESTSDIR} +CLEANFILES+= ${_T} ${_T}.tmp +# TODO(jmmv): It seems to me that this SED and SRC functionality should +# exist in bsd.prog.mk along the support for SCRIPTS. Move it there if +# this proves to be useful within the tests. +ATF_TESTS_PYTEST_SED_${_T}?= # empty +ATF_TESTS_PYTEST_SRC_${_T}?= ${.CURDIR}/${_T:S,.xtmp$,,} +${_T}: + echo "#!${TESTSBASE}/atf_pytest_wrapper -P ${TESTSBASE}" > ${.TARGET}.tmp +.if empty(ATF_TESTS_PYTEST_SED_${_T}) + cat ${ATF_TESTS_PYTEST_SRC_${_T}} >>${.TARGET}.tmp +.else + cat ${ATF_TESTS_PYTEST_SRC_${_T}} \ + | sed ${ATF_TESTS_PYTEST_SED_${_T}} >>${.TARGET}.tmp +.endif + chmod +x ${.TARGET}.tmp + mv ${.TARGET}.tmp ${.TARGET} +.endfor +.endif diff --git a/tests/Makefile b/tests/Makefile index 561a0ec5fcab..cfd065d61539 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -4,12 +4,14 @@ PACKAGE= tests TESTSDIR= ${TESTSBASE} -${PACKAGE}FILES+= README +${PACKAGE}FILES+= README __init__.py conftest.py KYUAFILE= yes SUBDIR+= etc SUBDIR+= sys +SUBDIR+= atf_python +SUBDIR+= freebsd_test_suite SUBDIR_PARALLEL= diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile new file mode 100644 index 000000000000..26d419743257 --- /dev/null +++ b/tests/atf_python/Makefile @@ -0,0 +1,12 @@ +.include + +.PATH: ${.CURDIR} + +FILES= __init__.py atf_pytest.py +SUBDIR= sys + +.include +FILESDIR= ${TESTSBASE}/atf_python + + +.include diff --git a/tests/atf_python/__init__.py b/tests/atf_python/__init__.py new file mode 100644 index 000000000000..6d5ec22ef054 --- /dev/null +++ b/tests/atf_python/__init__.py @@ -0,0 +1,4 @@ +import pytest + +pytest.register_assert_rewrite("atf_python.sys.net.rtsock") +pytest.register_assert_rewrite("atf_python.sys.net.vnet") diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py new file mode 100644 index 000000000000..89c0e3a515b9 --- /dev/null +++ b/tests/atf_python/atf_pytest.py @@ -0,0 +1,218 @@ +import types +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple + +import pytest + + +class ATFCleanupItem(pytest.Item): + def runtest(self): + """Runs cleanup procedure for the test instead of the test""" + instance = self.parent.cls() + instance.cleanup(self.nodeid) + + def setup_method_noop(self, method): + """Overrides runtest setup method""" + pass + + def teardown_method_noop(self, method): + """Overrides runtest teardown method""" + pass + + +class ATFTestObj(object): + def __init__(self, obj, has_cleanup): + # Use nodeid without name to properly name class-derived tests + self.ident = obj.nodeid.split("::", 1)[1] + self.description = self._get_test_description(obj) + self.has_cleanup = has_cleanup + self.obj = obj + + def _get_test_description(self, obj): + """Returns first non-empty line from func docstring or func name""" + docstr = obj.function.__doc__ + if docstr: + for line in docstr.split("\n"): + if line: + return line + return obj.name + + def _convert_marks(self, obj) -> Dict[str, Any]: + wj_func = lambda x: " ".join(x) # noqa: E731 + _map: Dict[str, Dict] = { + "require_user": {"name": "require.user"}, + "require_arch": {"name": "require.arch", "fmt": wj_func}, + "require_diskspace": {"name": "require.diskspace"}, + "require_files": {"name": "require.files", "fmt": wj_func}, + "require_machine": {"name": "require.machine", "fmt": wj_func}, + "require_memory": {"name": "require.memory"}, + "require_progs": {"name": "require.progs", "fmt": wj_func}, + "timeout": {}, + } + ret = {} + for mark in obj.iter_markers(): + if mark.name in _map: + name = _map[mark.name].get("name", mark.name) + if "fmt" in _map[mark.name]: + val = _map[mark.name]["fmt"](mark.args[0]) + else: + val = mark.args[0] + ret[name] = val + return ret + + def as_lines(self) -> List[str]: + """Output test definition in ATF-specific format""" + ret = [] + ret.append("ident: {}".format(self.ident)) + ret.append("descr: {}".format(self._get_test_description(self.obj))) + if self.has_cleanup: + ret.append("has.cleanup: true") + for key, value in self._convert_marks(self.obj).items(): + ret.append("{}: {}".format(key, value)) + return ret + + +class ATFHandler(object): + class ReportState(NamedTuple): + state: str + reason: str + + def __init__(self): + self._tests_state_map: Dict[str, ReportStatus] = {} + + def override_runtest(self, obj): + # Override basic runtest command + obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj) + # Override class setup/teardown + obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop + obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop + + def get_object_cleanup_class(self, obj): + if hasattr(obj, "parent") and obj.parent is not None: + if hasattr(obj.parent, "cls") and obj.parent.cls is not None: + if hasattr(obj.parent.cls, "cleanup"): + return obj.parent.cls + return None + + def has_object_cleanup(self, obj): + return self.get_object_cleanup_class(obj) is not None + + def list_tests(self, tests: List[str]): + print('Content-Type: application/X-atf-tp; version="1"') + print() + for test_obj in tests: + has_cleanup = self.has_object_cleanup(test_obj) + atf_test = ATFTestObj(test_obj, has_cleanup) + for line in atf_test.as_lines(): + print(line) + print() + + def set_report_state(self, test_name: str, state: str, reason: str): + self._tests_state_map[test_name] = self.ReportState(state, reason) + + def _extract_report_reason(self, report): + data = report.longrepr + if data is None: + return None + if isinstance(data, Tuple): + # ('/path/to/test.py', 23, 'Skipped: unable to test') + reason = data[2] + for prefix in "Skipped: ": + if reason.startswith(prefix): + reason = reason[len(prefix):] + return reason + else: + # string/ traceback / exception report. Capture the last line + return str(data).split("\n")[-1] + return None + + def add_report(self, report): + # MAP pytest report state to the atf-desired state + # + # ATF test states: + # (1) expected_death, (2) expected_exit, (3) expected_failure + # (4) expected_signal, (5) expected_timeout, (6) passed + # (7) skipped, (8) failed + # + # Note that ATF don't have the concept of "soft xfail" - xpass + # is a failure. It also calls teardown routine in a separate + # process, thus teardown states (pytest-only) are handled as + # body continuation. + + # (stage, state, wasxfail) + + # Just a passing test: WANT: passed + # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F) + # + # Failing body test: WHAT: failed + # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F) + # + # pytest.skip test decorator: WANT: skipped + # GOT: (setup,skipped, False), (teardown, passed, False) + # + # pytest.skip call inside test function: WANT: skipped + # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F) + # + # mark.xfail decorator+pytest.xfail: WANT: expected_failure + # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F) + # + # mark.xfail decorator+pass: WANT: failed + # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F) + + test_name = report.location[2] + stage = report.when + state = report.outcome + reason = self._extract_report_reason(report) + + # We don't care about strict xfail - it gets translated to False + + if stage == "setup": + if state in ("skipped", "failed"): + # failed init -> failed test, skipped setup -> xskip + # for the whole test + self.set_report_state(test_name, state, reason) + elif stage == "call": + # "call" stage shouldn't matter if setup failed + if test_name in self._tests_state_map: + if self._tests_state_map[test_name].state == "failed": + return + if state == "failed": + # Record failure & override "skipped" state + self.set_report_state(test_name, state, reason) + elif state == "skipped": + if hasattr(reason, "wasxfail"): + # xfail() called in the test body + state = "expected_failure" + else: + # skip inside the body + pass + self.set_report_state(test_name, state, reason) + elif state == "passed": + if hasattr(reason, "wasxfail"): + # the test was expected to fail but didn't + # mark as hard failure + state = "failed" + self.set_report_state(test_name, state, reason) + elif stage == "teardown": + if state == "failed": + # teardown should be empty, as the cleanup + # procedures should be implemented as a separate + # function/method, so mark teardown failure as + # global failure + self.set_report_state(test_name, state, reason) + + def write_report(self, path): + if self._tests_state_map: + # If we're executing in ATF mode, there has to be just one test + # Anyway, deterministically pick the first one + first_test_name = next(iter(self._tests_state_map)) + test = self._tests_state_map[first_test_name] + if test.state == "passed": + line = test.state + else: + line = "{}: {}".format(test.state, test.reason) + with open(path, mode="w") as f: + print(line, file=f) diff --git a/tests/atf_python/sys/Makefile b/tests/atf_python/sys/Makefile new file mode 100644 index 000000000000..ff4cf17b85d2 --- /dev/null +++ b/tests/atf_python/sys/Makefile @@ -0,0 +1,11 @@ +.include + +.PATH: ${.CURDIR} + +FILES= __init__.py +SUBDIR= net + +.include +FILESDIR= ${TESTSBASE}/atf_python/sys + +.include diff --git a/tests/atf_python/sys/__init__.py b/tests/atf_python/sys/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/sys/net/Makefile b/tests/atf_python/sys/net/Makefile new file mode 100644 index 000000000000..05b1d8afe863 --- /dev/null +++ b/tests/atf_python/sys/net/Makefile @@ -0,0 +1,10 @@ +.include + +.PATH: ${.CURDIR} + +FILES= __init__.py rtsock.py tools.py vnet.py + +.include +FILESDIR= ${TESTSBASE}/atf_python/sys/net + +.include diff --git a/tests/atf_python/sys/net/__init__.py b/tests/atf_python/sys/net/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/atf_python/sys/net/rtsock.py b/tests/atf_python/sys/net/rtsock.py new file mode 100755 index 000000000000..788e863f8b28 --- /dev/null +++ b/tests/atf_python/sys/net/rtsock.py @@ -0,0 +1,604 @@ +#!/usr/local/bin/python3 +import os +import socket +import struct +import sys +from ctypes import c_byte +from ctypes import c_char +from ctypes import c_int +from ctypes import c_long +from ctypes import c_uint32 +from ctypes import c_ulong +from ctypes import c_ushort +from ctypes import sizeof +from ctypes import Structure +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtSockException(OSError): + pass + + +class RtConst: + RTM_VERSION = 5 + ALIGN = sizeof(c_long) + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + RTV_MTU = 0x1 + RTV_HOPCOUNT = 0x2 + RTV_EXPIRE = 0x4 + RTV_RPIPE = 0x8 + RTV_SPIPE = 0x10 + RTV_SSTHRESH = 0x20 + RTV_RTT = 0x40 + RTV_RTTVAR = 0x80 + RTV_WEIGHT = 0x100 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:{}:{}".format(prefix, value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def is_ipv6(ip: str) -> bool: + return ":" in ip + + @staticmethod + def ip_sa(ip: str, scopeid: int = 0) -> bytes: + if SaHelper.is_ipv6(ip): + return SaHelper.ip6_sa(ip, scopeid) + else: + return SaHelper.ip4_sa(ip) + + @staticmethod + def ip4_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int) -> bytes: + addr_bytes = (c_byte * 16)() + for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)): + addr_bytes[i] = b + sin6 = SockaddrIn6( + sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid + ) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes: + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + @staticmethod + def pxlen4_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen)) + + @staticmethod + def pxlen_to_ip4(pxlen: int) -> str: + if pxlen == 32: + return "255.255.255.255" + else: + addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1) + addr_bytes = struct.pack("!I", addr) + return socket.inet_ntop(socket.AF_INET, addr_bytes) + + @staticmethod + def pxlen6_sa(pxlen: int) -> bytes: + return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen)) + + @staticmethod + def pxlen_to_ip6(pxlen: int) -> str: + ip6_b = [0] * 16 + start = 0 + while pxlen > 8: + ip6_b[start] = 0xFF + pxlen -= 8 + start += 1 + ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1) + return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b)) + + @staticmethod + def print_sa_inet(sa: bytes): + if len(sa) < 8: + raise RtSockException("IPv4 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + @staticmethod + def print_sa_inet6(sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise RtSockException("IPv6 sa size too small: {}".format(len(sa))) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28])[0] + return "{} scopeid {}".format(addr, scopeid) + + @staticmethod + def print_sa_link(sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise RtSockException("LINK sa size too small: {}".format(len(sa))) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise RtSockException( + "LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)) + ) + ifname = "ifname:{} ".format( + bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen]) + ) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + @staticmethod + def print_sa_unknown(sa: bytes): + return "unknown_type:{}".format(sa[1]) + + @classmethod + def print_sa(cls, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception( + "sa type {} too short: {}".format( + RtConst.get_name("AF_", sa[1]), len(sa) + ) + ) + + if sa[1] == socket.AF_INET: + text = cls.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = cls.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = cls.print_sa_link(sa) + else: + text = cls.print_sa_unknown(sa) + if hd: + dump = " [{!r}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.sa = SaHelper() + + @staticmethod + def print_rtm_type(rtm_type): + return RtConst.get_name("RTM_", rtm_type) + + @property + def rtm_type_str(self): + return self.print_rtm_type(self.rtm_type) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [ + RtConst.RTM_ADD, + RtConst.RTM_DELETE, + RtConst.RTM_CHANGE, + RtConst.RTM_GET, + ] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self.rtm_inits = 0 + self.rtm_rmx = RtMetrics() + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes: bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0): + if ":" in ip_addr: + self.add_ip6_attr(attr_type, ip_addr, scopeid) + else: + self.add_ip4_attr(attr_type, ip_addr) + + def add_ip4_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int] = 0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def get_sa(self, attr_type) -> bytes: + return self._attrs.get(attr_type) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags: + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print( + "{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + self.rtm_type_str, + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags), + ) + ) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = SaHelper.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + def print_in_message(self): + print("vvvvvvvv IN vvvvvvvv") + self.print_message() + print() + + def verify_sa_inet(self, sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception( + "IPv4 sin_len too big: {} vs sa size {}: {}".format( + sa_data[0], len(sa_data), sa_data + ) + ) + sin = SockaddrIn.from_buffer_copy(sa_data) + assert sin.sin_port == 0 + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception( + "sa_len for type {} too short: {}".format(sa_type_name, len(sa_data)) + ) + our_sa = self._attrs[sa_type] + assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa) + assert len(sa_data) == len(our_sa) + assert sa_data == our_sa + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type_str == self.print_rtm_type(rtm_type) + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + sa_type_name = RtConst.get_name("RTA_", sa_type) + raise Exception("SA type {} not present".format(sa_type_name)) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception( + "messages size {} is less than expected {}".format( + len(data), sizeof(RtMsgHdr) + ) + ) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_inits = hdr.rtm_inits + self.rtm_rmx = hdr.rtm_rmx + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception( + "SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data) + ) + ) + self._attrs[v] = data[off : off + data[off]] + off += roundup2(data[off], RtConst.ALIGN) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), RtConst.ALIGN) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + rtm_inits=self.rtm_inits, + rtm_rmx=self.rtm_rmx, + ) + buf = bytearray(sz) + buf[0 : sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off : off + sa_len] = v + off += roundup2(len(v), RtConst.ALIGN) + return bytes(buf) + + +class Rtsock: + def __init__(self): + self.socket = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def get_weight(self, weight) -> int: + if weight: + return weight + else: + return 1 # RT_DEFAULT_WEIGHT + + def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]): + px = prefix.split("/") + addr_sa = SaHelper.ip_sa(px[0]) + if len(px) > 1: + pxlen = int(px[1]) + if SaHelper.is_ipv6(px[0]): + mask_sa = SaHelper.pxlen6_sa(pxlen) + else: + mask_sa = SaHelper.pxlen4_sa(pxlen) + else: + mask_sa = None + msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa) + if isinstance(gw, bytes): + msg.add_sa_attr(RtConst.RTA_GATEWAY, gw) + else: + # String + msg.add_ip_attr(RtConst.RTA_GATEWAY, gw) + return msg + + def new_rtm_add(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw) + + def new_rtm_del(self, prefix: str, gw: Union[str, bytes]): + return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw) + + def new_rtm_change(self, prefix: str, gw: Union[str, bytes]): *** 641 LINES SKIPPED ***