diff --git a/oletools/common/log_helper/_json_formatter.py b/oletools/common/log_helper/_json_formatter.py index 1f334a81e..8e1c6609b 100644 --- a/oletools/common/log_helper/_json_formatter.py +++ b/oletools/common/log_helper/_json_formatter.py @@ -8,6 +8,10 @@ class JsonFormatter(logging.Formatter): """ _is_first_line = True + def __init__(self, other_logger_has_first_line=False): + if other_logger_has_first_line: + self._is_first_line = False + def format(self, record): """ Since we don't buffer messages, we always prepend messages with a comma to make diff --git a/oletools/common/log_helper/_logger_adapter.py b/oletools/common/log_helper/_logger_adapter.py index ee52b26da..dfc748f89 100644 --- a/oletools/common/log_helper/_logger_adapter.py +++ b/oletools/common/log_helper/_logger_adapter.py @@ -54,4 +54,9 @@ def set_json_enabled_function(self, json_enabled): self._json_enabled = json_enabled def level(self): + """Return current level of logger.""" return self.logger.level + + def setLevel(self, new_level): + """Set level of underlying logger. Required only for python < 3.2.""" + return self.logger.setLevel(new_level) diff --git a/oletools/common/log_helper/log_helper.py b/oletools/common/log_helper/log_helper.py index 2ef7d8783..9ec9843ae 100644 --- a/oletools/common/log_helper/log_helper.py +++ b/oletools/common/log_helper/log_helper.py @@ -3,6 +3,26 @@ General logging helpers +Use as follows: + + # at the start of your file: + # import logging <-- replace this with next line + from oletools.common.log_helper import log_helper + + logger = log_helper.get_or_create_silent_logger("module_name") + def enable_logging(): + '''Enable logging in this module; for use by importing scripts''' + logger.setLevel(log_helper.NOTSET) + imported_oletool_module.enable_logging() + other_imported_oletool_module.enable_logging() + + # ... your code; use logger instead of logging ... + + def main(): + log_helper.enable_logging(level=...) # instead of logging.basicConfig + # ... your main code ... + log_helper.end_logging() + .. codeauthor:: Intra2net AG , Philippe Lagadec """ @@ -45,6 +65,7 @@ # TODO: +from __future__ import print_function from ._json_formatter import JsonFormatter from ._logger_adapter import OletoolsLoggerAdapter from . import _root_logger_wrapper @@ -61,15 +82,27 @@ 'critical': logging.CRITICAL } +#: provide this constant to modules, so they do not have to import +#: :py:mod:`logging` for themselves just for this one constant. +NOTSET = logging.NOTSET + DEFAULT_LOGGER_NAME = 'oletools' DEFAULT_MESSAGE_FORMAT = '%(levelname)-8s %(message)s' class LogHelper: + """ + Single helper class that creates and remembers loggers. + """ + + #: for convenience: here again (see also :py:data:`log_helper.NOTSET`) + NOTSET = logging.NOTSET + def __init__(self): self._all_names = set() # set so we do not have duplicates self._use_json = False self._is_enabled = False + self._target_stream = None def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CRITICAL + 1): """ @@ -82,7 +115,8 @@ def get_or_create_silent_logger(self, name=DEFAULT_LOGGER_NAME, level=logging.CR """ return self._get_or_create_logger(name, level, logging.NullHandler()) - def enable_logging(self, use_json=False, level='warning', log_format=DEFAULT_MESSAGE_FORMAT, stream=None): + def enable_logging(self, use_json=False, level='warning', log_format=DEFAULT_MESSAGE_FORMAT, stream=None, + other_logger_has_first_line=False): """ This function initializes the root logger and enables logging. We set the level of the root logger to the one passed by calling logging.basicConfig. @@ -93,15 +127,26 @@ def enable_logging(self, use_json=False, level='warning', log_format=DEFAULT_MES which in turn will log to the stream set in this function. Since the root logger is the one doing the work, when using JSON we set its formatter so that every message logged is JSON-compatible. + + If other code also creates json output, all items should be pre-pended + with a comma like the `JsonFormatter` does. Except the first; use param + `other_logger_has_first_line` to clarify whether our logger or the + other code will produce the first json item. """ if self._is_enabled: raise ValueError('re-enabling logging. Not sure whether that is ok...') - if stream in (None, sys.stdout): + if stream is None: + self.target_stream = sys.stdout + else: + self.target_stream = stream + + if self.target_stream == sys.stdout: ensure_stdout_handles_unicode() log_level = LOG_LEVELS[level] - logging.basicConfig(level=log_level, format=log_format, stream=stream) + logging.basicConfig(level=log_level, format=log_format, + stream=self.target_stream) self._is_enabled = True self._use_json = use_json @@ -115,8 +160,8 @@ def enable_logging(self, use_json=False, level='warning', log_format=DEFAULT_MES # add a JSON formatter to the root logger, which will be used by every logger if self._use_json: - _root_logger_wrapper.set_formatter(JsonFormatter()) - print('[') + _root_logger_wrapper.set_formatter(JsonFormatter(other_logger_has_first_line)) + print('[', file=self.target_stream) def end_logging(self): """ @@ -133,7 +178,7 @@ def end_logging(self): # end json list if self._use_json: - print(']') + print(']', file=self.target_stream) self._use_json = False def _get_except_hook(self, old_hook): diff --git a/oletools/crypto.py b/oletools/crypto.py index f4d3ace5b..9079ae91c 100644 --- a/oletools/crypto.py +++ b/oletools/crypto.py @@ -95,6 +95,7 @@ def script_main_function(input_file, passwords, crypto_nesting=0, args): # 2019-05-23 PL: - added DEFAULT_PASSWORDS list # 2021-05-22 v0.60 PL: - added PowerPoint transparent password # '/01Hannes Ruescher/01' (issue #627) +# 2019-05-24 CH: - use log_helper __version__ = '0.60' @@ -104,7 +105,6 @@ def script_main_function(input_file, passwords, crypto_nesting=0, args): from os.path import splitext, isfile from tempfile import mkstemp import zipfile -import logging from olefile import OleFileIO @@ -134,44 +134,20 @@ def script_main_function(input_file, passwords, crypto_nesting=0, args): # === LOGGING ================================================================= -# TODO: use log_helper instead - -def get_logger(name, level=logging.CRITICAL+1): - """ - Create a suitable logger object for this module. - The goal is not to change settings of the root logger, to avoid getting - other modules' logs on the screen. - If a logger exists with same name, reuse it. (Else it would have duplicate - handlers and messages would be doubled.) - The level is set to CRITICAL+1 by default, to avoid any logging. - """ - # First, test if there is already a logger with the same name, else it - # will generate duplicate messages (due to duplicate handlers): - if name in logging.Logger.manager.loggerDict: - # NOTE: another less intrusive but more "hackish" solution would be to - # use getLogger then test if its effective level is not default. - logger = logging.getLogger(name) - # make sure level is OK: - logger.setLevel(level) - return logger - # get a new logger: - logger = logging.getLogger(name) - # only add a NullHandler for this logger, it is up to the application - # to configure its own logging: - logger.addHandler(logging.NullHandler()) - logger.setLevel(level) - return logger - # a global logger object used for debugging: -log = get_logger('crypto') +log = log_helper.get_or_create_silent_logger('crypto') + def enable_logging(): """ Enable logging for this module (disabled by default). + + For use by third-party libraries that import `crypto` as module. + This will set the module-specific logger level to NOTSET, which means the main application controls the actual logging level. """ - log.setLevel(logging.NOTSET) + log.setLevel(log_helper.NOTSET) def is_encrypted(some_file): diff --git a/oletools/mraptor.py b/oletools/mraptor.py index 80cfe3516..35bf6ed6d 100644 --- a/oletools/mraptor.py +++ b/oletools/mraptor.py @@ -71,7 +71,7 @@ #--- IMPORTS ------------------------------------------------------------------ -import sys, logging, optparse, re, os +import sys, optparse, re, os # IMPORTANT: it should be possible to run oletools directly as scripts # in any directory without installing them with pip or setup.py. @@ -90,11 +90,12 @@ from oletools import olevba from oletools.olevba import TYPE2TAG +from oletools.common.log_helper import log_helper # === LOGGING ================================================================= # a global logger object used for debugging: -log = olevba.get_logger('mraptor') +log = log_helper.get_or_create_silent_logger('mraptor') #--- CONSTANTS ---------------------------------------------------------------- @@ -230,15 +231,7 @@ def main(): """ Main function, called when olevba is run from the command line """ - global log DEFAULT_LOG_LEVEL = "warning" # Default log level - LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL - } usage = 'usage: mraptor [options] [filename2 ...]' parser = optparse.OptionParser(usage=usage) @@ -272,9 +265,9 @@ def main(): print('MacroRaptor %s - http://decalage.info/python/oletools' % __version__) print('This is work in progress, please report issues at %s' % URL_ISSUES) - logging.basicConfig(level=LOG_LEVELS[options.loglevel], format='%(levelname)-8s %(message)s') + log_helper.enable_logging(level=options.loglevel) # enable logging in the modules: - log.setLevel(logging.NOTSET) + olevba.enable_logging() t = tablestream.TableStream(style=tablestream.TableStyleSlim, header_row=['Result', 'Flags', 'Type', 'File'], @@ -346,6 +339,7 @@ def main(): global_result = result exitcode = result.exit_code + log_helper.end_logging() print('') print('Flags: A=AutoExec, W=Write, X=Execute') print('Exit code: %d - %s' % (exitcode, global_result.name)) diff --git a/oletools/olevba.py b/oletools/olevba.py index bec4ea161..0559c3013 100644 --- a/oletools/olevba.py +++ b/oletools/olevba.py @@ -239,7 +239,6 @@ #------------------------------------------------------------------------------ # TODO: -# + setup logging (common with other oletools) # + add xor bruteforcing like bbharvest # + options -a and -c should imply -d @@ -272,7 +271,6 @@ import traceback import sys import os -import logging import struct from io import BytesIO, StringIO import math @@ -346,6 +344,7 @@ from oletools.common.io_encoding import ensure_stdout_handles_unicode from oletools.common import codepages from oletools import ftguess +from oletools.common.log_helper import log_helper # === PYTHON 2+3 SUPPORT ====================================================== @@ -423,47 +422,29 @@ def bytes2str(bytes_string, encoding='utf8'): # === LOGGING ================================================================= -def get_logger(name, level=logging.CRITICAL+1): - """ - Create a suitable logger object for this module. - The goal is not to change settings of the root logger, to avoid getting - other modules' logs on the screen. - If a logger exists with same name, reuse it. (Else it would have duplicate - handlers and messages would be doubled.) - The level is set to CRITICAL+1 by default, to avoid any logging. - """ - # First, test if there is already a logger with the same name, else it - # will generate duplicate messages (due to duplicate handlers): - if name in logging.Logger.manager.loggerDict: - # NOTE: another less intrusive but more "hackish" solution would be to - # use getLogger then test if its effective level is not default. - logger = logging.getLogger(name) - # make sure level is OK: - logger.setLevel(level) - return logger - # get a new logger: - logger = logging.getLogger(name) - # only add a NullHandler for this logger, it is up to the application - # to configure its own logging: - logger.addHandler(logging.NullHandler()) - logger.setLevel(level) - return logger # a global logger object used for debugging: -log = get_logger('olevba') +log = log_helper.get_or_create_silent_logger('olevba') def enable_logging(): """ Enable logging for this module (disabled by default). - This will set the module-specific logger level to NOTSET, which + + For use by third-party libraries that import `olevba` as module. + + This will set the module-specific logger level to `NOTSET`, which means the main application controls the actual logging level. + + This also enables logging for the modules used by us, but not the global + common logging mechanism (:py:mod:`oletools.common.log_helper.log_helper`). + Use :py:func:`oletools.common.log_helper.log_helper.enable_logging` for + that. """ - log.setLevel(logging.NOTSET) - # Also enable logging in the ppt_parser module: + log.setLevel(log_helper.NOTSET) ppt_parser.enable_logging() crypto.enable_logging() - + # TODO: do not have enable_logging yet: oleform, rtfobj #=== EXCEPTIONS ============================================================== @@ -2462,18 +2443,18 @@ def json2ascii(json_obj, encoding='utf8', errors='replace'): return json_obj -def print_json(json_dict=None, _json_is_first=False, _json_is_last=False, - **json_parts): +def print_json(json_dict=None, _json_is_first=False, **json_parts): """ line-wise print of json.dumps(json2ascii(..)) with options and indent+1 can use in two ways: (1) print_json(some_dict) (2) print_json(key1=value1, key2=value2, ...) - :param bool _json_is_first: set to True only for very first entry to complete - the top-level json-list - :param bool _json_is_last: set to True only for very last entry to complete - the top-level json-list + This is compatible with :py:mod:`oletools.common.log_helper`: log messages + can be mixed if arg `use_json` was `True` in + :py:func:`log_helper.enable_logging` provided this function is called + before the first "regular" logging with `_json_is_first=True` (and + non-empty input) but after log_helper.enable_logging. """ if json_dict and json_parts: raise ValueError('Invalid json argument: want either single dict or ' @@ -2485,18 +2466,18 @@ def print_json(json_dict=None, _json_is_first=False, _json_is_last=False, if json_parts: json_dict = json_parts - if _json_is_first: - print('[') - lines = json.dumps(json2ascii(json_dict), check_circular=False, - indent=4, ensure_ascii=False).splitlines() - for line in lines[:-1]: - print(' {0}'.format(line)) - if _json_is_last: - print(' {0}'.format(lines[-1])) # print last line without comma - print(']') + indent=4, ensure_ascii=False).splitlines() + if not lines: + return + + if _json_is_first: + print(' ' + lines[0]) else: - print(' {0},'.format(lines[-1])) # print last line with comma + print(', ' + lines[0]) + + for line in lines[1:]: + print(' ' + line.rstrip()) class VBA_Scanner(object): @@ -4358,13 +4339,6 @@ def parse_args(cmd_line_args=None): """ parse command line arguments (given ones or per default sys.argv) """ DEFAULT_LOG_LEVEL = "warning" # Default log level - LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL - } usage = 'usage: olevba [options] [filename2 ...]' parser = argparse.ArgumentParser(usage=usage) @@ -4459,8 +4433,6 @@ def parse_args(cmd_line_args=None): if options.show_pcode and options.no_pcode: parser.error('You cannot combine options --no-pcode and --show-pcode') - options.loglevel = LOG_LEVELS[options.loglevel] - return options @@ -4474,6 +4446,8 @@ def process_file(filename, data, container, options, crypto_nesting=0): Returns a single code summarizing the status of processing of this file """ try: + vba_parser = None + # Open the file vba_parser = VBA_Parser_CLI(filename, data=data, container=container, relaxed=options.relaxed, @@ -4501,6 +4475,7 @@ def process_file(filename, data, container, options, crypto_nesting=0): no_xlm=options.no_xlm)) else: # (should be impossible) raise ValueError('unexpected output mode: "{0}"!'.format(options.output_mode)) + vba_parser.close() # even if processing succeeds, file might still be encrypted log.debug('Checking for encryption (normal)') @@ -4508,6 +4483,10 @@ def process_file(filename, data, container, options, crypto_nesting=0): log.debug('no encryption detected') return RETURN_OK except Exception as exc: + log.debug('Caught exception:', exc_info=True) + if vba_parser: + vba_parser.close() + log.debug('Checking for encryption (after exception)') if crypto.is_encrypted(filename): pass # deal with this below @@ -4582,10 +4561,15 @@ def main(cmd_line_args=None): """ options = parse_args(cmd_line_args) + # enable logging in the modules (for json, this prints the opening '['): + log_helper.enable_logging(options.output_mode=='json', options.loglevel, + other_logger_has_first_line=True) + # provide info about tool and its version if options.output_mode == 'json': - # print first json entry with meta info and opening '[' + # print first json entry with meta info print_json(script_name='olevba', version=__version__, + python_version=sys.version_info[0:3], url='http://decalage.info/python/oletools', type='MetaInformation', _json_is_first=True) else: @@ -4594,10 +4578,6 @@ def main(cmd_line_args=None): print('olevba %s on Python %s - http://decalage.info/python/oletools' % (__version__, python_version)) - logging.basicConfig(level=options.loglevel, format='%(levelname)-8s %(message)s') - # enable logging in the modules: - enable_logging() - # with the option --reveal, make sure --deobf is also enabled: if options.show_deobfuscated_code and not options.deobfuscate: log.debug('set --deobf because --reveal was set') @@ -4682,11 +4662,6 @@ def main(cmd_line_args=None): 'A=Auto-executable, S=Suspicious keywords, I=IOCs, H=Hex strings, ' \ 'B=Base64 strings, D=Dridex strings, V=VBA strings, ?=Unknown)\n') - if options.output_mode == 'json': - # print last json entry (a last one without a comma) and closing ] - print_json(type='MetaInformation', return_code=return_code, - n_processed=count, _json_is_last=True) - except crypto.CryptoErrorBase as exc: log.exception('Problems with encryption in main: {}'.format(exc), exc_info=True) @@ -4704,6 +4679,7 @@ def main(cmd_line_args=None): # done. exit log.debug('will exit now with code %s' % return_code) + log_helper.end_logging() sys.exit(return_code) if __name__ == '__main__': diff --git a/oletools/record_base.py b/oletools/record_base.py index 06128b171..9cf1015b9 100644 --- a/oletools/record_base.py +++ b/oletools/record_base.py @@ -42,6 +42,7 @@ # 2018-09-11 v0.54 PL: - olefile is now a dependency # 2019-01-30 PL: - fixed import to avoid mixing installed oletools # and dev version +# 2019-05-24 CH: - use log_helper __version__ = '0.60.dev1' @@ -64,7 +65,6 @@ import sys import os.path from io import SEEK_CUR -import logging import olefile @@ -74,6 +74,7 @@ if PARENT_DIR not in sys.path: sys.path.insert(0, PARENT_DIR) del PARENT_DIR +from oletools.common.log_helper import log_helper ############################################################################### @@ -100,11 +101,26 @@ } +logger = log_helper.get_or_create_silent_logger('record_base') + + def enable_olefile_logging(): - """ enable logging olefile e.g., to get debug info from OleFileIO """ + """ enable logging in olefile e.g., to get debug info from OleFileIO """ olefile.enable_logging() +def enable_logging(): + """ + Enable logging for this module (disabled by default). + + For use by third-party libraries that import `record_base` as module. + + This will set the module-specific logger level to NOTSET, which + means the main application controls the actual logging level. + """ + logger.setLevel(log_helper.NOTSET) + + ############################################################################### # Base Classes ############################################################################### @@ -139,7 +155,7 @@ def stream_class_for_name(cls, stream_name): def iter_streams(self): """ find all streams, including orphans """ - logging.debug('Finding streams in ole file') + logger.debug('Finding streams in ole file') for sid, direntry in enumerate(self.direntries): is_orphan = direntry is None @@ -147,7 +163,7 @@ def iter_streams(self): # this direntry is not part of the tree --> unused or orphan direntry = self._load_direntry(sid) is_stream = direntry.entry_type == olefile.STGTY_STREAM - logging.debug('direntry {:2d} {}: {}'.format( + logger.debug('direntry {:2d} {}: {}'.format( sid, '[orphan]' if is_orphan else direntry.name, 'is stream of size {}'.format(direntry.size) if is_stream else 'no stream ({})'.format(ENTRY_TYPE2STR[direntry.entry_type]))) @@ -216,8 +232,8 @@ def iter_records(self, fill_data=False): # read first few bytes, determine record type and size rec_type, rec_size, other = self.read_record_head() - # logging.debug('Record type {0} of size {1}' - # .format(rec_type, rec_size)) + # logger.debug('Record type {0} of size {1}' + # .format(rec_type, rec_size)) # determine what class to wrap this into rec_clz, force_read = self.record_class_for_type(rec_type) @@ -237,6 +253,7 @@ def iter_records(self, fill_data=False): yield rec_object def close(self): + """Close this stream (i.e. the stream given in constructor).""" self.stream.close() def __str__(self): @@ -348,25 +365,25 @@ def test(filenames, ole_file_class=OleRecordFile, if an error occurs while parsing a stream of type in must_parse, the error will be raised. Otherwise a message is printed """ - logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO) + log_helper.enable_logging(False, 'debug' if verbose else 'info') if do_per_record is None: def do_per_record(record): # pylint: disable=function-redefined pass # do nothing if not filenames: - logging.info('need file name[s]') + logger.info('need file name[s]') return 2 for filename in filenames: - logging.info('checking file {0}'.format(filename)) + logger.info('checking file {0}'.format(filename)) if not olefile.isOleFile(filename): - logging.info('not an ole file - skip') + logger.info('not an ole file - skip') continue ole = ole_file_class(filename) for stream in ole.iter_streams(): - logging.info(' parse ' + str(stream)) + logger.info(' parse ' + str(stream)) try: for record in stream.iter_records(): - logging.info(' ' + str(record)) + logger.info(' ' + str(record)) do_per_record(record) except Exception: if not must_parse: @@ -374,7 +391,9 @@ def do_per_record(record): # pylint: disable=function-redefined elif isinstance(stream, must_parse): raise else: - logging.info(' failed to parse', exc_info=True) + logger.info(' failed to parse', exc_info=True) + + log_helper.end_logging() return 0 diff --git a/tests/common/log_helper/log_helper_test_imported.py b/tests/common/log_helper/log_helper_test_imported.py index 8820a3e36..1be8181a5 100644 --- a/tests/common/log_helper/log_helper_test_imported.py +++ b/tests/common/log_helper/log_helper_test_imported.py @@ -4,7 +4,6 @@ """ from oletools.common.log_helper import log_helper -import logging DEBUG_MESSAGE = 'imported: debug log' INFO_MESSAGE = 'imported: info log' @@ -14,7 +13,11 @@ RESULT_MESSAGE = 'imported: result log' RESULT_TYPE = 'imported: result' -logger = log_helper.get_or_create_silent_logger('test_imported', logging.ERROR) +logger = log_helper.get_or_create_silent_logger('test_imported') + +def enable_logging(): + """Enable logging if imported by third party modules.""" + logger.setLevel(log_helper.NOTSET) def log(): diff --git a/tests/common/log_helper/log_helper_test_main.py b/tests/common/log_helper/log_helper_test_main.py index fb0ccca2f..c82f9bcf8 100644 --- a/tests/common/log_helper/log_helper_test_main.py +++ b/tests/common/log_helper/log_helper_test_main.py @@ -1,6 +1,7 @@ """ Test log_helpers """ import sys +import logging from tests.common.log_helper import log_helper_test_imported from oletools.common.log_helper import log_helper @@ -15,7 +16,13 @@ logger = log_helper.get_or_create_silent_logger('test_main') -def init_logging_and_log(args): +def enable_logging(): + """Enable logging if imported by third party modules.""" + logger.setLevel(log_helper.NOTSET) + log_helper_test_imported.enable_logging() + + +def main(args): """ Try to cover possible logging scenarios. For each scenario covered, here's the expected args and outcome: - Log without enabling: [''] @@ -36,13 +43,12 @@ def init_logging_and_log(args): throw = 'throw' in args percent_autoformat = '%-autoformat' in args + log_helper_test_imported.logger.setLevel(logging.ERROR) + if 'enable' in args: log_helper.enable_logging(use_json, level, stream=sys.stdout) - _log() - - if percent_autoformat: - logger.info('The %s is %d.', 'answer', 47) + do_log(percent_autoformat) if throw: raise Exception('An exception occurred before ending the logging') @@ -50,7 +56,10 @@ def init_logging_and_log(args): log_helper.end_logging() -def _log(): +def do_log(percent_autoformat=False): + if percent_autoformat: + logger.info('The %s is %d.', 'answer', 47) + logger.debug(DEBUG_MESSAGE) logger.info(INFO_MESSAGE) logger.warning(WARNING_MESSAGE) @@ -61,4 +70,4 @@ def _log(): if __name__ == '__main__': - init_logging_and_log(sys.argv[1:]) + main(sys.argv[1:]) diff --git a/tests/common/log_helper/test_log_helper.py b/tests/common/log_helper/test_log_helper.py index bcd0de0f4..f9b20a08b 100644 --- a/tests/common/log_helper/test_log_helper.py +++ b/tests/common/log_helper/test_log_helper.py @@ -15,18 +15,16 @@ from tests.test_utils import PROJECT_ROOT -# this is the common base of "tests" and "oletools" dirs +# test file we use as "main" module TEST_FILE = relpath(join(dirname(abspath(__file__)), 'log_helper_test_main.py'), PROJECT_ROOT) -PYTHON_EXECUTABLE = sys.executable -MAIN_LOG_MESSAGES = [ - log_helper_test_main.DEBUG_MESSAGE, - log_helper_test_main.INFO_MESSAGE, - log_helper_test_main.WARNING_MESSAGE, - log_helper_test_main.ERROR_MESSAGE, - log_helper_test_main.CRITICAL_MESSAGE -] +# test file simulating a third party main module that only imports oletools +TEST_FILE_3RD_PARTY = relpath(join(dirname(abspath(__file__)), + 'third_party_importer.py'), + PROJECT_ROOT) + +PYTHON_EXECUTABLE = sys.executable class TestLogHelper(unittest.TestCase): @@ -127,6 +125,22 @@ def test_json_correct_on_exceptions(self): log_helper_test_imported.CRITICAL_MESSAGE ]) + def test_import_by_third_party_disabled(self): + """Test that when imported by third party, logging is still disabled.""" + output = self._run_test([], run_third_party=True).splitlines() + self.assertEqual(len(output), 2) + self.assertEqual(output[0], + 'INFO:root:Start message from 3rd party importer') + self.assertEqual(output[1], + 'INFO:root:End message from 3rd party importer') + + def test_import_by_third_party_enabled(self): + """Test that when imported by third party, logging can be enabled.""" + output = self._run_test(['enable', ], run_third_party=True).splitlines() + self.assertEqual(len(output), 12) + self.assertIn('INFO:test_main:main: info log', output) + self.assertIn('INFO:test_imported:imported: info log', output) + def _assert_json_messages(self, output, messages): try: json_data = json.loads(output) @@ -139,14 +153,24 @@ def _assert_json_messages(self, output, messages): self.assertNotEqual(len(json_data), 0, msg='Output was empty') - def _run_test(self, args, should_succeed=True): + def _run_test(self, args, should_succeed=True, run_third_party=False): """ Use subprocess to better simulate the real scenario and avoid logging conflicts when running multiple tests (since logging depends on singletons, we might get errors or false positives between sequential tests runs) + + When arg `run_third_party` is `True`, we do not run the `TEST_FILE` as + main moduel but the `TEST_FILE_3RD_PARTY` and return contents of + `stderr` instead of `stdout`. """ + all_args = [PYTHON_EXECUTABLE, ] + if run_third_party: + all_args.append(TEST_FILE_3RD_PARTY) + else: + all_args.append(TEST_FILE) + all_args.extend(args) child = subprocess.Popen( - [PYTHON_EXECUTABLE, TEST_FILE] + args, + all_args, shell=False, env={'PYTHONPATH': PROJECT_ROOT}, universal_newlines=True, @@ -157,6 +181,16 @@ def _run_test(self, args, should_succeed=True): ) (output, output_err) = child.communicate() + if False: # DEBUG + print() + for line in output_err.splitlines(): + print('ERR: {}'.format(line.rstrip())) + for line in output.splitlines(): + print('OUT: {}'.format(line.rstrip())) + + if run_third_party: + output = output_err + if not isinstance(output, str): output = output.decode('utf-8') diff --git a/tests/common/log_helper/third_party_importer.py b/tests/common/log_helper/third_party_importer.py new file mode 100644 index 000000000..16c29ad0d --- /dev/null +++ b/tests/common/log_helper/third_party_importer.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +""" +Module for testing import of common logging modules by third party modules. + +This module behaves like a third party module. It does not use the common +logging and enables logging on its own. But it imports log_helper_test_main. +""" + +import sys +import logging + +from tests.common.log_helper import log_helper_test_main + + +def main(args): + """ + Main function, called when running file as script + + see module doc for more info + """ + logging.basicConfig(level=logging.INFO) + if 'enable' in args: + log_helper_test_main.enable_logging() + + logging.debug('Should not show.') + logging.info('Start message from 3rd party importer') + + log_helper_test_main.do_log() + + logging.debug('Returning 0, but you will never see that ... .') + logging.info('End message from 3rd party importer') + return 0 + + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/tests/olevba/test_basic.py b/tests/olevba/test_basic.py index ef5ed2685..b5a8b7759 100644 --- a/tests/olevba/test_basic.py +++ b/tests/olevba/test_basic.py @@ -76,14 +76,14 @@ def test_rtf_behaviour(self): def test_crypt_return(self): """ - Tests that encrypted files give a certain return code. + Test that encrypted files give a certain return code. Currently, only the encryption applied by Office 2010 (CryptoApi RC4 Encryption) is tested. """ CRYPT_DIR = join(DATA_BASE_DIR, 'encrypted') CRYPT_RETURN_CODE = 9 - ADD_ARGS = [], ['-d', ], ['-a', ], ['-j', ], ['-t', ] + ADD_ARGS = [], ['-d', ], ['-a', ], ['-j', ], ['-t', ] # only 1st file EXCEPTIONS = ['autostart-encrypt-standardpassword.xls', # These ... 'autostart-encrypt-standardpassword.xlsm', # files ... 'autostart-encrypt-standardpassword.xlsb', # are ... @@ -103,6 +103,10 @@ def test_crypt_return(self): msg='Wrong return code {} for args {}'\ .format(ret_code, args + [filename, ])) + # test only first file with all arg combinations, others just + # without arg (test takes too long otherwise + ADD_ARGS = ([], ) + # just in case somebody calls this file as a script if __name__ == '__main__': diff --git a/tests/olevba/test_crypto.py b/tests/olevba/test_crypto.py index aad78df35..06000610d 100644 --- a/tests/olevba/test_crypto.py +++ b/tests/olevba/test_crypto.py @@ -40,26 +40,35 @@ def test_autostart(self): exclude_stderr=True) data = json.loads(output, object_pairs_hook=OrderedDict) # debug: json.dump(data, sys.stdout, indent=4) - self.assertEqual(len(data), 4) + self.assertIn(len(data), (3, 4)) + + # first 2 parts: general info about script and file self.assertIn('script_name', data[0]) self.assertIn('version', data[0]) self.assertEqual(data[0]['type'], 'MetaInformation') - self.assertIn('return_code', data[-1]) - self.assertEqual(data[-1]['type'], 'MetaInformation') self.assertEqual(data[1]['container'], None) self.assertEqual(data[1]['file'], example_file) self.assertEqual(data[1]['analysis'], None) self.assertEqual(data[1]['macros'], []) self.assertEqual(data[1]['type'], 'OLE') - self.assertEqual(data[2]['container'], example_file) - self.assertNotEqual(data[2]['file'], example_file) - self.assertEqual(data[2]['type'], "OpenXML") - analysis = data[2]['analysis'] + self.assertTrue(data[1]['json_conversion_successful']) + + # possible VBA stomping warning + if len(data) == 4: + self.assertEqual(data[2]['type'], 'msg') + self.assertIn('VBA stomping', data[2]['msg']) + + # last part is the actual result + self.assertEqual(data[-1]['container'], example_file) + self.assertNotEqual(data[-1]['file'], example_file) + self.assertEqual(data[-1]['type'], "OpenXML") + analysis = data[-1]['analysis'] self.assertEqual(analysis[0]['type'], 'AutoExec') self.assertEqual(analysis[0]['keyword'], 'Auto_Open') - macros = data[2]['macros'] + macros = data[-1]['macros'] self.assertEqual(macros[0]['vba_filename'], 'Modul1.bas') self.assertIn('Sub Auto_Open()', macros[0]['code']) + self.assertTrue(data[-1]['json_conversion_successful']) if __name__ == '__main__':