From eafe6ae7c0eceec78a1633a9a3c857c9b82459c1 Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Fri, 17 Jan 2025 21:49:03 +0000 Subject: [PATCH 1/3] [Frontend] Generalise SourceCodeFile class Signed-off-by: Arthur Chan --- src/fuzz_introspector/frontends/datatypes.py | 77 +++++++++++- src/fuzz_introspector/frontends/frontend_c.py | 64 ++++------ .../frontends/frontend_cpp.py | 59 ++++----- .../frontends/frontend_go.py | 58 +++------ .../frontends/frontend_jvm.py | 66 ++++------ .../frontends/frontend_rust.py | 66 ++++------ src/fuzz_introspector/frontends/oss_fuzz.py | 118 +++--------------- 7 files changed, 202 insertions(+), 306 deletions(-) diff --git a/src/fuzz_introspector/frontends/datatypes.py b/src/fuzz_introspector/frontends/datatypes.py index a6469ddb..6abce25d 100644 --- a/src/fuzz_introspector/frontends/datatypes.py +++ b/src/fuzz_introspector/frontends/datatypes.py @@ -14,13 +14,78 @@ # ################################################################################ -from typing import Any, Optional +from tree_sitter import Language, Parser +import tree_sitter_c +import tree_sitter_cpp +import tree_sitter_go +import tree_sitter_java +import tree_sitter_rust +import logging -class Project(): +logger = logging.getLogger(name=__name__) + +from typing import Any, Optional, Generic, TypeVar + +T = TypeVar('T', bound='SourceCodeFile') + + +class SourceCodeFile(): + """Class for holding file-specific information.""" + LANGUAGE: dict[str, Language] = { + 'c': Language(tree_sitter_c.language()), + 'cpp': Language(tree_sitter_cpp.language()), + 'c++': Language(tree_sitter_cpp.language()), + 'go': Language(tree_sitter_go.language()), + 'jvm': Language(tree_sitter_java.language()), + 'rust': Language(tree_sitter_rust.language()), + } + + def __init__(self, + language: str, + source_file: str, + entrypoint: str = '', + source_content: Optional[bytes] = None): + logger.info('Processing %s' % source_file) + + self.root = None + self.source_file = source_file + self.language = language + self.entrypoint = entrypoint + self.tree_sitter_lang = self.LANGUAGE.get(language) + self.parser = Parser(self.tree_sitter_lang) + + if source_content: + self.source_content = source_content + else: + with open(self.source_file, 'rb') as f: + self.source_content = f.read() + + # Initialization ruotines + self.load_tree() + + # Language specific process + self.language_specific_process() + + def load_tree(self): + """Load the the source code into a treesitter tree, and set + the root node.""" + if not self.root: + self.root = self.parser.parse(self.source_content).root_node + + def language_specific_process(self): + """Dummy function to perform some specific processes in subclasses.""" + pass + + def has_libfuzzer_harness(self) -> bool: + """Dummy function for source code files.""" + return False + + +class Project(Generic[T]): """Wrapper for doing analysis of a collection of source files.""" - def __init__(self, source_code_files: list[Any]): + def __init__(self, source_code_files: list[T]): self.source_code_files = source_code_files def dump_module_logic(self, @@ -35,7 +100,7 @@ def dump_module_logic(self, def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[T] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -48,14 +113,14 @@ def extract_calltree(self, def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[T] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" # Dummy function for subclasses return set() - def get_source_codes_with_harnesses(self) -> list[Any]: + def get_source_codes_with_harnesses(self) -> list[T]: """Gets the source codes that holds libfuzzer harnesses.""" harnesses = [] for source_code in self.source_code_files: diff --git a/src/fuzz_introspector/frontends/frontend_c.py b/src/fuzz_introspector/frontends/frontend_c.py index 98145f32..e3932234 100644 --- a/src/fuzz_introspector/frontends/frontend_c.py +++ b/src/fuzz_introspector/frontends/frontend_c.py @@ -20,23 +20,21 @@ import logging from tree_sitter import Language, Parser -import tree_sitter_c import yaml -from typing import Any, Optional, Set +from typing import Any, Optional -from fuzz_introspector.frontends.datatypes import Project +from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) -tree_sitter_languages = {'c': Language(tree_sitter_c.language())} -language_parsers = {'c': Parser(Language(tree_sitter_c.language()))} - - -class CProject(Project): +class CProject(Project['CSourceCodeFile']): """Wrapper for doing analysis of a collection of source files.""" + def __init__(self, source_code_files: list['CSourceCodeFile']): + super().__init__(source_code_files) + def dump_module_logic(self, report_name, entry_function: str = '', @@ -86,8 +84,8 @@ def dump_module_logic(self, 'functionLinenumberEnd'] = func_def.root.end_point.row func_dict['linkageType'] = '' func_dict['func_position'] = { - 'start': source_code.root.start_point.row, - 'end': source_code.root.end_point.row, + 'start': func_def.root.start_point.row, + 'end': func_def.root.end_point.row, } cc_str = 'CyclomaticComplexity' func_dict[cc_str] = func_def.get_function_complexity() @@ -130,9 +128,12 @@ def get_source_code_with_target(self, target_func_name): return source_code return None + def get_source_codes_with_harnesses(self) -> list['CSourceCodeFile']: + return super().get_source_codes_with_harnesses() + def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional['CSourceCodeFile'] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -182,9 +183,9 @@ def extract_calltree(self, def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional['CSourceCodeFile'] = None, function: Optional[str] = None, - visited_functions: Optional[set[str]] = None) -> Set[str]: + visited_functions: Optional[set[str]] = None) -> set[str]: """Gets the reachable frunctions from a given function.""" # Create calltree from a given function # Find the function in the source code @@ -456,28 +457,17 @@ def callsites(self): return callsites -class SourceCodeFile(): +class CSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def __init__(self, source_file, language, source_content=""): - self.source_file = source_file - self.language = language - self.parser = language_parsers.get(self.language) - self.tree_sitter_lang = tree_sitter_languages[self.language] - - self.root = None + def language_specific_process(self): + """Perform some language specific processes in subclasses.""" self.function_names = [] self.line_range_pairs = [] self.struct_defs = [] self.typedefs = [] self.includes = set() - if source_content: - self.source_content = source_content - else: - with open(self.source_file, 'rb') as f: - self.source_content = f.read() - # List of function definitions in the source file. self.func_defs = [] @@ -488,12 +478,6 @@ def __init__(self, source_file, language, source_content=""): self._set_function_defintions() self.extract_types() - def load_tree(self) -> None: - """Load the the source code into a treesitter tree, and set - the root node.""" - if self.language == 'c' and not self.root: - self.root = self.parser.parse(self.source_content).root_node - def extract_types(self): """Extracts the types of the source code""" # Extract all structs @@ -640,7 +624,7 @@ def get_linenumber(self, bytepos): def load_treesitter_trees(source_files: list[str], - is_log: bool = True) -> list[SourceCodeFile]: + is_log: bool = True) -> CProject: """Creates treesitter trees for all files in a given list of source files.""" results = [] @@ -648,7 +632,7 @@ def load_treesitter_trees(source_files: list[str], if not os.path.isfile(code_file): continue - source_cls = SourceCodeFile(code_file, 'c') + source_cls = CSourceCodeFile('c', code_file) if is_log: if source_cls.has_libfuzzer_harness(): @@ -656,12 +640,12 @@ def load_treesitter_trees(source_files: list[str], results.append(source_cls) - return results + return CProject(results) -def analyse_source_code(source_content: str) -> SourceCodeFile: +def analyse_source_code(source_content: str) -> CSourceCodeFile: """Returns a source abstraction based on a single source string.""" - source_code = SourceCodeFile(source_file='in-memory string', - language='c', - source_content=source_content.encode()) + source_code = CSourceCodeFile('c', + source_file='in-memory string', + source_content=source_content.encode()) return source_code diff --git a/src/fuzz_introspector/frontends/frontend_cpp.py b/src/fuzz_introspector/frontends/frontend_cpp.py index 97c3484a..d2711581 100644 --- a/src/fuzz_introspector/frontends/frontend_cpp.py +++ b/src/fuzz_introspector/frontends/frontend_cpp.py @@ -23,44 +23,23 @@ import tree_sitter_cpp import yaml -from fuzz_introspector.frontends.datatypes import Project +from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) LOG_FMT = '%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s' -class SourceCodeFile(): +class CppSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def __init__(self, - source_file: str, - source_content: Optional[bytes] = None): - logger.info('Processing %s', source_file) - - self.source_file = source_file - self.tree_sitter_lang = Language(tree_sitter_cpp.language()) - self.parser = Parser(self.tree_sitter_lang) - - self.root = None + def language_specific_process(self): + """Function to perform some language specific processes in subclasses.""" self.func_defs: list['FunctionDefinition'] = [] - - if source_content: - self.source_content = source_content - else: - with open(self.source_file, 'rb') as f: - self.source_content = f.read() - if self.source_content: # Initialization routines self.load_tree() self.process_tree(self.root) - def load_tree(self): - """Load the the source code into a treesitter tree, and set - the root node.""" - if not self.root: - self.root = self.parser.parse(self.source_content).root_node - def process_tree(self, node: Node, namespace: str = ''): """Process the node from the parsed tree.""" for child in node.children: @@ -139,7 +118,7 @@ class FunctionDefinition(): """Wrapper for a function definition""" def __init__(self, root: Node, tree_sitter_lang: Language, - source_code: 'SourceCodeFile', namespace: str): + source_code: CppSourceCodeFile, namespace: str): self.root = root self.tree_sitter_lang = tree_sitter_lang self.parent_source = source_code @@ -572,10 +551,10 @@ def extract_callsites(self, project): self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) -class CppProject(Project): +class CppProject(Project[CppSourceCodeFile]): """Wrapper for doing analysis of a collection of source files.""" - def __init__(self, source_code_files: list[SourceCodeFile]): + def __init__(self, source_code_files: list[CppSourceCodeFile]): super().__init__(source_code_files) self.all_functions: list[FunctionDefinition] = [] @@ -665,7 +644,7 @@ def get_function_from_name(self, function_name): def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[CppSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -744,7 +723,7 @@ def extract_calltree(self, def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[CppSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Gets the reachable frunctions from a given function.""" @@ -809,9 +788,12 @@ def find_function_from_approximate_name( return None + def get_source_codes_with_harnesses(self) -> list[CppSourceCodeFile]: + return super().get_source_codes_with_harnesses() + def find_source_with_func_def( - self, - name: str) -> Optional[tuple[SourceCodeFile, FunctionDefinition]]: + self, name: str + ) -> Optional[tuple[CppSourceCodeFile, FunctionDefinition]]: """Finds the source code with a given function.""" return_func = None @@ -894,7 +876,7 @@ def _recursive_function_depth(function: FunctionDefinition) -> int: return func_depth -def load_treesitter_trees(source_files, is_log=True): +def load_treesitter_trees(source_files, is_log=True) -> CppProject: """Creates treesitter trees for all files in a given list of source files.""" results = [] @@ -902,20 +884,21 @@ def load_treesitter_trees(source_files, is_log=True): if not os.path.isfile(code_file): continue - source_cls = SourceCodeFile(code_file) + source_cls = CppSourceCodeFile('c++', code_file) results.append(source_cls) if is_log: if source_cls.has_libfuzzer_harness(): logger.info('harness: %s', code_file) - return results + return CppProject(results) -def analyse_source_code(source_content: str) -> SourceCodeFile: +def analyse_source_code(source_content: str) -> CppSourceCodeFile: """Returns a source abstraction based on a single source string.""" - source_code = SourceCodeFile(source_file='in-memory string', - source_content=source_content.encode()) + source_code = CppSourceCodeFile('c++', + source_file='in-memory string', + source_content=source_content.encode()) return source_code diff --git a/src/fuzz_introspector/frontends/frontend_go.py b/src/fuzz_introspector/frontends/frontend_go.py index a5b609bc..45cf592b 100644 --- a/src/fuzz_introspector/frontends/frontend_go.py +++ b/src/fuzz_introspector/frontends/frontend_go.py @@ -23,7 +23,7 @@ import tree_sitter_go import yaml -from fuzz_introspector.frontends.datatypes import Project +from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) @@ -39,33 +39,16 @@ } -class SourceCodeFile(): +class GoSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def __init__(self, - source_file: str, - source_content: Optional[bytes] = None): - logger.info('Processing %s', source_file) - - self.root = None + def language_specific_process(self): + """Perform some language specific processes in subclasses.""" self.imports: list[str] = [] - self.source_file = source_file - self.tree_sitter_lang = Language(tree_sitter_go.language()) - self.parser = Parser(self.tree_sitter_lang) - - if source_content: - self.source_content = source_content - else: - with open(self.source_file, 'rb') as f: - self.source_content = f.read() - # List of function definitions in the source file. self.functions: list['FunctionMethod'] = [] self.methods: list['FunctionMethod'] = [] - # Initialization ruotines - self.load_tree() - # Load function/method declaration self._set_function_declaration() self._set_method_declaration() @@ -73,11 +56,6 @@ def __init__(self, # Parse import package self._set_imports() - def load_tree(self): - """Load the the source code into a treesitter tree, and set - the root node.""" - self.root = self.parser.parse(self.source_content).root_node - def _set_function_declaration(self): """Internal helper for retrieving all functions.""" func_query_str = '( function_declaration ) @fd ' @@ -178,10 +156,10 @@ def get_entry_function_name(self) -> str: return '' -class GoProject(Project): +class GoProject(Project[GoSourceCodeFile]): """Wrapper for doing analysis of a collection of source files.""" - def __init__(self, source_code_files: list[SourceCodeFile]): + def __init__(self, source_code_files: list[GoSourceCodeFile]): super().__init__(source_code_files) full_functions_methods = [ @@ -274,7 +252,7 @@ def dump_module_logic(self, def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[GoSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -324,10 +302,13 @@ def extract_calltree(self, line_number=line_number) return line_to_print + def get_source_codes_with_harnesses(self) -> list[GoSourceCodeFile]: + return super().get_source_codes_with_harnesses() + def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[GoSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" @@ -362,7 +343,7 @@ def get_reachable_functions( return visited_functions def find_source_with_func_def( - self, target_function_name: str) -> Optional[SourceCodeFile]: + self, target_function_name: str) -> Optional[GoSourceCodeFile]: """Finds the source code with a given function.""" for source_code in self.source_code_files: if source_code.has_function_definition(target_function_name): @@ -375,7 +356,7 @@ class FunctionMethod(): """Wrapper for a General Declaration for function/method""" def __init__(self, root: Node, tree_sitter_lang: Language, - source_code: SourceCodeFile, is_function: bool): + source_code: GoSourceCodeFile, is_function: bool): self.root = root self.tree_sitter_lang = tree_sitter_lang self.parent_source = source_code @@ -782,22 +763,23 @@ def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']): def load_treesitter_trees(source_files: list[str], - is_log: bool = True) -> list[SourceCodeFile]: + is_log: bool = True) -> GoProject: """Creates treesitter trees for all files in a given list of source files.""" results = [] for code_file in source_files: - source_cls = SourceCodeFile(code_file) + source_cls = GoSourceCodeFile('go', code_file) if is_log: if source_cls.has_libfuzzer_harness(): logger.info('harness: %s', code_file) results.append(source_cls) - return results + return GoProject(results) -def analyse_source_code(source_content: str) -> SourceCodeFile: +def analyse_source_code(source_content: str) -> GoSourceCodeFile: """Returns a source abstraction based on a single source string.""" - source_code = SourceCodeFile(source_file='in-memory string', - source_content=source_content.encode()) + source_code = GoSourceCodeFile('go', + source_file='in-memory string', + source_content=source_content.encode()) return source_code diff --git a/src/fuzz_introspector/frontends/frontend_jvm.py b/src/fuzz_introspector/frontends/frontend_jvm.py index 36ad2e99..473ed176 100644 --- a/src/fuzz_introspector/frontends/frontend_jvm.py +++ b/src/fuzz_introspector/frontends/frontend_jvm.py @@ -23,7 +23,7 @@ import tree_sitter_java import yaml -from fuzz_introspector.frontends.datatypes import Project +from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) @@ -69,27 +69,11 @@ } -class SourceCodeFile(): +class JvmSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def __init__(self, - source_file: str, - entrypoint: str = 'fuzzerTestOneInput', - source_content: Optional[bytes] = None): - logger.info('Processing %s' % source_file) - - self.root = None - self.source_file = source_file - self.entrypoint = entrypoint - self.tree_sitter_lang = Language(tree_sitter_java.language()) - self.parser = Parser(self.tree_sitter_lang) - - if source_content: - self.source_content = source_content - else: - with open(self.source_file, 'rb') as f: - self.source_content = f.read() - + def language_specific_process(self): + """Perform some language specific processes in subclasses.""" # List of definitions in the source file. self.package = '' self.classes: list['JavaClassInterface'] = [] @@ -107,11 +91,6 @@ def __init__(self, # Load import statements self._set_import_declaration() - def load_tree(self): - """Load the the source code into a treesitter tree, and set - the root node.""" - self.root = self.parser.parse(self.source_content).root_node - def post_process_imports(self, classes: list['JavaClassInterface']): """Add in full qualified name for classes in projects.""" for cls in classes: @@ -247,7 +226,7 @@ def __init__(self, self.class_interface = class_interface self.tree_sitter_lang = self.class_interface.tree_sitter_lang self.parent_source: Optional[ - SourceCodeFile] = self.class_interface.parent_source + JvmSourceCodeFile] = self.class_interface.parent_source self.is_constructor = is_constructor self.is_default_constructor = is_default_constructor self.name: str = '' @@ -813,7 +792,7 @@ class JavaClassInterface(): def __init__(self, root: Node, tree_sitter_lang: Language, - source_code: SourceCodeFile, + source_code: JvmSourceCodeFile, parent: Optional['JavaClassInterface'] = None): self.root = root self.parent = parent @@ -1008,10 +987,10 @@ def has_method_definition( return False, None -class JvmProject(Project): +class JvmProject(Project[JvmSourceCodeFile]): """Wrapper for doing analysis of a collection of source files.""" - def __init__(self, source_code_files: list[SourceCodeFile]): + def __init__(self, source_code_files: list[JvmSourceCodeFile]): super().__init__(source_code_files) self.all_classes = [] for source_code in self.source_code_files: @@ -1045,9 +1024,9 @@ def dump_module_logic(self, # Log entry method if provided if harness_name and source_code.has_class(harness_name): - entry_function = source_code.get_entry_method_name(True) - if entry_function: - report['Fuzzing method'] = entry_function + entry_func = source_code.get_entry_method_name(True) + if entry_func: + report['Fuzzing method'] = entry_func # Retrieve full proejct methods and information methods = source_code.get_all_methods() @@ -1128,7 +1107,8 @@ def dump_module_logic(self, with open(report_name, 'w', encoding='utf-8') as f: f.write(yaml.dump(report)) - def find_source_with_method(self, name: str) -> Optional[SourceCodeFile]: + def find_source_with_method(self, + name: str) -> Optional[JvmSourceCodeFile]: """Finds the source code with a given method name.""" for source_code in self.source_code_files: if source_code.has_method_definition(name): @@ -1181,7 +1161,7 @@ def _recursive_method_depth(method: JavaMethod) -> int: def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[JvmSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -1231,10 +1211,13 @@ def extract_calltree(self, return line_to_print + def get_source_codes_with_harnesses(self) -> list[JvmSourceCodeFile]: + return super().get_source_codes_with_harnesses() + def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[JvmSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" @@ -1272,23 +1255,24 @@ def get_reachable_functions( def load_treesitter_trees(source_files: list[str], entrypoint: str, - is_log: bool = True) -> list[SourceCodeFile]: + is_log: bool = True) -> JvmProject: """Creates treesitter trees for all files in a given list of source files.""" results = [] for code_file in source_files: - source_cls = SourceCodeFile(code_file, entrypoint) + source_cls = JvmSourceCodeFile('jvm', code_file, entrypoint) if is_log: if source_cls.has_libfuzzer_harness(): logger.info('harness: %s', code_file) results.append(source_cls) - return results + return JvmProject(results) def analyse_source_code(source_content: str, - entrypoint: str) -> SourceCodeFile: + entrypoint: str) -> JvmSourceCodeFile: """Returns a source abstraction based on a single source string.""" - source_code = SourceCodeFile(source_file='in-memory string', - source_content=source_content.encode()) + source_code = JvmSourceCodeFile('jvm', + source_file='in-memory string', + source_content=source_content.encode()) return source_code diff --git a/src/fuzz_introspector/frontends/frontend_rust.py b/src/fuzz_introspector/frontends/frontend_rust.py index b6cacf76..0348972e 100644 --- a/src/fuzz_introspector/frontends/frontend_rust.py +++ b/src/fuzz_introspector/frontends/frontend_rust.py @@ -23,47 +23,25 @@ import tree_sitter_rust import yaml -from fuzz_introspector.frontends.datatypes import Project +from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) LOG_FMT = '%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s' -class SourceCodeFile(): +class RustSourceCodeFile(SourceCodeFile): """Class for holding file-specific information.""" - def __init__(self, - source_file: str, - source_content: Optional[bytes] = None): - logger.info('Processing %s' % source_file) - - self.root = None - self.entrypoint = None - self.source_file = source_file - self.tree_sitter_lang = Language(tree_sitter_rust.language()) - self.parser = Parser(self.tree_sitter_lang) + def language_specific_process(self): + """Perform some language specific processes in subclasses.""" self.uses: dict[str, str] = {} - if source_content: - self.source_content = source_content - else: - with open(self.source_file, 'rb') as f: - self.source_content = f.read() - # Definition initialisation self.functions: list['RustFunction'] = [] - # Initialization ruotines - self.load_tree() - # Load functions/methods delcaration self._set_function_method_declaration(self.root) - def load_tree(self): - """Load the the source code into a treesitter tree, and set - the root node.""" - self.root = self.parser.parse(self.source_content).root_node - def _set_function_method_declaration(self, start_object: Node, start_prefix: list[str] = []): @@ -239,7 +217,7 @@ class RustFunction(): def __init__(self, root: Node, tree_sitter_lang: Language, - source_code: SourceCodeFile, + source_code: RustSourceCodeFile, prefix: list[str], is_macro: bool = False): self.root = root @@ -587,9 +565,12 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) -class RustProject(Project): +class RustProject(Project[RustSourceCodeFile]): """Wrapper for doing analysis of a collection of source files.""" + def __init__(self, source_code_files: list[RustSourceCodeFile]): + self.source_code_files = source_code_files + def dump_module_logic(self, report_name: str, entry_function: str = '', @@ -677,7 +658,7 @@ def dump_module_logic(self, f.write(yaml.dump(report)) def _find_source_with_function(self, - name: str) -> Optional[SourceCodeFile]: + name: str) -> Optional[RustSourceCodeFile]: """Finds the source code with a given function name.""" for source_code in self.source_code_files: if get_function_node( @@ -735,7 +716,7 @@ def _recursive_function_depth(function: RustFunction) -> int: def extract_calltree(self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[RustSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None, depth: int = 0, @@ -803,19 +784,13 @@ def extract_calltree(self, return line_to_print - def get_source_codes_with_harnesses(self) -> list[SourceCodeFile]: - """Gets the source codes that holds libfuzzer harnesses.""" - harnesses = [] - for source_code in self.source_code_files: - if source_code.has_libfuzzer_harness(): - harnesses.append(source_code) - - return harnesses + def get_source_codes_with_harnesses(self) -> list[RustSourceCodeFile]: + return super().get_source_codes_with_harnesses() def get_reachable_functions( self, source_file: str = '', - source_code: Optional[Any] = None, + source_code: Optional[RustSourceCodeFile] = None, function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" @@ -858,25 +833,26 @@ def get_reachable_functions( def load_treesitter_trees(source_files: list[str], - is_log: bool = True) -> list[SourceCodeFile]: + is_log: bool = True) -> RustProject: """Creates treesitter trees for all files in a given list of source files.""" results = [] for code_file in source_files: - source_cls = SourceCodeFile(code_file) + source_cls = RustSourceCodeFile('rust', code_file) if is_log: if source_cls.has_libfuzzer_harness(): logger.info('harness: %s', code_file) results.append(source_cls) - return results + return RustProject(results) def analyse_source_code(source_content: str, - entrypoint: str) -> SourceCodeFile: + entrypoint: str) -> RustSourceCodeFile: """Returns a source abstraction based on a single source string.""" - source_code = SourceCodeFile(source_file='in-memory string', - source_content=source_content.encode()) + source_code = RustSourceCodeFile('rust', + source_file='in-memory string', + source_content=source_content.encode()) return source_code diff --git a/src/fuzz_introspector/frontends/oss_fuzz.py b/src/fuzz_introspector/frontends/oss_fuzz.py index 17648403..1ccfb232 100644 --- a/src/fuzz_introspector/frontends/oss_fuzz.py +++ b/src/fuzz_introspector/frontends/oss_fuzz.py @@ -100,11 +100,8 @@ def process_c_project(target_dir: str, logger.info('Going C route') logger.info('Found %d files to include in analysis', len(source_files)) - logger.info('Loading tree-sitter trees') - source_codes = frontend_c.load_treesitter_trees(source_files) - - logger.info('Creating base project.') - project = frontend_c.CProject(source_codes) + logger.info('Loading tree-sitter trees and create base project') + project = frontend_c.load_treesitter_trees(source_files) # We may not need to do this, but will do it while refactoring into # the new frontends. @@ -140,7 +137,6 @@ def process_c_project(target_dir: str, harness.source_file, dump_output) logger.info('Extracting calltree for %s', harness.source_file) - calltree = project.extract_calltree(harness, entrypoint) calltree = project.extract_calltree(source_code=harness, function=entrypoint) with open(os.path.join(out, 'fuzzerLogFile-%d.data' % (idx)), @@ -153,84 +149,6 @@ def process_c_project(target_dir: str, return project -def process_cpp_project(entrypoint: str, - out: str, - source_files: list[str], - dump_output: bool = True) -> Project: - """Process a project in CPP language""" - # Default entrypoint - if not entrypoint: - entrypoint = 'LLVMFuzzerTestOneInput' - - # Process tree sitter for c++ source files - logger.info('Going C++ route') - logger.info('Found %d files to include in analysis', len(source_files)) - logger.info('Loading tree-sitter trees') - source_codes = frontend_cpp.load_treesitter_trees(source_files) - - # Create and dump project - logger.info('Creating base project.') - project = frontend_cpp.CppProject(source_codes) - - return project - - -def process_go_project(out: str, - source_files: list[str], - dump_output: bool = True) -> Project: - """Process a project in Go language""" - # Process tree sitter for go source files - logger.info('Going Go route') - logger.info('Found %d files to include in analysis', len(source_files)) - logger.info('Loading tree-sitter trees') - source_codes = frontend_go.load_treesitter_trees(source_files) - - # Create and dump project - logger.info('Creating base project.') - project = frontend_go.GoProject(source_codes) - - return project - - -def process_jvm_project(entrypoint: str, - out: str, - source_files: list[str], - dump_output: bool = True) -> Project: - """Process a project in JVM based language""" - # Default entrypoint - if not entrypoint: - entrypoint = 'fuzzerTestOneInput' - - # Process tree sitter for go source files - logger.info('Going JVM route') - logger.info('Found %d files to include in analysis', len(source_files)) - logger.info('Loading tree-sitter trees') - source_codes = frontend_jvm.load_treesitter_trees(source_files, entrypoint) - - # Create and dump project - logger.info('Creating base project.') - project = frontend_jvm.JvmProject(source_codes) - - return project - - -def process_rust_project(out: str, - source_files: list[str], - dump_output: bool = True) -> Project: - """Process a project in Rust based language""" - # Process tree sitter for rust source files - logger.info('Going Rust route') - logger.info('Found %d files to include in analysis', len(source_files)) - logger.info('Loading tree-sitter trees') - source_codes = frontend_rust.load_treesitter_trees(source_files) - - # Create and dump project - logger.info('Creating base project.') - project = frontend_rust.RustProject(source_codes) - - return project - - def analyse_folder(language: str = '', directory: str = '', entrypoint: str = '', @@ -241,6 +159,7 @@ def analyse_folder(language: str = '', # Extract source files for target language source_files = capture_source_files_in_tree(directory, language) + logger.info('Found %d files to include in analysis', len(source_files)) if language == 'c': project = process_c_project(directory, @@ -252,23 +171,26 @@ def analyse_folder(language: str = '', else: # Process for different language if language.lower() in ['cpp', 'c++']: - project = process_cpp_project(entrypoint, - out, - source_files, - dump_output=dump_output) + logger.info('Going C++ route') + logger.info('Loading tree-sitter trees') + if not entrypoint: + entrypoint = 'LLVMFuzzerTestOneInput' + project = frontend_cpp.load_treesitter_trees(source_files) elif language == 'go': - project = process_go_project(out, - source_files, - dump_output=dump_output) + logger.info('Going Go route') + logger.info('Loading tree-sitter trees and create base project') + project = frontend_go.load_treesitter_trees(source_files) elif language == 'jvm': - project = process_jvm_project(entrypoint, - out, - source_files, - dump_output=dump_output) + logger.info('Going JVM route') + logger.info('Loading tree-sitter trees and create base project') + if not entrypoint: + entrypoint = 'fuzzerTestOneInput' + project = frontend_jvm.load_treesitter_trees( + source_files, entrypoint) elif language == 'rust': - project = process_rust_project(out, - source_files, - dump_output=dump_output) + logger.info('Going Rust route') + logger.info('Loading tree-sitter trees and create base project') + project = frontend_rust.load_treesitter_trees(source_files) else: logger.error('Unsupported language: %s' % language) return Project([]) From c3ec19dab7e00e0a6fb72c50c381642428255a9d Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Fri, 17 Jan 2025 22:01:19 +0000 Subject: [PATCH 2/3] Fix formatting Signed-off-by: Arthur Chan --- src/fuzz_introspector/frontends/datatypes.py | 3 ++- src/fuzz_introspector/frontends/frontend_c.py | 7 ++----- src/fuzz_introspector/frontends/frontend_cpp.py | 5 ++--- src/fuzz_introspector/frontends/frontend_go.py | 5 ++--- src/fuzz_introspector/frontends/frontend_jvm.py | 5 ++--- src/fuzz_introspector/frontends/frontend_rust.py | 5 ++--- 6 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/fuzz_introspector/frontends/datatypes.py b/src/fuzz_introspector/frontends/datatypes.py index 6abce25d..c88df265 100644 --- a/src/fuzz_introspector/frontends/datatypes.py +++ b/src/fuzz_introspector/frontends/datatypes.py @@ -14,6 +14,8 @@ # ################################################################################ +from typing import Any, Optional, Generic, TypeVar + from tree_sitter import Language, Parser import tree_sitter_c import tree_sitter_cpp @@ -25,7 +27,6 @@ logger = logging.getLogger(name=__name__) -from typing import Any, Optional, Generic, TypeVar T = TypeVar('T', bound='SourceCodeFile') diff --git a/src/fuzz_introspector/frontends/frontend_c.py b/src/fuzz_introspector/frontends/frontend_c.py index e3932234..88fccd40 100644 --- a/src/fuzz_introspector/frontends/frontend_c.py +++ b/src/fuzz_introspector/frontends/frontend_c.py @@ -15,15 +15,12 @@ ################################################################################ """Fuzz Introspector Light frontend""" -import os +from typing import Any, Optional +import os import logging - -from tree_sitter import Language, Parser import yaml -from typing import Any, Optional - from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile logger = logging.getLogger(name=__name__) diff --git a/src/fuzz_introspector/frontends/frontend_cpp.py b/src/fuzz_introspector/frontends/frontend_cpp.py index d9001bb6..222adcca 100644 --- a/src/fuzz_introspector/frontends/frontend_cpp.py +++ b/src/fuzz_introspector/frontends/frontend_cpp.py @@ -16,11 +16,10 @@ from typing import Any, Optional +from tree_sitter import Language, Node + import os import logging - -from tree_sitter import Language, Parser, Node -import tree_sitter_cpp import yaml from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile diff --git a/src/fuzz_introspector/frontends/frontend_go.py b/src/fuzz_introspector/frontends/frontend_go.py index 45cf592b..eecd43aa 100644 --- a/src/fuzz_introspector/frontends/frontend_go.py +++ b/src/fuzz_introspector/frontends/frontend_go.py @@ -17,10 +17,9 @@ from typing import Any, Optional -import logging +from tree_sitter import Language, Node -from tree_sitter import Language, Parser, Node -import tree_sitter_go +import logging import yaml from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile diff --git a/src/fuzz_introspector/frontends/frontend_jvm.py b/src/fuzz_introspector/frontends/frontend_jvm.py index 473ed176..38e9181a 100644 --- a/src/fuzz_introspector/frontends/frontend_jvm.py +++ b/src/fuzz_introspector/frontends/frontend_jvm.py @@ -17,10 +17,9 @@ from typing import Any, Optional -import logging +from tree_sitter import Language, Node -from tree_sitter import Language, Parser, Node -import tree_sitter_java +import logging import yaml from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile diff --git a/src/fuzz_introspector/frontends/frontend_rust.py b/src/fuzz_introspector/frontends/frontend_rust.py index 47450983..3f5b8395 100644 --- a/src/fuzz_introspector/frontends/frontend_rust.py +++ b/src/fuzz_introspector/frontends/frontend_rust.py @@ -17,10 +17,9 @@ from typing import Any, Optional -import logging +from tree_sitter import Language, Node -from tree_sitter import Language, Parser, Node -import tree_sitter_rust +import logging import yaml from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile From cd3440b2575d8ffbd7855af3132603f924568e63 Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Fri, 17 Jan 2025 22:02:56 +0000 Subject: [PATCH 3/3] Fix formatting Signed-off-by: Arthur Chan --- src/fuzz_introspector/frontends/datatypes.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fuzz_introspector/frontends/datatypes.py b/src/fuzz_introspector/frontends/datatypes.py index c88df265..0355064e 100644 --- a/src/fuzz_introspector/frontends/datatypes.py +++ b/src/fuzz_introspector/frontends/datatypes.py @@ -27,7 +27,6 @@ logger = logging.getLogger(name=__name__) - T = TypeVar('T', bound='SourceCodeFile')