001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.Task;
021
022import javax.net.ssl.*;
023import java.io.EOFException;
024import java.io.IOException;
025import java.net.Socket;
026import java.net.URI;
027import java.nio.ByteBuffer;
028import java.nio.channels.*;
029import java.security.cert.Certificate;
030import java.security.cert.X509Certificate;
031import java.util.ArrayList;
032
033import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
034import static javax.net.ssl.SSLEngineResult.Status.*;
035
036/**
037 * An SSL Transport for secure communications.
038 *
039 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
040 */
041public class SslTransport extends TcpTransport implements SecuredSession {
042
043    /**
044     * Maps uri schemes to a protocol algorithm names.
045     * Valid algorithm names listed at:
046     * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
047     */
048    public static String protocol(String scheme) {
049        if( scheme.equals("tls") ) {
050            return "TLS";
051        } else if( scheme.startsWith("tlsv") ) {
052            return "TLSv"+scheme.substring(4);
053        } else if( scheme.equals("ssl") ) {
054            return "SSL";
055        } else if( scheme.startsWith("sslv") ) {
056            return "SSLv"+scheme.substring(4);
057        }
058        return null;
059    }
060
061    enum ClientAuth {
062        WANT, NEED, NONE
063    };
064
065    private ClientAuth clientAuth = ClientAuth.WANT;
066    private String disabledCypherSuites = null;
067
068    private SSLContext sslContext;
069    private SSLEngine engine;
070
071    private ByteBuffer readBuffer;
072    private boolean readUnderflow;
073
074    private ByteBuffer writeBuffer;
075    private boolean writeFlushing;
076
077    private ByteBuffer readOverflowBuffer;
078    private SSLChannel ssl_channel = new SSLChannel();
079
080
081    public void setSSLContext(SSLContext ctx) {
082        this.sslContext = ctx;
083    }
084
085    /**
086     * Allows subclasses of TcpTransportFactory to create custom instances of
087     * TcpTransport.
088     */
089    public static SslTransport createTransport(URI uri) throws Exception {
090        String protocol = protocol(uri.getScheme());
091        if( protocol !=null ) {
092            SslTransport rc = new SslTransport();
093            rc.setSSLContext(SSLContext.getInstance(protocol));
094            return rc;
095        }
096        return null;
097    }
098
099    public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
100
101        public int write(ByteBuffer plain) throws IOException {
102            return secure_write(plain);
103        }
104
105        public int read(ByteBuffer plain) throws IOException {
106            return secure_read(plain);
107        }
108
109        public boolean isOpen() {
110            return getSocketChannel().isOpen();
111        }
112
113        public void close() throws IOException {
114            getSocketChannel().close();
115        }
116
117        public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
118            if(offset+length > srcs.length || length<0 || offset<0) {
119                throw new IndexOutOfBoundsException();
120            }
121            long rc=0;
122            for (int i = 0; i < length; i++) {
123                ByteBuffer src = srcs[offset+i];
124                if(src.hasRemaining()) {
125                    rc += write(src);
126                }
127                if( src.hasRemaining() ) {
128                    return rc;
129                }
130            }
131            return rc;
132        }
133
134        public long write(ByteBuffer[] srcs) throws IOException {
135            return write(srcs, 0, srcs.length);
136        }
137
138        public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
139            if(offset+length > dsts.length || length<0 || offset<0) {
140                throw new IndexOutOfBoundsException();
141            }
142            long rc=0;
143            for (int i = 0; i < length; i++) {
144                ByteBuffer dst = dsts[offset+i];
145                if(dst.hasRemaining()) {
146                    rc += read(dst);
147                }
148                if( dst.hasRemaining() ) {
149                    return rc;
150                }
151            }
152            return rc;
153        }
154
155        public long read(ByteBuffer[] dsts) throws IOException {
156            return read(dsts, 0, dsts.length);
157        }
158        
159        public Socket socket() {
160            SocketChannel c = channel;
161            if( c == null ) {
162                return null;
163            }
164            return c.socket();
165        }
166    }
167
168    public SSLSession getSSLSession() {
169        return engine==null ? null : engine.getSession();
170    }
171
172    public X509Certificate[] getPeerX509Certificates() {
173        if( engine==null ) {
174            return null;
175        }
176        try {
177            ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
178            for( Certificate c:engine.getSession().getPeerCertificates() ) {
179                if(c instanceof X509Certificate) {
180                    rc.add((X509Certificate) c);
181                }
182            }
183            return rc.toArray(new X509Certificate[rc.size()]);
184        } catch (SSLPeerUnverifiedException e) {
185            return null;
186        }
187    }
188
189    @Override
190    public void connecting(URI remoteLocation, URI localLocation) throws Exception {
191        assert engine == null;
192        engine = sslContext.createSSLEngine(remoteLocation.getHost(), remoteLocation.getPort());
193        engine.setUseClientMode(true);
194        super.connecting(remoteLocation, localLocation);
195    }
196
197    @Override
198    public void connected(SocketChannel channel) throws Exception {
199        if (engine == null) {
200            engine = sslContext.createSSLEngine();
201            engine.setUseClientMode(false);
202            switch (clientAuth) {
203                case WANT: engine.setWantClientAuth(true); break;
204                case NEED: engine.setNeedClientAuth(true); break;
205                case NONE: engine.setWantClientAuth(false); break;
206            }
207
208        }
209
210        if( disabledCypherSuites!=null ) {
211            ArrayList<String> disabledList = new ArrayList<String>();
212            for( String x : disabledCypherSuites.split(",") ) {
213                disabledList.add(x.trim());
214            }
215            ArrayList<String> enabled = new ArrayList<String>();
216            for (String suite : engine.getSupportedCipherSuites()) {
217                boolean add = true;
218                for (String disabled : disabledList) {
219                    if( suite.contains(disabled) ) {
220                        add = false;
221                        break;
222                    }
223                }
224                if( add ) {
225                    enabled.add(suite);
226                }
227            }
228            engine.setEnabledCipherSuites(enabled.toArray(new String[enabled.size()]));
229        }
230
231        super.connected(channel);
232    }
233
234    @Override
235    protected void initializeChannel() throws Exception {
236        super.initializeChannel();
237        SSLSession session = engine.getSession();
238        readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
239        readBuffer.flip();
240        writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
241    }
242
243    @Override
244    protected void onConnected() throws IOException {
245        super.onConnected();
246        engine.beginHandshake();
247        handshake();
248    }
249
250    @Override
251    public void flush() {
252        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
253            handshake();
254        } else {
255            super.flush();
256        }
257    }
258
259    @Override
260    public void drainInbound() {
261        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
262            handshake();
263        } else {
264            super.drainInbound();
265        }
266    }
267
268    /**
269     * @return true if fully flushed.
270     * @throws IOException
271     */
272    protected boolean transportFlush() throws IOException {
273        while (true) {
274            if(writeFlushing) {
275                int count = super.getWriteChannel().write(writeBuffer);
276                if( !writeBuffer.hasRemaining() ) {
277                    writeBuffer.clear();
278                    writeFlushing = false;
279                    suspendWrite();
280                    return true;
281                } else {
282                    return false;
283                }
284            } else {
285                if( writeBuffer.position()!=0 ) {
286                    writeBuffer.flip();
287                    writeFlushing = true;
288                    resumeWrite();
289                } else {
290                    return true;
291                }
292            }
293        }
294    }
295
296    private int secure_write(ByteBuffer plain) throws IOException {
297        if( !transportFlush() ) {
298            // can't write anymore until the write_secured_buffer gets fully flushed out..
299            return 0;
300        }
301        int rc = 0;
302        while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
303            SSLEngineResult result = engine.wrap(plain, writeBuffer);
304            assert result.getStatus()!= BUFFER_OVERFLOW;
305            rc += result.bytesConsumed();
306            if( !transportFlush() || result.getStatus() == CLOSED) {
307                break;
308            }
309        }
310        if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
311            dispatchQueue.execute(new Task() {
312                public void run() {
313                    handshake();
314                }
315            });
316        }
317        return rc;
318    }
319
320    private int secure_read(ByteBuffer plain) throws IOException {
321        int rc=0;
322        while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
323            if( readOverflowBuffer !=null ) {
324                if(  plain.hasRemaining() ) {
325                    // lets drain the overflow buffer before trying to suck down anymore
326                    // network bytes.
327                    int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
328                    plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
329                    readOverflowBuffer.position(readOverflowBuffer.position()+size);
330                    if( !readOverflowBuffer.hasRemaining() ) {
331                        readOverflowBuffer = null;
332                    }
333                    rc += size;
334                } else {
335                    return rc;
336                }
337            } else if( readUnderflow ) {
338                int count = super.getReadChannel().read(readBuffer);
339                if( count == -1 ) {  // peer closed socket.
340                    if (rc==0) {
341                        return -1;
342                    } else {
343                        return rc;
344                    }
345                }
346                if( count==0 ) {  // no data available right now.
347                    return rc;
348                }
349                // read in some more data, perhaps now we can unwrap.
350                readUnderflow = false;
351                readBuffer.flip();
352            } else {
353                SSLEngineResult result = engine.unwrap(readBuffer, plain);
354                rc += result.bytesProduced();
355                if( result.getStatus() == BUFFER_OVERFLOW ) {
356                    readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
357                    result = engine.unwrap(readBuffer, readOverflowBuffer);
358                    if( readOverflowBuffer.position()==0 ) {
359                        readOverflowBuffer = null;
360                    } else {
361                        readOverflowBuffer.flip();
362                    }
363                }
364                switch( result.getStatus() ) {
365                    case CLOSED:
366                        if (rc==0) {
367                            engine.closeInbound();
368                            return -1;
369                        } else {
370                            return rc;
371                        }
372                    case OK:
373                        if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
374                            dispatchQueue.execute(new Task() {
375                                public void run() {
376                                    handshake();
377                                }
378                            });
379                        }
380                        break;
381                    case BUFFER_UNDERFLOW:
382                        readBuffer.compact();
383                        readUnderflow = true;
384                        break;
385                    case BUFFER_OVERFLOW:
386                        throw new AssertionError("Unexpected case.");
387                }
388            }
389        }
390        return rc;
391    }
392
393    public void handshake() {
394        try {
395            if( !transportFlush() ) {
396                return;
397            }
398            switch (engine.getHandshakeStatus()) {
399                case NEED_TASK:
400                    final Runnable task = engine.getDelegatedTask();
401                    if( task!=null ) {
402                        blockingExecutor.execute(new Task() {
403                            public void run() {
404                                task.run();
405                                dispatchQueue.execute(new Task() {
406                                    public void run() {
407                                        if (isConnected()) {
408                                            handshake();
409                                        }
410                                    }
411                                });
412                            }
413                        });
414                    }
415                    break;
416
417                case NEED_WRAP:
418                    secure_write(ByteBuffer.allocate(0));
419                    break;
420
421                case NEED_UNWRAP:
422                    if( secure_read(ByteBuffer.allocate(0)) == -1) {
423                        throw new EOFException("Peer disconnected during ssl handshake");
424                    }
425                    break;
426
427                case FINISHED:
428                case NOT_HANDSHAKING:
429                    break;
430
431                default:
432                    System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
433                    break;
434            }
435        } catch (IOException e ) {
436            onTransportFailure(e);
437        } finally {
438            if( engine.getHandshakeStatus() == NOT_HANDSHAKING ) {
439                drainOutboundSource.merge(1);
440                super.drainInbound();
441            }
442        }
443    }
444
445
446    public ReadableByteChannel getReadChannel() {
447        return ssl_channel;
448    }
449
450    public WritableByteChannel getWriteChannel() {
451        return ssl_channel;
452    }
453
454    public String getClientAuth() {
455        return clientAuth.name();
456    }
457
458    public void setClientAuth(String clientAuth) {
459        this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
460    }
461
462    public String getDisabledCypherSuites() {
463        return disabledCypherSuites;
464    }
465
466    public void setDisabledCypherSuites(String disabledCypherSuites) {
467        this.disabledCypherSuites = disabledCypherSuites;
468    }
469}
470
471