diff --git a/docs/inputs/helper.md b/docs/inputs/helper.md index c26858220c..bb16a96de2 100644 --- a/docs/inputs/helper.md +++ b/docs/inputs/helper.md @@ -54,3 +54,20 @@ def stream_events(inputs: smi.InputDefinition, event_writer: smi.EventWriter): ``` The two methods' bodies should be filled by the developer. + +Alternatively, if you want to have access to the instance of the input script class, +you can also add the `self` parameter to the methods: + +```python +from splunklib import modularinput as smi + + +def validate_input(self, definition: smi.ValidationDefinition): + ... + + +def stream_events(self, inputs: smi.InputDefinition, event_writer: smi.EventWriter): + ... +``` + +Instead of `self`, you can also use any other name, but it must be the first parameter. diff --git a/splunk_add_on_ucc_framework/templates/input.template b/splunk_add_on_ucc_framework/templates/input.template index 2b1b4f63c9..d56d788e57 100644 --- a/splunk_add_on_ucc_framework/templates/input.template +++ b/splunk_add_on_ucc_framework/templates/input.template @@ -4,6 +4,7 @@ import import_declare_test import json {% endif -%} import sys +import inspect from splunklib import modularinput as smi {%- if input_helper_module %} @@ -42,14 +43,14 @@ class {{class_name}}(smi.Script): def validate_input(self, definition: smi.ValidationDefinition): {%- if input_helper_module %} - return validate_input(definition) + return self.call_with_args(validate_input, definition) {%- else %} return {%- endif %} def stream_events(self, inputs: smi.InputDefinition, ew: smi.EventWriter): {%- if input_helper_module %} - return stream_events(inputs, ew) + return self.call_with_args(stream_events, inputs, ew) {%- else %} input_items = [{'count': len(inputs.inputs)}] for input_name, input_item in inputs.inputs.items(): @@ -62,6 +63,16 @@ class {{class_name}}(smi.Script): ew.write_event(event) {%- endif %} + {% if input_helper_module -%} + def call_with_args(self, method, *args): + method_args_list = inspect.getfullargspec(method).args + + if len(method_args_list) == len(args) + 1: + return method(self, *args) + + return method(*args) + {%- endif %} + if __name__ == '__main__': exit_code = {{class_name}}().run(sys.argv) diff --git a/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_one.py b/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_one.py index 3c6317bc20..c9f85bbe36 100644 --- a/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_one.py +++ b/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_one.py @@ -2,6 +2,7 @@ import json import sys +import inspect from splunklib import modularinput as smi @@ -42,6 +43,8 @@ def stream_events(self, inputs: smi.InputDefinition, ew: smi.EventWriter): ew.write_event(event) + + if __name__ == '__main__': exit_code = EXAMPLE_INPUT_ONE().run(sys.argv) sys.exit(exit_code) \ No newline at end of file diff --git a/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_two.py b/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_two.py index 8b702d393c..56d53449d8 100644 --- a/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_two.py +++ b/tests/testdata/expected_addons/expected_addon_no_configuration/Splunk_TA_UCCExample/bin/example_input_two.py @@ -2,6 +2,7 @@ import json import sys +import inspect from splunklib import modularinput as smi @@ -42,6 +43,8 @@ def stream_events(self, inputs: smi.InputDefinition, ew: smi.EventWriter): ew.write_event(event) + + if __name__ == '__main__': exit_code = EXAMPLE_INPUT_TWO().run(sys.argv) sys.exit(exit_code) \ No newline at end of file diff --git a/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_one.py b/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_one.py index d662e2ce8e..ac1d15ea27 100644 --- a/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_one.py +++ b/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_one.py @@ -1,6 +1,7 @@ import import_declare_test import sys +import inspect from splunklib import modularinput as smi from helper_one import stream_events, validate_input @@ -124,12 +125,20 @@ def get_scheme(self): return scheme def validate_input(self, definition: smi.ValidationDefinition): - return validate_input(definition) + return self.call_with_args(validate_input, definition) def stream_events(self, inputs: smi.InputDefinition, ew: smi.EventWriter): - return stream_events(inputs, ew) + return self.call_with_args(stream_events, inputs, ew) + + def call_with_args(self, method, *args): + method_args_list = inspect.getfullargspec(method).args + + if len(method_args_list) == len(args) + 1: + return method(self, *args) + + return method(*args) if __name__ == '__main__': exit_code = EXAMPLE_INPUT_ONE().run(sys.argv) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_two.py b/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_two.py index f131132d97..f4221ebac4 100644 --- a/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_two.py +++ b/tests/testdata/expected_addons/expected_output_global_config_everything/Splunk_TA_UCCExample/bin/example_input_two.py @@ -1,6 +1,7 @@ import import_declare_test import sys +import inspect from splunklib import modularinput as smi from helper_two import stream_events, validate_input @@ -100,12 +101,20 @@ def get_scheme(self): return scheme def validate_input(self, definition: smi.ValidationDefinition): - return validate_input(definition) + return self.call_with_args(validate_input, definition) def stream_events(self, inputs: smi.InputDefinition, ew: smi.EventWriter): - return stream_events(inputs, ew) + return self.call_with_args(stream_events, inputs, ew) + + def call_with_args(self, method, *args): + method_args_list = inspect.getfullargspec(method).args + + if len(method_args_list) == len(args) + 1: + return method(self, *args) + + return method(*args) if __name__ == '__main__': exit_code = EXAMPLE_INPUT_TWO().run(sys.argv) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/tests/unit/test_templates.py b/tests/unit/test_templates.py new file mode 100644 index 0000000000..d30eba1968 --- /dev/null +++ b/tests/unit/test_templates.py @@ -0,0 +1,62 @@ +import importlib.util +import sys +from textwrap import dedent +from unittest.mock import MagicMock + +import pytest + +from splunk_add_on_ucc_framework import utils + + +class Script: + pass + + +@pytest.mark.parametrize("with_self", [True, False]) +def test_input_helpers(tmp_path, monkeypatch, with_self): + # To avoid issues with imports, we need to monkeypatch the sys.modules with a different dict + monkeypatch.setattr(sys, "modules", {k: v for k, v in sys.modules.items()}) + + content = ( + utils.get_j2_env() + .get_template("input.template") + .render( + input_name="MyInput", + class_name="MyClass", + description="My Input Description", + entity=[], + input_helper_module="my_input_helper", + ) + ) + + (tmp_path / "my_input.py").write_text(content) + + if with_self: + self_arg = "self, " + else: + self_arg = "" + + (tmp_path / "my_input_helper.py").write_text( + dedent( + f""" + def validate_input({self_arg}definition): + return definition + + + def stream_events({self_arg}inputs, event_writer): + return inputs, event_writer + """ + ) + ) + + for module in ["import_declare_test"]: + # mock module in sys.modules - set to MagicMock + mock_module = MagicMock() + mock_module.__file__ = str(tmp_path / f"{module}.py") + monkeypatch.setitem(sys.modules, module, mock_module) + + monkeypatch.syspath_prepend(str(tmp_path)) + my_obj = importlib.import_module("my_input").MyClass() + + assert my_obj.validate_input("arg1") == "arg1" + assert my_obj.stream_events("arg1", "arg2") == ("arg1", "arg2")