From f4d834e9fed8909526c12290c95ee0a59a1e40a9 Mon Sep 17 00:00:00 2001 From: Pavlo Penenko Date: Sun, 15 Dec 2024 15:58:01 -0500 Subject: [PATCH] Count the number of compilation failures and and raise an exception at the end if non-0 (#18) --- src/generate.py | 64 ++++++++++++++++++++++++------------------ tests/test_generate.py | 3 +- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/src/generate.py b/src/generate.py index d4da896..05b9a2f 100644 --- a/src/generate.py +++ b/src/generate.py @@ -50,27 +50,27 @@ def generate(self, material, primitive): # self._generate(shader_file, material, primitive) + class CompileResult(NamedTuple): + log : str + success : bool + @abc.abstractmethod - def compile(self, to_glsl : bool) -> str: + def compile(self, to_glsl : bool) -> CompileResult: pass -def _compile_shader( - shader, - to_glsl : bool, - ignore_compile_errors : bool -) -> str: +def _compile_shader(shader, to_glsl : bool) -> _Shader.CompileResult: ''' Helper function to compile a shader in a process pool. Without it, the pool would not be able to pickle the method. ''' - return shader.compile(to_glsl, ignore_compile_errors) + return shader.compile(to_glsl) class _HlslShader(_Shader): @abc.abstractmethod def _get_hlsl_profile(): pass - def compile(self, to_glsl : bool, ignore_compile_errors : bool) -> str: + def compile(self, to_glsl : bool) -> _Shader.CompileResult: log = io.StringIO() log, sys.stdout = sys.stdout, log @@ -102,12 +102,12 @@ def compile(self, to_glsl : bool, ignore_compile_errors : bool) -> str: entry_point_name = _impl.entry_point_name, output_path = spv_path ) + success = True except subprocess.CalledProcessError as err: - if not ignore_compile_errors: - raise err + success = False log, sys.stdout = sys.stdout, log - return log.getvalue() + return _Shader.CompileResult(log.getvalue(), success) class _HlslVertexShader(_HlslShader): def __init__( @@ -154,7 +154,7 @@ def _generate(self, shader_file, material, primitive): ) class _GlslShader(_Shader): - def compile(self, to_glsl : bool, ignore_compile_errors : bool ) -> str: + def compile(self, to_glsl : bool) -> _Shader.CompileResult: log = io.StringIO() log, sys.stdout = sys.stdout, log @@ -166,12 +166,12 @@ def compile(self, to_glsl : bool, ignore_compile_errors : bool ) -> str: shader_stage = 'frag', output_path = glsl_output_path ) + success = True except subprocess.CalledProcessError as err: - if not ignore_compile_errors: - raise err + success = False log, sys.stdout = sys.stdout, log - return log.getvalue() + return _Shader.CompileResult(log.getvalue(), success) class _GlslFragmentShader(_GlslShader): def __init__( @@ -237,8 +237,7 @@ def generate( compile : bool, to_glsl : bool, skip_codegen : bool, - serial : bool, - ignore_compile_errors : bool + serial : bool ): if not gltf_dir_path.is_dir(): raise NotADirectoryError(gltf_dir_path) @@ -277,24 +276,34 @@ def generate( glslc.identify() spirv_cross.identify() + num_failed = 0 + if serial: for shader in shaders: - log = shader.compile( - to_glsl = to_glsl, - ignore_compile_errors = ignore_compile_errors - ) - print(log, end = '') + result = shader.compile(to_glsl = to_glsl) + if not result.success: + num_failed += 1 + print(result.log, end = '') else: with mp.Pool() as pool: - for log in pool.imap_unordered( + for result in pool.imap_unordered( functools.partial( _compile_shader, - to_glsl = to_glsl, - ignore_compile_errors = ignore_compile_errors + to_glsl = to_glsl ), shaders ): - print(log, end = '') + if not result.success: + num_failed += 1 + print(result.log, end = '') + + if num_failed > 0: + raise RuntimeError( + f'{num_failed} out of {len(shaders)} shaders failed to ' + 'compile - see the log above.' + ) + else: + print(f'\nAll {len(shaders)} shaders compiled successfully.') if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -331,6 +340,5 @@ def generate( compile = args.compile, to_glsl = args.to_glsl, skip_codegen = args.skip_codegen, - serial = args.serial, - ignore_compile_errors = True # just collect all errors in stdout + serial = args.serial ) \ No newline at end of file diff --git a/tests/test_generate.py b/tests/test_generate.py index 0a7a343..da6015f 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -36,6 +36,5 @@ def test_generate(self): compile = True, to_glsl = False, skip_codegen = False, - serial = False, - ignore_compile_errors = False + serial = False )