00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "ink_config.h"
00025
00026 #include "P_SSLCertLookup.h"
00027 #include "P_SSLUtils.h"
00028 #include "P_SSLConfig.h"
00029 #include "I_EventSystem.h"
00030 #include "I_Layout.h"
00031 #include "Regex.h"
00032 #include "Trie.h"
00033 #include "ts/TestBox.h"
00034
00035 struct SSLAddressLookupKey
00036 {
00037 explicit
00038 SSLAddressLookupKey(const IpEndpoint& ip) : sep(0)
00039 {
00040 static const char hextab[16] = {
00041 '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
00042 };
00043
00044 int nbytes;
00045 uint16_t port = ntohs(ip.port());
00046
00047
00048
00049
00050 nbytes = ats_ip_to_hex(&ip.sa, key, sizeof(key));
00051 if (port) {
00052 sep = nbytes;
00053 key[nbytes++] = '.';
00054 key[nbytes++] = hextab[ (port >> 12) & 0x000F ];
00055 key[nbytes++] = hextab[ (port >> 8) & 0x000F ];
00056 key[nbytes++] = hextab[ (port >> 4) & 0x000F ];
00057 key[nbytes++] = hextab[ (port ) & 0x000F ];
00058 }
00059 key[nbytes++] = 0;
00060 }
00061
00062 const char * get() const { return key; }
00063 void split() { key[sep] = '\0'; }
00064 void unsplit() { key[sep] = '.'; }
00065
00066 private:
00067 char key[(TS_IP6_SIZE * 2) + 1 + 4 + 1 ];
00068 unsigned char sep;
00069 };
00070
00071 struct SSLContextStorage
00072 {
00073 SSLContextStorage();
00074 ~SSLContextStorage();
00075
00076 bool insert(SSL_CTX * ctx, const char * name);
00077 SSL_CTX * lookup(const char * name) const;
00078 unsigned count() const { return this->references.count(); }
00079 SSL_CTX * get(unsigned i) const { return this->references[i]; }
00080
00081 private:
00082 struct SSLEntry
00083 {
00084 explicit SSLEntry(SSL_CTX * c) : ctx(c) {}
00085
00086 void Print() const { Debug("ssl", "SSLEntry=%p SSL_CTX=%p", this, ctx); }
00087
00088 SSL_CTX * ctx;
00089 LINK(SSLEntry, link);
00090 };
00091
00092 Trie<SSLEntry> wildcards;
00093 InkHashTable * hostnames;
00094 Vec<SSL_CTX *> references;
00095 };
00096
00097 SSLCertLookup::SSLCertLookup()
00098 : ssl_storage(new SSLContextStorage()), ssl_default(NULL)
00099 {
00100 }
00101
00102 SSLCertLookup::~SSLCertLookup()
00103 {
00104 delete this->ssl_storage;
00105 }
00106
00107 SSL_CTX *
00108 SSLCertLookup::findInfoInHash(const char * address) const
00109 {
00110 return this->ssl_storage->lookup(address);
00111 }
00112
00113 SSL_CTX *
00114 SSLCertLookup::findInfoInHash(const IpEndpoint& address) const
00115 {
00116 SSL_CTX * ctx;
00117 SSLAddressLookupKey key(address);
00118
00119
00120 if ((ctx = this->ssl_storage->lookup(key.get()))) {
00121 return ctx;
00122 }
00123
00124
00125 if (address.port()) {
00126 key.split();
00127 return this->ssl_storage->lookup(key.get());
00128 }
00129
00130 return NULL;
00131 }
00132
00133 bool
00134 SSLCertLookup::insert(SSL_CTX * ctx, const char * name)
00135 {
00136 return this->ssl_storage->insert(ctx, name);
00137 }
00138
00139 bool
00140 SSLCertLookup::insert(SSL_CTX * ctx, const IpEndpoint& address)
00141 {
00142 SSLAddressLookupKey key(address);
00143 return this->ssl_storage->insert(ctx, key.get());
00144 }
00145
00146 unsigned
00147 SSLCertLookup::count() const
00148 {
00149 return ssl_storage->count();
00150 }
00151
00152 SSL_CTX *
00153 SSLCertLookup::get(unsigned i) const
00154 {
00155 return ssl_storage->get(i);
00156 }
00157
00158 struct ats_wildcard_matcher
00159 {
00160 ats_wildcard_matcher() {
00161 if (regex.compile("^\\*\\.[^\\*.]+") != 0) {
00162 Fatal("failed to compile TLS wildcard matching regex");
00163 }
00164 }
00165
00166 ~ats_wildcard_matcher() {
00167 }
00168
00169 bool match(const char * hostname) const {
00170 return regex.match(hostname) != -1;
00171 }
00172
00173 private:
00174 DFA regex;
00175 };
00176
00177 static char *
00178 reverse_dns_name(const char * hostname, char (&reversed)[TS_MAX_HOST_NAME_LEN+1])
00179 {
00180 char * ptr = reversed + sizeof(reversed);
00181 const char * part = hostname;
00182
00183 *(--ptr) = '\0';
00184
00185 while (*part) {
00186 ssize_t len = strcspn(part, ".");
00187 ssize_t remain = ptr - reversed;
00188
00189 if (remain < (len + 1)) {
00190 return NULL;
00191 }
00192
00193 ptr -= len;
00194 memcpy(ptr, part, len);
00195
00196
00197
00198 part += len;
00199 if (*part == '.') {
00200 ++part;
00201 *(--ptr) = '.';
00202 }
00203 }
00204
00205 return ptr;
00206 }
00207
00208 SSLContextStorage::SSLContextStorage()
00209 :wildcards(), hostnames(ink_hash_table_create(InkHashTableKeyType_String))
00210 {
00211 }
00212
00213 SSLContextStorage::~SSLContextStorage()
00214 {
00215 for (unsigned i = 0; i < this->references.count(); ++i) {
00216 SSLReleaseContext(this->references[i]);
00217 }
00218
00219 ink_hash_table_destroy(this->hostnames);
00220 }
00221
00222 bool
00223 SSLContextStorage::insert(SSL_CTX * ctx, const char * name)
00224 {
00225 ats_wildcard_matcher wildcard;
00226 bool inserted = false;
00227
00228 if (wildcard.match(name)) {
00229
00230
00231 char namebuf[TS_MAX_HOST_NAME_LEN + 1];
00232 char * reversed;
00233 ats_scoped_obj<SSLEntry> entry;
00234
00235 reversed = reverse_dns_name(name + 1, namebuf);
00236 if (!reversed) {
00237 Error("wildcard name '%s' is too long", name);
00238 return false;
00239 }
00240
00241 entry = new SSLEntry(ctx);
00242 inserted = this->wildcards.Insert(reversed, entry, 0 , -1 );
00243 if (!inserted) {
00244 SSLEntry * found;
00245
00246
00247 found = this->wildcards.Search(reversed);
00248 if (found != NULL && found->ctx != ctx) {
00249 Warning("previously indexed wildcard certificate for '%s' as '%s', cannot index it with SSL_CTX %p now",
00250 name, reversed, ctx);
00251 }
00252
00253 goto done;
00254 }
00255
00256 Debug("ssl", "indexed wildcard certificate for '%s' as '%s' with SSL_CTX %p", name, reversed, ctx);
00257 entry.release();
00258 } else {
00259 InkHashTableValue value;
00260
00261 if (ink_hash_table_lookup(this->hostnames, name, &value) && (void *)ctx != value) {
00262 Warning("previously indexed '%s' with SSL_CTX %p, cannot index it with SSL_CTX %p now", name, value, ctx);
00263 goto done;
00264 }
00265
00266 inserted = true;
00267 ink_hash_table_insert(this->hostnames, name, (void *)ctx);
00268 Debug("ssl", "indexed '%s' with SSL_CTX %p", name, ctx);
00269 }
00270
00271 done:
00272
00273
00274
00275 if (inserted) {
00276 if (this->references.in(ctx) == NULL) {
00277 this->references.push_back(ctx);
00278 }
00279 }
00280
00281 return inserted;
00282 }
00283
00284 SSL_CTX *
00285 SSLContextStorage::lookup(const char * name) const
00286 {
00287 InkHashTableValue value;
00288
00289 if (ink_hash_table_lookup(const_cast<InkHashTable *>(this->hostnames), name, &value)) {
00290 return (SSL_CTX *)value;
00291 }
00292
00293 if (!this->wildcards.Empty()) {
00294 char namebuf[TS_MAX_HOST_NAME_LEN + 1];
00295 char * reversed;
00296 SSLEntry * entry;
00297
00298 reversed = reverse_dns_name(name, namebuf);
00299 if (!reversed) {
00300 Error("failed to reverse hostname name '%s' is too long", name);
00301 return NULL;
00302 }
00303
00304 Debug("ssl", "attempting wildcard match for %s", reversed);
00305 entry = this->wildcards.Search(reversed);
00306 if (entry) {
00307 return entry->ctx;
00308 }
00309 }
00310
00311 return NULL;
00312 }
00313
00314 #if TS_HAS_TESTS
00315
00316 REGRESSION_TEST(SSLWildcardMatch)(RegressionTest * t, int , int * pstatus)
00317 {
00318 TestBox box(t, pstatus);
00319 ats_wildcard_matcher wildcard;
00320
00321 box = REGRESSION_TEST_PASSED;
00322
00323 box.check(wildcard.match("foo.com") == false, "foo.com is not a wildcard");
00324 box.check(wildcard.match("*.foo.com") == true, "*.foo.com not a wildcard");
00325 box.check(wildcard.match("bar*.foo.com") == false, "bar*.foo.com not a wildcard");
00326 box.check(wildcard.match("*") == false, "* is not a wildcard");
00327 box.check(wildcard.match("") == false, "'' is not a wildcard");
00328 }
00329
00330 REGRESSION_TEST(SSLReverseHostname)(RegressionTest * t, int , int * pstatus)
00331 {
00332 TestBox box(t, pstatus);
00333
00334 char reversed[TS_MAX_HOST_NAME_LEN + 1];
00335
00336 #define _R(name) reverse_dns_name(name, reversed)
00337
00338 box = REGRESSION_TEST_PASSED;
00339
00340 box.check(strcmp(_R("foo.com"), "com.foo") == 0, "reversed foo.com");
00341 box.check(strcmp(_R("bar.foo.com"), "com.foo.bar") == 0, "reversed bar.foo.com");
00342 box.check(strcmp(_R("foo"), "foo") == 0, "reversed foo");
00343
00344 #undef _R
00345 }
00346
00347 #endif // TS_HAS_TESTS