Skip to content

Commit

Permalink
Merge pull request #4300 from jestabro/configd-inspect-by-ast
Browse files Browse the repository at this point in the history
T7042: drop use of inspect module in favor of ast for source analysis
  • Loading branch information
jestabro authored Jan 11, 2025
2 parents 9f6a986 + d5b1bfc commit 8a83a97
Showing 1 changed file with 134 additions and 76 deletions.
210 changes: 134 additions & 76 deletions src/tests/test_configd_inspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2020-2024 VyOS maintainers and contributors
# Copyright (C) 2020-2025 VyOS maintainers and contributors
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
Expand All @@ -12,93 +12,151 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import re
import ast
import json

import warnings
import importlib.util
from inspect import signature
from inspect import getsource
from functools import wraps
from unittest import TestCase

INC_FILE = 'data/configd-include.json'
CONF_DIR = 'src/conf_mode'

f_list = ['get_config', 'verify', 'generate', 'apply']

def import_script(s):
path = os.path.join(CONF_DIR, s)
name = os.path.splitext(s)[0].replace('-', '_')
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

# importing conf_mode scripts imports jinja2 with deprecation warning
def ignore_deprecation_warning(f):
@wraps(f)
def decorated_function(*args, **kwargs):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f(*args, **kwargs)
return decorated_function
funcs = ['get_config', 'verify', 'generate', 'apply']


class FunctionSig(ast.NodeVisitor):
def __init__(self):
self.func_sig_len = dict.fromkeys(funcs, None)
self.get_config_default_values = []

def visit_FunctionDef(self, node):
func_name = node.name
if func_name in funcs:
self.func_sig_len[func_name] = len(node.args.args)

if func_name == 'get_config':
for default in node.args.defaults:
if isinstance(default, ast.Constant):
self.get_config_default_values.append(default.value)

self.generic_visit(node)

def get_sig_lengths(self):
return self.func_sig_len

def get_config_default(self):
return self.get_config_default_values[0]


class LegacyCall(ast.NodeVisitor):
def __init__(self):
self.legacy_func_count = 0

def visit_Constant(self, node):
value = node.value
if isinstance(value, str):
if 'my_set' in value or 'my_delete' in value:
self.legacy_func_count += 1

self.generic_visit(node)

def get_legacy_func_count(self):
return self.legacy_func_count


class ConfigInstance(ast.NodeVisitor):
def __init__(self):
self.count = 0

def visit_Call(self, node):
if isinstance(node.func, ast.Name):
name = node.func.id
if name == 'Config':
self.count += 1
self.generic_visit(node)

def get_count(self):
return self.count


class FunctionConfigInstance(ast.NodeVisitor):
def __init__(self):
self.func_config_instance = dict.fromkeys(funcs, 0)

def visit_FunctionDef(self, node):
func_name = node.name
if func_name in funcs:
config_instance = ConfigInstance()
config_instance.visit(node)
self.func_config_instance[func_name] = config_instance.get_count()
self.generic_visit(node)

def get_func_config_instance(self):
return self.func_config_instance


class TestConfigdInspect(TestCase):
def setUp(self):
self.ast_list = []

with open(INC_FILE) as f:
self.inc_list = json.load(f)

@ignore_deprecation_warning
def test_signatures(self):
for s in self.inc_list:
m = import_script(s)
for i in f_list:
f = getattr(m, i, None)
self.assertIsNotNone(f, f"'{s}': missing function '{i}'")
sig = signature(f)
par = sig.parameters
l = len(par)
self.assertEqual(l, 1,
f"'{s}': '{i}' incorrect signature")
if i == 'get_config':
for p in par.values():
self.assertTrue(p.default is None,
f"'{s}': '{i}' incorrect signature")

@ignore_deprecation_warning
def test_function_instance(self):
for s in self.inc_list:
m = import_script(s)
for i in f_list:
f = getattr(m, i, None)
if not f:
continue
str_f = getsource(f)
# Regex not XXXConfig() T3108
n = len(re.findall(r'[^a-zA-Z]Config\(\)', str_f))
if i == 'get_config':
self.assertEqual(n, 1,
f"'{s}': '{i}' no instance of Config")
if i != 'get_config':
self.assertEqual(n, 0,
f"'{s}': '{i}' instance of Config")

@ignore_deprecation_warning
def test_file_instance(self):
for s in self.inc_list:
m = import_script(s)
str_m = getsource(m)
# Regex not XXXConfig T3108
n = len(re.findall(r'[^a-zA-Z]Config\(\)', str_m))
self.assertEqual(n, 1,
f"'{s}' more than one instance of Config")

@ignore_deprecation_warning
s_path = f'{CONF_DIR}/{s}'
with open(s_path) as f:
s_str = f.read()
s_tree = ast.parse(s_str)
self.ast_list.append((s, s_tree))

def test_signatures(self):
for s, t in self.ast_list:
visitor = FunctionSig()
visitor.visit(t)
sig_lens = visitor.get_sig_lengths()

for f in funcs:
self.assertIsNotNone(sig_lens[f], f"'{s}': '{f}' missing")
self.assertEqual(sig_lens[f], 1, f"'{s}': '{f}' incorrect signature")

self.assertEqual(
visitor.get_config_default(),
None,
f"'{s}': 'get_config' incorrect signature",
)

def test_file_config_instance(self):
for s, t in self.ast_list:
visitor = ConfigInstance()
visitor.visit(t)
count = visitor.get_count()

self.assertEqual(count, 1, f"'{s}' more than one instance of Config")

def test_function_config_instance(self):
for s, t in self.ast_list:
visitor = FunctionConfigInstance()
visitor.visit(t)
func_config_instance = visitor.get_func_config_instance()

for f in funcs:
if f == 'get_config':
self.assertTrue(
func_config_instance[f] > 0,
f"'{s}': '{f}' no instance of Config",
)
self.assertTrue(
func_config_instance[f] < 2,
f"'{s}': '{f}' more than one instance of Config",
)
else:
self.assertEqual(
func_config_instance[f], 0, f"'{s}': '{f}' instance of Config"
)

def test_config_modification(self):
for s in self.inc_list:
m = import_script(s)
str_m = getsource(m)
n = str_m.count('my_set')
self.assertEqual(n, 0, f"'{s}' modifies config")
for s, t in self.ast_list:
visitor = LegacyCall()
visitor.visit(t)
legacy_func_count = visitor.get_legacy_func_count()

self.assertEqual(legacy_func_count, 0, f"'{s}' modifies config")

0 comments on commit 8a83a97

Please sign in to comment.