From f7d0bed7aa7b206fb03fbbbb31173537b3e50921 Mon Sep 17 00:00:00 2001 From: Vadim Vetrov Date: Sun, 2 Feb 2025 23:34:10 +0300 Subject: [PATCH] Use Aho-Corasick algorithm in tls parsing --- src/args.c | 66 +++++++++--------------------------- src/config.h | 11 +++--- src/tls.c | 96 ++++++++++++++++------------------------------------ src/trie.c | 16 ++++++++- src/trie.h | 7 ---- test/tls.c | 15 ++++---- test/trie.c | 43 +++++++++++++++++++++++ 7 files changed, 117 insertions(+), 137 deletions(-) diff --git a/src/args.c b/src/args.c index b8b66d4..d18b026 100644 --- a/src/args.c +++ b/src/args.c @@ -94,11 +94,9 @@ close_file: } #endif -static int parse_sni_domains(struct domains_list **dlist, const char *domains_str, size_t domains_strlen) { - // Empty and shouldn't be used - struct domains_list ndomain = {0}; - struct domains_list *cdomain = &ndomain; - +static int parse_sni_domains(struct trie_container *trie, const char *domains_str, size_t domains_strlen) { + trie_init(trie); + unsigned int j = 0; for (unsigned int i = 0; i <= domains_strlen; i++) { if (( i == domains_strlen || @@ -119,38 +117,17 @@ static int parse_sni_domains(struct domains_list **dlist, const char *domains_st unsigned int domain_len = (i - j); const char *domain_startp = domains_str + j; - struct domains_list *edomain = malloc(sizeof(struct domains_list)); - *edomain = (struct domains_list){0}; - if (edomain == NULL) { - return -ENOMEM; - } - - edomain->domain_len = domain_len; - edomain->domain_name = malloc(domain_len + 1); - if (edomain->domain_name == NULL) { - return -ENOMEM; - } - - strncpy(edomain->domain_name, domain_startp, domain_len); - edomain->domain_name[domain_len] = '\0'; - cdomain->next = edomain; - cdomain = edomain; + trie_add_string(trie, (const uint8_t *)domain_startp, domain_len); j = i + 1; } } - *dlist = ndomain.next; return 0; } -static void free_sni_domains(struct domains_list *dlist) { - for (struct domains_list *ldl = dlist; ldl != NULL;) { - struct domains_list *ndl = ldl->next; - SFREE(ldl->domain_name); - SFREE(ldl); - ldl = ndl; - } +static void free_sni_domains(struct trie_container *trie) { + trie_destroy(trie); } static long parse_numeric_option(const char* value) { @@ -633,7 +610,7 @@ int yparse_args(struct config_t *config, int argc, char *argv[]) { break; case OPT_SNI_DOMAINS: - free_sni_domains(sect_config->sni_domains); + free_sni_domains(§_config->sni_domains); sect_config->all_domains = 0; if (!strcmp(optarg, "all")) { sect_config->all_domains = 1; @@ -649,7 +626,7 @@ int yparse_args(struct config_t *config, int argc, char *argv[]) { goto error; #else { - free_sni_domains(sect_config->sni_domains); + free_sni_domains(§_config->sni_domains); ret = read_file(optarg); if (ret < 0) { goto error; @@ -662,7 +639,7 @@ int yparse_args(struct config_t *config, int argc, char *argv[]) { } #endif case OPT_EXCLUDE_DOMAINS: - free_sni_domains(sect_config->exclude_sni_domains); + free_sni_domains(§_config->exclude_sni_domains); ret = parse_sni_domains(§_config->exclude_sni_domains, optarg, strlen(optarg)); if (ret < 0) goto error; @@ -674,7 +651,7 @@ int yparse_args(struct config_t *config, int argc, char *argv[]) { goto error; #else { - free_sni_domains(sect_config->exclude_sni_domains); + free_sni_domains(§_config->exclude_sni_domains); ret = read_file(optarg); if (ret < 0) { goto error; @@ -1068,20 +1045,11 @@ static size_t print_config_section(const struct section_config_t *section, char if (section->all_domains) { print_cnf_buf("--sni-domains=all"); - } else if (section->sni_domains != NULL) { - print_cnf_raw("--sni-domains="); - - for (struct domains_list *sne = section->sni_domains; sne != NULL; sne = sne->next) { - print_cnf_raw("%s,", sne->domain_name); - } - print_cnf_raw(" "); + } else if (section->sni_domains.vx != NULL) { + print_cnf_buf("--sni-domains=", section->sni_domains.sz); } - if (section->exclude_sni_domains != NULL) { - print_cnf_raw("--exclude-domains="); - for (struct domains_list *sne = section->exclude_sni_domains; sne != NULL; sne = sne->next) { - print_cnf_raw("%s,", sne->domain_name); - } - print_cnf_raw(" "); + if (section->exclude_sni_domains.vx != NULL) { + print_cnf_buf("--exclude-domains=", section->sni_domains.sz); } switch(section->sni_detection) { @@ -1281,10 +1249,8 @@ void free_config_section(struct section_config_t *section) { SFREE(section->udp_dport_range); } - free_sni_domains(section->sni_domains); - section->sni_domains = NULL; - free_sni_domains(section->exclude_sni_domains); - section->exclude_sni_domains = NULL; + free_sni_domains(§ion->sni_domains); + free_sni_domains(§ion->exclude_sni_domains); section->fake_custom_pkt_sz = 0; SFREE(section->fake_custom_pkt); diff --git a/src/config.h b/src/config.h index 36579d0..f009791 100644 --- a/src/config.h +++ b/src/config.h @@ -25,6 +25,7 @@ #endif #include "types.h" +#include "trie.h" typedef int (*raw_send_t)(const unsigned char *data, size_t data_len); /** @@ -64,8 +65,10 @@ struct section_config_t { struct section_config_t *next; struct section_config_t *prev; - struct domains_list *sni_domains; - struct domains_list *exclude_sni_domains; + // struct domains_list *sni_domains; + // struct domains_list *exclude_sni_domains; + struct trie_container sni_domains; + struct trie_container exclude_sni_domains; unsigned int all_domains; int tls_enabled; @@ -237,8 +240,8 @@ enum { }; #define default_section_config { \ - .sni_domains = NULL, \ - .exclude_sni_domains = NULL, \ + .sni_domains = {0}, \ + .exclude_sni_domains = {0}, \ .all_domains = 0, \ .tls_enabled = 1, \ .frag_sni_reverse = 1, \ diff --git a/src/tls.c b/src/tls.c index 73721b2..73eb118 100644 --- a/src/tls.c +++ b/src/tls.c @@ -33,6 +33,8 @@ int bruteforce_analyze_sni_str( const uint8_t *data, size_t dlen, struct tls_verdict *vrd ) { + size_t offset, offlen; + int ret; *vrd = (struct tls_verdict){0}; if (dlen <= 1) { @@ -47,50 +49,17 @@ int bruteforce_analyze_sni_str( vrd->target_sni_len = vrd->sni_len; return 0; } - int max_domain_len = 0; - for (struct domains_list *sne = section->sni_domains; sne != NULL; - sne = sne->next) { - max_domain_len = max((int)sne->domain_len, max_domain_len); - } - - size_t buf_size = max_domain_len + dlen + 1; - uint8_t *buf = malloc(buf_size); - if (buf == NULL) { - return -ENOMEM; - } - int *nzbuf = malloc(buf_size * sizeof(int)); - if (nzbuf == NULL) { - free(buf); - return -ENOMEM; + // It is safe for multithreading, so dp mutability is ok + ret = trie_process_str((struct trie_container *)§ion->sni_domains, data, dlen, 0, &offset, &offlen); + if (ret) { + vrd->target_sni = 1; + vrd->sni_len = offlen; + vrd->sni_ptr = data + offset; + vrd->target_sni_ptr = vrd->sni_ptr; + vrd->target_sni_len = vrd->sni_len; } - for (struct domains_list *sne = section->sni_domains; sne != NULL; sne = sne->next) { - const char *domain_startp = sne->domain_name; - int domain_len = sne->domain_len; - - int *zbuf = (void *)nzbuf; - - memcpy(buf, domain_startp, domain_len); - memcpy(buf + domain_len, "#", 1); - memcpy(buf + domain_len + 1, data, dlen); - - z_function((char *)buf, zbuf, domain_len + 1 + dlen); - - for (size_t k = 0; k < domain_len + 1 + dlen; k++) { - if (zbuf[k] == domain_len) { - vrd->target_sni = 1; - vrd->sni_len = domain_len; - vrd->sni_ptr = data + (k - domain_len - 1); - vrd->target_sni_ptr = vrd->sni_ptr; - vrd->target_sni_len = vrd->sni_len; - goto return_vrd; - } - } - } -return_vrd: - free(buf); - free(nzbuf); return 0; } static int analyze_sni_str( @@ -98,42 +67,35 @@ static int analyze_sni_str( const char *sni_name, int sni_len, struct tls_verdict *vrd ) { + int ret; + size_t offset, offlen; + if (section->all_domains) { vrd->target_sni = 1; goto check_domain; } - - for (struct domains_list *sne = section->sni_domains; sne != NULL; sne = sne->next) { - const char *sni_startp = sni_name + sni_len - sne->domain_len; - const char *domain_startp = sne->domain_name; - if (sni_len >= sne->domain_len && - sni_len < 128 && - !strncmp(sni_startp, - domain_startp, - sne->domain_len)) { - vrd->target_sni = 1; - vrd->target_sni_ptr = (const uint8_t *)sni_startp; - vrd->target_sni_len = sne->domain_len; - break; - } + lgtrace_addp("abacaba"); + + // It is safe for multithreading, so dp mutability is ok + ret = trie_process_str((struct trie_container *)§ion->sni_domains, + (const uint8_t *)sni_name, sni_len, TRIE_OPT_MAP_TO_END, &offset, &offlen); + if (ret) { + vrd->target_sni = 1; + vrd->target_sni_ptr = (const uint8_t *)sni_name + offset; + vrd->target_sni_len = offlen; } check_domain: if (vrd->target_sni == 1) { - for (struct domains_list *sne = section->exclude_sni_domains; sne != NULL; sne = sne->next) { - const char *sni_startp = sni_name + sni_len - sne->domain_len; - const char *domain_startp = sne->domain_name; - if (sni_len >= sne->domain_len && - sni_len < 128 && - !strncmp(sni_startp, - domain_startp, - sne->domain_len)) { - vrd->target_sni = 0; - lgdebug("Excluded SNI: %.*s", - vrd->sni_len, vrd->sni_ptr); - } + // It is safe for multithreading, so dp mutability is ok + ret = trie_process_str((struct trie_container *)§ion->exclude_sni_domains, + (const uint8_t *)sni_name, sni_len, TRIE_OPT_MAP_TO_END, &offset, &offlen); + if (ret) { + vrd->target_sni = 0; + lgdebug("Excluded SNI: %.*s", + vrd->sni_len, vrd->sni_ptr); } } diff --git a/src/trie.c b/src/trie.c index 1d4af2a..30f08e2 100644 --- a/src/trie.c +++ b/src/trie.c @@ -52,7 +52,13 @@ void trie_destroy(struct trie_container *trie) { trie->vx = NULL; } -int trie_push_vertex(struct trie_container *trie) { +/** + * + * Increases trie vertex container size. + * Returns new vertex index or ret < 0 on error + * + */ +static int trie_push_vertex(struct trie_container *trie) { if (trie->sz == NMAX - 1) { return -EINVAL; } @@ -74,6 +80,10 @@ int trie_push_vertex(struct trie_container *trie) { int trie_add_string(struct trie_container *trie, const uint8_t *str, size_t strlen) { + if (trie == NULL || trie->vx == NULL) { + return -EINVAL; + } + int v = 0; int nv; @@ -145,6 +155,10 @@ int trie_process_str( int flags, size_t *offset, size_t *offlen ) { + if (trie == NULL || trie->vx == NULL) { + return 0; + } + int v = 0; size_t i = 0; uint8_t c; diff --git a/src/trie.h b/src/trie.h index 3bcceea..7d8b565 100644 --- a/src/trie.h +++ b/src/trie.h @@ -66,13 +66,6 @@ struct trie_container { int trie_init(struct trie_container *trie); void trie_destroy(struct trie_container *trie); -/** - * - * Increases trie vertex container size. - * Returns new vertex index or ret < 0 on error - * - */ -int trie_push_vertex(struct trie_container *trie); int trie_add_string(struct trie_container *trie, const uint8_t *str, size_t strlen); diff --git a/test/tls.c b/test/tls.c index cd25aa0..f5803db 100644 --- a/test/tls.c +++ b/test/tls.c @@ -36,22 +36,21 @@ TEST(TLSTest, Test_CHLO_message_detect) TEST(TLSTest, Test_Bruteforce_detects) { struct tls_verdict tlsv; - struct domains_list dmns = { - .domain_name = "youtube.com", - .domain_len = 11, - .next = NULL - }; - sconf.sni_domains = &dmns; + struct trie_container trie; + int ret; + ret = trie_init(&trie); + ret = trie_add_string(&trie, (uint8_t *)"youtube.com", 11); + sconf.sni_domains = trie; - int ret = bruteforce_analyze_sni_str(&sconf, (const uint8_t *)tls_bruteforce_message, sizeof(tls_bruteforce_message) - 1, &tlsv); + ret = bruteforce_analyze_sni_str(&sconf, (const uint8_t *)tls_bruteforce_message, sizeof(tls_bruteforce_message) - 1, &tlsv); TEST_ASSERT_EQUAL(0, ret); TEST_ASSERT_EQUAL(11, tlsv.sni_len); TEST_ASSERT_EQUAL_STRING_LEN("youtube.com", tlsv.sni_ptr, 11); TEST_ASSERT_EQUAL_PTR(tls_bruteforce_message + sizeof(tls_bruteforce_message) - 12, tlsv.sni_ptr); + trie_destroy(&trie); } - TEST_GROUP_RUNNER(TLSTest) { RUN_TEST_CASE(TLSTest, Test_CHLO_message_detect); diff --git a/test/trie.c b/test/trie.c index 642da76..34c26fb 100644 --- a/test/trie.c +++ b/test/trie.c @@ -95,10 +95,53 @@ TEST(TrieTest, Trie_string_finds_opt_end) trie_destroy(&trie); } +TEST(TrieTest, Trie_single_vertex) +{ + int ret; + size_t offset; + size_t offlen; + struct trie_container trie; + + ret = trie_init(&trie); + + ret = trie_process_str(&trie, + (uint8_t *)tstr, sizeof(tstr) - 1, + 0, + &offset, &offlen + ); + TEST_ASSERT_EQUAL(0, ret); + + trie_destroy(&trie); + +} + +TEST(TrieTest, Trie_uninitialized) +{ + int ret; + size_t offset; + size_t offlen; + struct trie_container trie = {0}; + + // ret = trie_init(&trie); + + ret = trie_add_string(&trie, (uint8_t *)ASTR, sizeof(ASTR) - 1); + TEST_ASSERT_EQUAL(-EINVAL, ret); + + ret = trie_process_str(&trie, + (uint8_t *)tstr, sizeof(tstr) - 1, + 0, + &offset, &offlen + ); + TEST_ASSERT_EQUAL(0, ret); + +} + TEST_GROUP_RUNNER(TrieTest) { RUN_TEST_CASE(TrieTest, Trie_string_adds); RUN_TEST_CASE(TrieTest, Trie_string_finds); RUN_TEST_CASE(TrieTest, Trie_string_finds_opt_end); + RUN_TEST_CASE(TrieTest, Trie_single_vertex); + RUN_TEST_CASE(TrieTest, Trie_uninitialized); }