diff --git a/include/ntlmclient.h b/include/ntlmclient.h index d109a5c..2923662 100644 --- a/include/ntlmclient.h +++ b/include/ntlmclient.h @@ -22,6 +22,19 @@ extern "C" { typedef struct ntlm_client ntlm_client; +typedef enum { + /** + * An error occurred; more details are available by querying + * `ntlm_client_errmsg`. + */ + NTLM_CLIENT_ERROR = -1, + + /** + * The input provided to the function is missing or invalid. + */ + NTLM_CLIENT_ERROR_INVALID_INPUT = -2, +} ntlm_error_code; + /* * Flags for initializing the `ntlm_client` context. A combination of * these flags can be provided to `ntlm_client_init`. diff --git a/src/ntlm.c b/src/ntlm.c index 979d612..3393be9 100644 --- a/src/ntlm.c +++ b/src/ntlm.c @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include @@ -24,6 +23,18 @@ #include "compat.h" #include "util.h" +#define NTLM_ASSERT_ARG(expr) do { \ + if (!(expr)) \ + return NTLM_CLIENT_ERROR_INVALID_INPUT; \ + } while(0) + +#define NTLM_ASSERT(ntlm, expr) do { \ + if (!(expr)) { \ + ntlm_client_set_errmsg(ntlm, "internal error: " #expr); \ + return -1; \ + } \ + } while(0) + unsigned char ntlm_client_signature[] = NTLM_SIGNATURE; static bool supports_unicode(ntlm_client *ntlm) @@ -74,7 +85,9 @@ void ntlm_client_set_errmsg(ntlm_client *ntlm, const char *errmsg) const char *ntlm_client_errmsg(ntlm_client *ntlm) { - assert(ntlm); + if (!ntlm) + return "internal error"; + return ntlm->errmsg ? ntlm->errmsg : "no error"; } @@ -84,7 +97,7 @@ int ntlm_client_set_version( uint8_t minor, uint16_t build) { - assert(ntlm); + NTLM_ASSERT_ARG(ntlm); ntlm->host_version.major = major; ntlm->host_version.minor = minor; @@ -111,8 +124,7 @@ int ntlm_client_set_hostname( const char *hostname, const char *domain) { - assert(ntlm); - + NTLM_ASSERT_ARG(ntlm); ENSURE_INITIALIZED(ntlm); free_hostname(ntlm); @@ -168,8 +180,7 @@ int ntlm_client_set_credentials( const char *domain, const char *password) { - assert(ntlm); - + NTLM_ASSERT_ARG(ntlm); ENSURE_INITIALIZED(ntlm); free_credentials(ntlm); @@ -218,8 +229,7 @@ int ntlm_client_set_credentials( int ntlm_client_set_target(ntlm_client *ntlm, const char *target) { - assert(ntlm); - + NTLM_ASSERT_ARG(ntlm); ENSURE_INITIALIZED(ntlm); free(ntlm->target); @@ -248,14 +258,16 @@ int ntlm_client_set_target(ntlm_client *ntlm, const char *target) int ntlm_client_set_nonce(ntlm_client *ntlm, uint64_t nonce) { - assert(ntlm); + NTLM_ASSERT_ARG(ntlm); + ntlm->nonce = nonce; return 0; } int ntlm_client_set_timestamp(ntlm_client *ntlm, uint64_t timestamp) { - assert(ntlm); + NTLM_ASSERT_ARG(ntlm); + ntlm->timestamp = timestamp; return 0; } @@ -601,7 +613,9 @@ int ntlm_client_negotiate( size_t hostname_offset = 0; uint32_t flags = 0; - assert(out && out_len && ntlm); + NTLM_ASSERT_ARG(out); + NTLM_ASSERT_ARG(out_len); + NTLM_ASSERT_ARG(ntlm); *out = NULL; *out_len = 0; @@ -684,20 +698,22 @@ int ntlm_client_negotiate( return -1; if (hostname_len > 0) { - assert(hostname_offset == ntlm->negotiate.pos); + NTLM_ASSERT(ntlm, hostname_offset == ntlm->negotiate.pos); + if (!write_buf(ntlm, &ntlm->negotiate, (const unsigned char *)ntlm->hostname, hostname_len)) return -1; } if (domain_len > 0) { - assert(domain_offset == ntlm->negotiate.pos); + NTLM_ASSERT(ntlm, domain_offset == ntlm->negotiate.pos); + if (!write_buf(ntlm, &ntlm->negotiate, (const unsigned char *)ntlm->hostdomain, domain_len)) return -1; } - assert(ntlm->negotiate.pos == ntlm->negotiate.len); + NTLM_ASSERT(ntlm, ntlm->negotiate.pos == ntlm->negotiate.len); ntlm->state = NTLM_STATE_CHALLENGE; @@ -719,7 +735,8 @@ int ntlm_client_set_challenge( uint32_t name_offset, info_offset = 0; bool unicode, has_target_info = false; - assert(ntlm && (challenge_msg || !challenge_msg_len)); + NTLM_ASSERT_ARG(ntlm); + NTLM_ASSERT_ARG(challenge_msg || !challenge_msg_len); ENSURE_INITIALIZED(ntlm); @@ -1101,7 +1118,7 @@ static bool generate_ntlm2_hash( return false; } - assert(out_len == NTLM_NTLM2_HASH_LEN); + NTLM_ASSERT(ntlm, out_len == NTLM_NTLM2_HASH_LEN); return true; } @@ -1122,7 +1139,7 @@ static bool generate_ntlm2_challengehash( return false; } - assert(out_len == 16); + NTLM_ASSERT(ntlm, out_len == 16); return true; } @@ -1143,7 +1160,7 @@ static bool generate_lm2_response(ntlm_client *ntlm, return false; } - assert(lm2_len == 16); + NTLM_ASSERT(ntlm, lm2_len == 16); memcpy(&ntlm->lm_response[0], lm2_challengehash, 16); memcpy(&ntlm->lm_response[16], &local_nonce, 8); @@ -1237,7 +1254,9 @@ int ntlm_client_response( uint32_t flags = 0; bool unicode; - assert(out && out_len && ntlm); + NTLM_ASSERT_ARG(out); + NTLM_ASSERT_ARG(out_len); + NTLM_ASSERT_ARG(ntlm); ENSURE_INITIALIZED(ntlm); @@ -1362,7 +1381,7 @@ int ntlm_client_response( !write_buf(ntlm, &ntlm->response, session, session_len)) return -1; - assert(ntlm->response.pos == ntlm->response.len); + NTLM_ASSERT(ntlm, ntlm->response.pos == ntlm->response.len); ntlm->state = NTLM_STATE_COMPLETE; @@ -1374,7 +1393,8 @@ int ntlm_client_response( void ntlm_client_reset(ntlm_client *ntlm) { - assert(ntlm); + if (!ntlm) + return; ntlm->state = NTLM_STATE_NEGOTIATE; diff --git a/tests/inputs.c b/tests/inputs.c index 276c9c9..7178877 100644 --- a/tests/inputs.c +++ b/tests/inputs.c @@ -1,5 +1,6 @@ #include "clar.h" #include "ntlm.h" +#include "ntlm_tests.h" static ntlm_client *ntlm; @@ -13,6 +14,31 @@ void test_inputs__cleanup(void) ntlm_client_free(ntlm); } +void test_inputs__null(void) +{ + const unsigned char *msg; + size_t msg_len; + + cl_assert(ntlm_client_errmsg(NULL) != NULL); + + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_set_hostname(NULL, "hostname", "HOSTDOMAIN")); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_set_credentials(NULL, "user", "DOMAIN", "pass!")); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_set_target(NULL, "target")); + + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_negotiate(NULL, &msg_len, ntlm)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_negotiate(&msg, NULL, ntlm)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_negotiate(&msg, &msg_len, NULL)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_negotiate(NULL, NULL, NULL)); + + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_set_challenge(NULL, (const unsigned char *)"foo", 3)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_set_challenge(ntlm, NULL, 3)); + + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_response(NULL, &msg_len, ntlm)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_response(&msg, NULL, ntlm)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_response(&msg, &msg_len, NULL)); + cl_must_fail_with(NTLM_CLIENT_ERROR_INVALID_INPUT, ntlm_client_response(NULL, NULL, NULL)); +} + void test_inputs__set_hostname(void) { cl_must_pass(ntlm_client_set_hostname(ntlm, "hostname", "HOSTDOMAIN")); diff --git a/tests/ntlm_tests.h b/tests/ntlm_tests.h index 8ccec24..6344bc4 100644 --- a/tests/ntlm_tests.h +++ b/tests/ntlm_tests.h @@ -5,6 +5,10 @@ #include "ntlm.h" #include "util.h" +#define cl_must_fail_with_(val, expr, desc) clar__assert((expr) == (val), __FILE__, __LINE__, "Expected function call to fail with " #val ": " #expr, desc, 0) + +#define cl_must_fail_with(val, expr) cl_must_fail_with_(val, expr, NULL) + #define cl_ntlm_pass(ntlm, expr) cl_ntlm_expect((ntlm), (expr), 0, __FILE__, __LINE__) #define cl_ntlm_expect(ntlm, expr, expected, file, line) do { \