00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 #include "ink_config.h"
00025 
00026 #include "P_SSLCertLookup.h"
00027 #include "P_SSLUtils.h"
00028 #include "P_SSLConfig.h"
00029 #include "I_EventSystem.h"
00030 #include "I_Layout.h"
00031 #include "Regex.h"
00032 #include "Trie.h"
00033 #include "ts/TestBox.h"
00034 
00035 struct SSLAddressLookupKey
00036 {
00037   explicit
00038   SSLAddressLookupKey(const IpEndpoint& ip) : sep(0)
00039   {
00040     static const char hextab[16] = {
00041       '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
00042     };
00043 
00044     int nbytes;
00045     uint16_t port = ntohs(ip.port());
00046 
00047     
00048     
00049     
00050     nbytes = ats_ip_to_hex(&ip.sa, key, sizeof(key));
00051     if (port) {
00052       sep = nbytes;
00053       key[nbytes++] = '.';
00054       key[nbytes++] = hextab[ (port >> 12) & 0x000F ];
00055       key[nbytes++] = hextab[ (port >>  8) & 0x000F ];
00056       key[nbytes++] = hextab[ (port >>  4) & 0x000F ];
00057       key[nbytes++] = hextab[ (port      ) & 0x000F ];
00058     }
00059     key[nbytes++] = 0;
00060   }
00061 
00062   const char * get() const { return key; }
00063   void split() { key[sep] = '\0'; }
00064   void unsplit() { key[sep] = '.'; }
00065 
00066 private:
00067   char key[(TS_IP6_SIZE * 2)  + 1  + 4  + 1 ];
00068   unsigned char sep; 
00069 };
00070 
00071 struct SSLContextStorage
00072 {
00073   SSLContextStorage();
00074   ~SSLContextStorage();
00075 
00076   bool insert(SSL_CTX * ctx, const char * name);
00077   SSL_CTX * lookup(const char * name) const;
00078   unsigned count() const { return this->references.count(); }
00079   SSL_CTX * get(unsigned i) const { return this->references[i]; }
00080 
00081 private:
00082   struct SSLEntry
00083   {
00084     explicit SSLEntry(SSL_CTX * c) : ctx(c) {}
00085 
00086     void Print() const { Debug("ssl", "SSLEntry=%p SSL_CTX=%p", this, ctx); }
00087 
00088     SSL_CTX * ctx;
00089     LINK(SSLEntry, link);
00090   };
00091 
00092   Trie<SSLEntry>  wildcards;
00093   InkHashTable *  hostnames;
00094   Vec<SSL_CTX *>  references;
00095 };
00096 
00097 SSLCertLookup::SSLCertLookup()
00098   : ssl_storage(new SSLContextStorage()), ssl_default(NULL)
00099 {
00100 }
00101 
00102 SSLCertLookup::~SSLCertLookup()
00103 {
00104   delete this->ssl_storage;
00105 }
00106 
00107 SSL_CTX *
00108 SSLCertLookup::findInfoInHash(const char * address) const
00109 {
00110   return this->ssl_storage->lookup(address);
00111 }
00112 
00113 SSL_CTX *
00114 SSLCertLookup::findInfoInHash(const IpEndpoint& address) const
00115 {
00116   SSL_CTX * ctx;
00117   SSLAddressLookupKey key(address);
00118 
00119   
00120   if ((ctx = this->ssl_storage->lookup(key.get()))) {
00121     return ctx;
00122   }
00123 
00124   
00125   if (address.port()) {
00126     key.split();
00127     return this->ssl_storage->lookup(key.get());
00128   }
00129 
00130   return NULL;
00131 }
00132 
00133 bool
00134 SSLCertLookup::insert(SSL_CTX * ctx, const char * name)
00135 {
00136   return this->ssl_storage->insert(ctx, name);
00137 }
00138 
00139 bool
00140 SSLCertLookup::insert(SSL_CTX * ctx, const IpEndpoint& address)
00141 {
00142   SSLAddressLookupKey key(address);
00143   return this->ssl_storage->insert(ctx, key.get());
00144 }
00145 
00146 unsigned
00147 SSLCertLookup::count() const
00148 {
00149   return ssl_storage->count();
00150 }
00151 
00152 SSL_CTX *
00153 SSLCertLookup::get(unsigned i) const
00154 {
00155   return ssl_storage->get(i);
00156 }
00157 
00158 struct ats_wildcard_matcher
00159 {
00160   ats_wildcard_matcher() {
00161     if (regex.compile("^\\*\\.[^\\*.]+") != 0) {
00162       Fatal("failed to compile TLS wildcard matching regex");
00163     }
00164   }
00165 
00166   ~ats_wildcard_matcher() {
00167   }
00168 
00169   bool match(const char * hostname) const {
00170     return regex.match(hostname) != -1;
00171   }
00172 
00173 private:
00174   DFA regex;
00175 };
00176 
00177 static char *
00178 reverse_dns_name(const char * hostname, char (&reversed)[TS_MAX_HOST_NAME_LEN+1])
00179 {
00180   char * ptr = reversed + sizeof(reversed);
00181   const char * part = hostname;
00182 
00183   *(--ptr) = '\0'; 
00184 
00185   while (*part) {
00186     ssize_t len = strcspn(part, ".");
00187     ssize_t remain = ptr - reversed;
00188 
00189     if (remain < (len + 1)) {
00190       return NULL;
00191     }
00192 
00193     ptr -= len;
00194     memcpy(ptr, part, len);
00195 
00196     
00197     
00198     part += len;
00199     if (*part == '.') {
00200       ++part;
00201       *(--ptr) = '.';
00202     }
00203   }
00204 
00205   return ptr;
00206 }
00207 
00208 SSLContextStorage::SSLContextStorage()
00209   :wildcards(), hostnames(ink_hash_table_create(InkHashTableKeyType_String))
00210 {
00211 }
00212 
00213 SSLContextStorage::~SSLContextStorage()
00214 {
00215   for (unsigned i = 0; i < this->references.count(); ++i) {
00216     SSLReleaseContext(this->references[i]);
00217   }
00218 
00219   ink_hash_table_destroy(this->hostnames);
00220 }
00221 
00222 bool
00223 SSLContextStorage::insert(SSL_CTX * ctx, const char * name)
00224 {
00225   ats_wildcard_matcher wildcard;
00226   bool inserted = false;
00227 
00228   if (wildcard.match(name)) {
00229     
00230     
00231     char namebuf[TS_MAX_HOST_NAME_LEN + 1];
00232     char * reversed;
00233     ats_scoped_obj<SSLEntry> entry;
00234 
00235     reversed = reverse_dns_name(name + 1, namebuf);
00236     if (!reversed) {
00237       Error("wildcard name '%s' is too long", name);
00238       return false;
00239     }
00240 
00241     entry = new SSLEntry(ctx);
00242     inserted = this->wildcards.Insert(reversed, entry, 0 , -1 );
00243     if (!inserted) {
00244       SSLEntry * found;
00245 
00246       
00247       found = this->wildcards.Search(reversed);
00248       if (found != NULL && found->ctx != ctx) {
00249         Warning("previously indexed wildcard certificate for '%s' as '%s', cannot index it with SSL_CTX %p now",
00250             name, reversed, ctx);
00251       }
00252 
00253       goto done;
00254     }
00255 
00256     Debug("ssl", "indexed wildcard certificate for '%s' as '%s' with SSL_CTX %p", name, reversed, ctx);
00257     entry.release();
00258   } else {
00259     InkHashTableValue value;
00260 
00261     if (ink_hash_table_lookup(this->hostnames, name, &value) && (void *)ctx != value) {
00262       Warning("previously indexed '%s' with SSL_CTX %p, cannot index it with SSL_CTX %p now", name, value, ctx);
00263       goto done;
00264     }
00265 
00266     inserted = true;
00267     ink_hash_table_insert(this->hostnames, name, (void *)ctx);
00268     Debug("ssl", "indexed '%s' with SSL_CTX %p", name, ctx);
00269   }
00270 
00271 done:
00272   
00273   
00274   
00275   if (inserted) {
00276     if (this->references.in(ctx) == NULL) {
00277       this->references.push_back(ctx);
00278     }
00279   }
00280 
00281   return inserted;
00282 }
00283 
00284 SSL_CTX *
00285 SSLContextStorage::lookup(const char * name) const
00286 {
00287   InkHashTableValue value;
00288 
00289   if (ink_hash_table_lookup(const_cast<InkHashTable *>(this->hostnames), name, &value)) {
00290     return (SSL_CTX *)value;
00291   }
00292 
00293   if (!this->wildcards.Empty()) {
00294     char namebuf[TS_MAX_HOST_NAME_LEN + 1];
00295     char * reversed;
00296     SSLEntry * entry;
00297 
00298     reversed = reverse_dns_name(name, namebuf);
00299     if (!reversed) {
00300       Error("failed to reverse hostname name '%s' is too long", name);
00301       return NULL;
00302     }
00303 
00304     Debug("ssl", "attempting wildcard match for %s", reversed);
00305     entry = this->wildcards.Search(reversed);
00306     if (entry) {
00307       return entry->ctx;
00308     }
00309   }
00310 
00311   return NULL;
00312 }
00313 
00314 #if TS_HAS_TESTS
00315 
00316 REGRESSION_TEST(SSLWildcardMatch)(RegressionTest * t, int , int * pstatus)
00317 {
00318   TestBox box(t, pstatus);
00319   ats_wildcard_matcher wildcard;
00320 
00321   box = REGRESSION_TEST_PASSED;
00322 
00323   box.check(wildcard.match("foo.com") == false, "foo.com is not a wildcard");
00324   box.check(wildcard.match("*.foo.com") == true, "*.foo.com not a wildcard");
00325   box.check(wildcard.match("bar*.foo.com") == false, "bar*.foo.com not a wildcard");
00326   box.check(wildcard.match("*") == false, "* is not a wildcard");
00327   box.check(wildcard.match("") == false, "'' is not a wildcard");
00328 }
00329 
00330 REGRESSION_TEST(SSLReverseHostname)(RegressionTest * t, int , int * pstatus)
00331 {
00332   TestBox box(t, pstatus);
00333 
00334   char reversed[TS_MAX_HOST_NAME_LEN + 1];
00335 
00336 #define _R(name) reverse_dns_name(name, reversed)
00337 
00338   box = REGRESSION_TEST_PASSED;
00339 
00340   box.check(strcmp(_R("foo.com"), "com.foo") == 0, "reversed foo.com");
00341   box.check(strcmp(_R("bar.foo.com"), "com.foo.bar") == 0, "reversed bar.foo.com");
00342   box.check(strcmp(_R("foo"), "foo") == 0, "reversed foo");
00343 
00344 #undef _R
00345 }
00346 
00347 #endif // TS_HAS_TESTS