001/** 002 * Copyright (C) 2012 FuseSource, Inc. 003 * http://fusesource.com 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.fusesource.hawtdispatch.transport; 019 020import org.fusesource.hawtdispatch.Task; 021 022import javax.net.ssl.*; 023import java.io.EOFException; 024import java.io.IOException; 025import java.net.Socket; 026import java.net.URI; 027import java.nio.ByteBuffer; 028import java.nio.channels.*; 029import java.security.cert.Certificate; 030import java.security.cert.X509Certificate; 031import java.util.ArrayList; 032 033import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*; 034import static javax.net.ssl.SSLEngineResult.Status.*; 035 036/** 037 * An SSL Transport for secure communications. 038 * 039 * @author <a href="http://hiramchirino.com">Hiram Chirino</a> 040 */ 041public class SslTransport extends TcpTransport implements SecuredSession { 042 043 /** 044 * Maps uri schemes to a protocol algorithm names. 045 * Valid algorithm names listed at: 046 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext 047 */ 048 public static String protocol(String scheme) { 049 if( scheme.equals("tls") ) { 050 return "TLS"; 051 } else if( scheme.startsWith("tlsv") ) { 052 return "TLSv"+scheme.substring(4); 053 } else if( scheme.equals("ssl") ) { 054 return "SSL"; 055 } else if( scheme.startsWith("sslv") ) { 056 return "SSLv"+scheme.substring(4); 057 } 058 return null; 059 } 060 061 enum ClientAuth { 062 WANT, NEED, NONE 063 }; 064 065 private ClientAuth clientAuth = ClientAuth.WANT; 066 private String disabledCypherSuites = null; 067 068 private SSLContext sslContext; 069 private SSLEngine engine; 070 071 private ByteBuffer readBuffer; 072 private boolean readUnderflow; 073 074 private ByteBuffer writeBuffer; 075 private boolean writeFlushing; 076 077 private ByteBuffer readOverflowBuffer; 078 private SSLChannel ssl_channel = new SSLChannel(); 079 080 081 public void setSSLContext(SSLContext ctx) { 082 this.sslContext = ctx; 083 } 084 085 /** 086 * Allows subclasses of TcpTransportFactory to create custom instances of 087 * TcpTransport. 088 */ 089 public static SslTransport createTransport(URI uri) throws Exception { 090 String protocol = protocol(uri.getScheme()); 091 if( protocol !=null ) { 092 SslTransport rc = new SslTransport(); 093 rc.setSSLContext(SSLContext.getInstance(protocol)); 094 return rc; 095 } 096 return null; 097 } 098 099 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel { 100 101 public int write(ByteBuffer plain) throws IOException { 102 return secure_write(plain); 103 } 104 105 public int read(ByteBuffer plain) throws IOException { 106 return secure_read(plain); 107 } 108 109 public boolean isOpen() { 110 return getSocketChannel().isOpen(); 111 } 112 113 public void close() throws IOException { 114 getSocketChannel().close(); 115 } 116 117 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { 118 if(offset+length > srcs.length || length<0 || offset<0) { 119 throw new IndexOutOfBoundsException(); 120 } 121 long rc=0; 122 for (int i = 0; i < length; i++) { 123 ByteBuffer src = srcs[offset+i]; 124 if(src.hasRemaining()) { 125 rc += write(src); 126 } 127 if( src.hasRemaining() ) { 128 return rc; 129 } 130 } 131 return rc; 132 } 133 134 public long write(ByteBuffer[] srcs) throws IOException { 135 return write(srcs, 0, srcs.length); 136 } 137 138 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { 139 if(offset+length > dsts.length || length<0 || offset<0) { 140 throw new IndexOutOfBoundsException(); 141 } 142 long rc=0; 143 for (int i = 0; i < length; i++) { 144 ByteBuffer dst = dsts[offset+i]; 145 if(dst.hasRemaining()) { 146 rc += read(dst); 147 } 148 if( dst.hasRemaining() ) { 149 return rc; 150 } 151 } 152 return rc; 153 } 154 155 public long read(ByteBuffer[] dsts) throws IOException { 156 return read(dsts, 0, dsts.length); 157 } 158 159 public Socket socket() { 160 SocketChannel c = channel; 161 if( c == null ) { 162 return null; 163 } 164 return c.socket(); 165 } 166 } 167 168 public SSLSession getSSLSession() { 169 return engine==null ? null : engine.getSession(); 170 } 171 172 public X509Certificate[] getPeerX509Certificates() { 173 if( engine==null ) { 174 return null; 175 } 176 try { 177 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>(); 178 for( Certificate c:engine.getSession().getPeerCertificates() ) { 179 if(c instanceof X509Certificate) { 180 rc.add((X509Certificate) c); 181 } 182 } 183 return rc.toArray(new X509Certificate[rc.size()]); 184 } catch (SSLPeerUnverifiedException e) { 185 return null; 186 } 187 } 188 189 @Override 190 public void connecting(URI remoteLocation, URI localLocation) throws Exception { 191 assert engine == null; 192 engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort()); 193 engine.setUseClientMode(true); 194 super.connecting(remoteLocation, localLocation); 195 } 196 197 @Override 198 public void connected(SocketChannel channel) throws Exception { 199 if (engine == null) { 200 engine = sslContext.createSSLEngine(); 201 engine.setUseClientMode(false); 202 switch (clientAuth) { 203 case WANT: engine.setWantClientAuth(true); break; 204 case NEED: engine.setNeedClientAuth(true); break; 205 case NONE: engine.setWantClientAuth(false); break; 206 } 207 208 } 209 210 if( disabledCypherSuites!=null ) { 211 ArrayList<String> disabledList = new ArrayList<String>(); 212 for( String x : disabledCypherSuites.split(",") ) { 213 disabledList.add(x.trim()); 214 } 215 ArrayList<String> enabled = new ArrayList<String>(); 216 for (String suite : engine.getSupportedCipherSuites()) { 217 boolean add = true; 218 for (String disabled : disabledList) { 219 if( suite.contains(disabled) ) { 220 add = false; 221 break; 222 } 223 } 224 if( add ) { 225 enabled.add(suite); 226 } 227 } 228 engine.setEnabledCipherSuites(enabled.toArray(new String[enabled.size()])); 229 } 230 231 super.connected(channel); 232 } 233 234 @Override 235 protected void initializeChannel() throws Exception { 236 super.initializeChannel(); 237 SSLSession session = engine.getSession(); 238 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 239 readBuffer.flip(); 240 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 241 } 242 243 @Override 244 protected void onConnected() throws IOException { 245 super.onConnected(); 246 engine.beginHandshake(); 247 handshake(); 248 } 249 250 @Override 251 public void flush() { 252 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 253 handshake(); 254 } else { 255 super.flush(); 256 } 257 } 258 259 @Override 260 public void drainInbound() { 261 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 262 handshake(); 263 } else { 264 super.drainInbound(); 265 } 266 } 267 268 /** 269 * @return true if fully flushed. 270 * @throws IOException 271 */ 272 protected boolean transportFlush() throws IOException { 273 while (true) { 274 if(writeFlushing) { 275 int count = super.getWriteChannel().write(writeBuffer); 276 if( !writeBuffer.hasRemaining() ) { 277 writeBuffer.clear(); 278 writeFlushing = false; 279 suspendWrite(); 280 return true; 281 } else { 282 return false; 283 } 284 } else { 285 if( writeBuffer.position()!=0 ) { 286 writeBuffer.flip(); 287 writeFlushing = true; 288 resumeWrite(); 289 } else { 290 return true; 291 } 292 } 293 } 294 } 295 296 private int secure_write(ByteBuffer plain) throws IOException { 297 if( !transportFlush() ) { 298 // can't write anymore until the write_secured_buffer gets fully flushed out.. 299 return 0; 300 } 301 int rc = 0; 302 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) { 303 SSLEngineResult result = engine.wrap(plain, writeBuffer); 304 assert result.getStatus()!= BUFFER_OVERFLOW; 305 rc += result.bytesConsumed(); 306 if( !transportFlush() || result.getStatus() == CLOSED) { 307 break; 308 } 309 } 310 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 311 dispatchQueue.execute(new Task() { 312 public void run() { 313 handshake(); 314 } 315 }); 316 } 317 return rc; 318 } 319 320 private int secure_read(ByteBuffer plain) throws IOException { 321 int rc=0; 322 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) { 323 if( readOverflowBuffer !=null ) { 324 if( plain.hasRemaining() ) { 325 // lets drain the overflow buffer before trying to suck down anymore 326 // network bytes. 327 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining()); 328 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size); 329 readOverflowBuffer.position(readOverflowBuffer.position()+size); 330 if( !readOverflowBuffer.hasRemaining() ) { 331 readOverflowBuffer = null; 332 } 333 rc += size; 334 } else { 335 return rc; 336 } 337 } else if( readUnderflow ) { 338 int count = super.getReadChannel().read(readBuffer); 339 if( count == -1 ) { // peer closed socket. 340 if (rc==0) { 341 return -1; 342 } else { 343 return rc; 344 } 345 } 346 if( count==0 ) { // no data available right now. 347 return rc; 348 } 349 // read in some more data, perhaps now we can unwrap. 350 readUnderflow = false; 351 readBuffer.flip(); 352 } else { 353 SSLEngineResult result = engine.unwrap(readBuffer, plain); 354 rc += result.bytesProduced(); 355 if( result.getStatus() == BUFFER_OVERFLOW ) { 356 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); 357 result = engine.unwrap(readBuffer, readOverflowBuffer); 358 if( readOverflowBuffer.position()==0 ) { 359 readOverflowBuffer = null; 360 } else { 361 readOverflowBuffer.flip(); 362 } 363 } 364 switch( result.getStatus() ) { 365 case CLOSED: 366 if (rc==0) { 367 engine.closeInbound(); 368 return -1; 369 } else { 370 return rc; 371 } 372 case OK: 373 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 374 dispatchQueue.execute(new Task() { 375 public void run() { 376 handshake(); 377 } 378 }); 379 } 380 break; 381 case BUFFER_UNDERFLOW: 382 readBuffer.compact(); 383 readUnderflow = true; 384 break; 385 case BUFFER_OVERFLOW: 386 throw new AssertionError("Unexpected case."); 387 } 388 } 389 } 390 return rc; 391 } 392 393 public void handshake() { 394 try { 395 if( !transportFlush() ) { 396 return; 397 } 398 switch (engine.getHandshakeStatus()) { 399 case NEED_TASK: 400 final Runnable task = engine.getDelegatedTask(); 401 if( task!=null ) { 402 blockingExecutor.execute(new Task() { 403 public void run() { 404 task.run(); 405 dispatchQueue.execute(new Task() { 406 public void run() { 407 if (isConnected()) { 408 handshake(); 409 } 410 } 411 }); 412 } 413 }); 414 } 415 break; 416 417 case NEED_WRAP: 418 secure_write(ByteBuffer.allocate(0)); 419 break; 420 421 case NEED_UNWRAP: 422 if( secure_read(ByteBuffer.allocate(0)) == -1) { 423 throw new EOFException("Peer disconnected during ssl handshake"); 424 } 425 break; 426 427 case FINISHED: 428 case NOT_HANDSHAKING: 429 break; 430 431 default: 432 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus()); 433 break; 434 } 435 } catch (IOException e ) { 436 onTransportFailure(e); 437 } finally { 438 if( engine.getHandshakeStatus() == NOT_HANDSHAKING ) { 439 drainOutboundSource.merge(1); 440 super.drainInbound(); 441 } 442 } 443 } 444 445 446 public ReadableByteChannel getReadChannel() { 447 return ssl_channel; 448 } 449 450 public WritableByteChannel getWriteChannel() { 451 return ssl_channel; 452 } 453 454 public String getClientAuth() { 455 return clientAuth.name(); 456 } 457 458 public void setClientAuth(String clientAuth) { 459 this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase()); 460 } 461 462 public String getDisabledCypherSuites() { 463 return disabledCypherSuites; 464 } 465 466 public void setDisabledCypherSuites(String disabledCypherSuites) { 467 this.disabledCypherSuites = disabledCypherSuites; 468 } 469} 470 471