001/**
002 * Copyright (C) 2012 FuseSource, Inc.
003 * http://fusesource.com
004 *
005 * Licensed under the Apache License, Version 2.0 (the "License");
006 * you may not use this file except in compliance with the License.
007 * You may obtain a copy of the License at
008 *
009 *    http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.fusesource.hawtdispatch.transport;
019
020import org.fusesource.hawtdispatch.*;
021
022import java.io.IOException;
023import java.net.*;
024import java.nio.ByteBuffer;
025import java.nio.channels.ReadableByteChannel;
026import java.nio.channels.SelectionKey;
027import java.nio.channels.SocketChannel;
028import java.nio.channels.WritableByteChannel;
029import java.util.LinkedList;
030import java.util.concurrent.Executor;
031import java.util.concurrent.TimeUnit;
032
033/**
034 * An implementation of the {@link org.fusesource.hawtdispatch.transport.Transport} interface using raw tcp/ip
035 *
036 * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
037 */
038public class TcpTransport extends ServiceBase implements Transport {
039
040    static InetAddress localhost;
041    synchronized static public InetAddress getLocalHost() throws UnknownHostException {
042        // cache it...
043        if( localhost==null ) {
044            // this can be slow on some systems and we use repeatedly.
045            localhost = InetAddress.getLocalHost();
046        }
047        return localhost;
048    }
049
050    abstract static class SocketState {
051        void onStop(Task onCompleted) {
052        }
053        void onCanceled() {
054        }
055        boolean is(Class<? extends SocketState> clazz) {
056            return getClass()==clazz;
057        }
058    }
059
060    static class DISCONNECTED extends SocketState{}
061
062    class CONNECTING extends SocketState{
063        void onStop(Task onCompleted) {
064            trace("CONNECTING.onStop");
065            CANCELING state = new CANCELING();
066            socketState = state;
067            state.onStop(onCompleted);
068        }
069        void onCanceled() {
070            trace("CONNECTING.onCanceled");
071            CANCELING state = new CANCELING();
072            socketState = state;
073            state.onCanceled();
074        }
075    }
076
077    class CONNECTED extends SocketState {
078
079        public CONNECTED() {
080            localAddress = channel.socket().getLocalSocketAddress();
081            remoteAddress = channel.socket().getRemoteSocketAddress();
082        }
083
084        void onStop(Task onCompleted) {
085            trace("CONNECTED.onStop");
086            CANCELING state = new CANCELING();
087            socketState = state;
088            state.add(createDisconnectTask());
089            state.onStop(onCompleted);
090        }
091        void onCanceled() {
092            trace("CONNECTED.onCanceled");
093            CANCELING state = new CANCELING();
094            socketState = state;
095            state.add(createDisconnectTask());
096            state.onCanceled();
097        }
098        Task createDisconnectTask() {
099            return new Task(){
100                public void run() {
101                    listener.onTransportDisconnected();
102                }
103            };
104        }
105    }
106
107    class CANCELING extends SocketState {
108        private LinkedList<Task> runnables =  new LinkedList<Task>();
109        private int remaining;
110        private boolean dispose;
111
112        public CANCELING() {
113            if( readSource!=null ) {
114                remaining++;
115                readSource.cancel();
116            }
117            if( writeSource!=null ) {
118                remaining++;
119                writeSource.cancel();
120            }
121        }
122        void onStop(Task onCompleted) {
123            trace("CANCELING.onCompleted");
124            add(onCompleted);
125            dispose = true;
126        }
127        void add(Task onCompleted) {
128            if( onCompleted!=null ) {
129                runnables.add(onCompleted);
130            }
131        }
132        void onCanceled() {
133            trace("CANCELING.onCanceled");
134            remaining--;
135            if( remaining!=0 ) {
136                return;
137            }
138            try {
139                if( closeOnCancel ) {
140                    channel.close();
141                }
142            } catch (IOException ignore) {
143            }
144            socketState = new CANCELED(dispose);
145            for (Task runnable : runnables) {
146                runnable.run();
147            }
148            if (dispose) {
149                dispose();
150            }
151        }
152    }
153
154    class CANCELED extends SocketState {
155        private boolean disposed;
156
157        public CANCELED(boolean disposed) {
158            this.disposed=disposed;
159        }
160
161        void onStop(Task onCompleted) {
162            trace("CANCELED.onStop");
163            if( !disposed ) {
164                disposed = true;
165                dispose();
166            }
167            onCompleted.run();
168        }
169    }
170
171    protected URI remoteLocation;
172    protected URI localLocation;
173    protected TransportListener listener;
174    protected ProtocolCodec codec;
175
176    protected SocketChannel channel;
177
178    protected SocketState socketState = new DISCONNECTED();
179
180    protected DispatchQueue dispatchQueue;
181    private DispatchSource readSource;
182    private DispatchSource writeSource;
183    protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
184    protected CustomDispatchSource<Integer, Integer> yieldSource;
185
186    protected boolean useLocalHost = true;
187
188    int maxReadRate;
189    int maxWriteRate;
190    int receiveBufferSize = 1024*64;
191    int sendBufferSize = 1024*64;
192    boolean closeOnCancel = true;
193
194    boolean keepAlive = true;
195
196    public static final int IPTOS_LOWCOST = 0x02;
197    public static final int IPTOS_RELIABILITY = 0x04;
198    public static final int IPTOS_THROUGHPUT = 0x08;
199    public static final int IPTOS_LOWDELAY = 0x10;
200
201    int trafficClass = IPTOS_THROUGHPUT;
202
203    protected RateLimitingChannel rateLimitingChannel;
204    SocketAddress localAddress;
205    SocketAddress remoteAddress;
206    protected Executor blockingExecutor;
207
208    class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel {
209
210        int read_allowance = maxReadRate;
211        boolean read_suspended = false;
212        int read_resume_counter = 0;
213        int write_allowance = maxWriteRate;
214        boolean write_suspended = false;
215
216        public void resetAllowance() {
217            if( read_allowance != maxReadRate || write_allowance != maxWriteRate) {
218                read_allowance = maxReadRate;
219                write_allowance = maxWriteRate;
220                if( write_suspended ) {
221                    write_suspended = false;
222                    resumeWrite();
223                }
224                if( read_suspended ) {
225                    read_suspended = false;
226                    resumeRead();
227                    for( int i=0; i < read_resume_counter ; i++ ) {
228                        resumeRead();
229                    }
230                }
231            }
232        }
233
234        public int read(ByteBuffer dst) throws IOException {
235            if( maxReadRate ==0 ) {
236                return channel.read(dst);
237            } else {
238                int remaining = dst.remaining();
239                if( read_allowance ==0 || remaining ==0 ) {
240                    return 0;
241                }
242
243                int reduction = 0;
244                if( remaining > read_allowance) {
245                    reduction = remaining - read_allowance;
246                    dst.limit(dst.limit() - reduction);
247                }
248                int rc=0;
249                try {
250                    rc = channel.read(dst);
251                    read_allowance -= rc;
252                } finally {
253                    if( reduction!=0 ) {
254                        if( dst.remaining() == 0 ) {
255                            // we need to suspend the read now until we get
256                            // a new allowance..
257                            readSource.suspend();
258                            read_suspended = true;
259                        }
260                        dst.limit(dst.limit() + reduction);
261                    }
262                }
263                return rc;
264            }
265        }
266
267        public int write(ByteBuffer src) throws IOException {
268            if( maxWriteRate ==0 ) {
269                return channel.write(src);
270            } else {
271                int remaining = src.remaining();
272                if( write_allowance ==0 || remaining ==0 ) {
273                    return 0;
274                }
275
276                int reduction = 0;
277                if( remaining > write_allowance) {
278                    reduction = remaining - write_allowance;
279                    src.limit(src.limit() - reduction);
280                }
281                int rc = 0;
282                try {
283                    rc = channel.write(src);
284                    write_allowance -= rc;
285                } finally {
286                    if( reduction!=0 ) {
287                        if( src.remaining() == 0 ) {
288                            // we need to suspend the read now until we get
289                            // a new allowance..
290                            write_suspended = true;
291                            suspendWrite();
292                        }
293                        src.limit(src.limit() + reduction);
294                    }
295                }
296                return rc;
297            }
298        }
299
300        public boolean isOpen() {
301            return channel.isOpen();
302        }
303
304        public void close() throws IOException {
305            channel.close();
306        }
307
308        public void resumeRead() {
309            if( read_suspended ) {
310                read_resume_counter += 1;
311            } else {
312                _resumeRead();
313            }
314        }
315
316    }
317
318    private final Task CANCEL_HANDLER = new Task() {
319        public void run() {
320            socketState.onCanceled();
321        }
322    };
323
324    static final class OneWay {
325        final Object command;
326        final Retained retained;
327
328        public OneWay(Object command, Retained retained) {
329            this.command = command;
330            this.retained = retained;
331        }
332    }
333
334    public void connected(SocketChannel channel) throws IOException, Exception {
335        this.channel = channel;
336        initializeChannel();
337        this.socketState = new CONNECTED();
338    }
339
340    protected void initializeChannel() throws Exception {
341        this.channel.configureBlocking(false);
342        Socket socket = channel.socket();
343        try {
344            socket.setReuseAddress(true);
345        } catch (SocketException e) {
346        }
347        try {
348            socket.setSoLinger(true, 0);
349        } catch (SocketException e) {
350        }
351        try {
352            socket.setTrafficClass(trafficClass);
353        } catch (SocketException e) {
354        }
355        try {
356            socket.setKeepAlive(keepAlive);
357        } catch (SocketException e) {
358        }
359        try {
360            socket.setTcpNoDelay(true);
361        } catch (SocketException e) {
362        }
363        try {
364            socket.setReceiveBufferSize(receiveBufferSize);
365        } catch (SocketException e) {
366        }
367        try {
368            socket.setSendBufferSize(sendBufferSize);
369        } catch (SocketException e) {
370        }
371
372        if( channel!=null && codec!=null ) {
373            initializeCodec();
374        }
375    }
376
377    protected void initializeCodec() throws Exception {
378        codec.setTransport(this);
379    }
380
381    public void connecting(final URI remoteLocation, final URI localLocation) throws Exception {
382        this.channel = SocketChannel.open();
383        initializeChannel();
384        this.remoteLocation = remoteLocation;
385        this.localLocation = localLocation;
386        socketState = new CONNECTING();
387    }
388
389
390    public DispatchQueue getDispatchQueue() {
391        return dispatchQueue;
392    }
393
394    public void setDispatchQueue(DispatchQueue queue) {
395        this.dispatchQueue = queue;
396        if(readSource!=null) readSource.setTargetQueue(queue);
397        if(writeSource!=null) writeSource.setTargetQueue(queue);
398        if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
399        if(yieldSource!=null) yieldSource.setTargetQueue(queue);
400    }
401
402    public void _start(Task onCompleted) {
403        try {
404            if (socketState.is(CONNECTING.class)) {
405
406                // Resolving host names might block.. so do it on the blocking executor.
407                this.blockingExecutor.execute(new Runnable() {
408                    public void run() {
409                        try {
410
411                            final InetSocketAddress localAddress = (localLocation != null) ?
412                                    new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort())
413                                    : null;
414
415                            String host = resolveHostName(remoteLocation.getHost());
416                            final InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
417
418                            // Done resolving.. switch back to the dispatch queue.
419                            dispatchQueue.execute(new Task() {
420                                @Override
421                                public void run() {
422                                    // No need to complete if we have been canceled.
423                                    if( ! socketState.is(CONNECTING.class) ) {
424                                        return;
425                                    }
426                                    try {
427
428                                        if (localAddress != null) {
429                                            channel.socket().bind(localAddress);
430                                        }
431                                        trace("connecting...");
432                                        channel.connect(remoteAddress);
433
434                                        // this allows the connect to complete..
435                                        readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue);
436                                        readSource.setEventHandler(new Task() {
437                                            public void run() {
438                                                if (getServiceState() != STARTED) {
439                                                    return;
440                                                }
441                                                try {
442                                                    trace("connected.");
443                                                    channel.finishConnect();
444                                                    readSource.setCancelHandler(null);
445                                                    readSource.cancel();
446                                                    readSource = null;
447                                                    socketState = new CONNECTED();
448                                                    onConnected();
449                                                } catch (IOException e) {
450                                                    onTransportFailure(e);
451                                                }
452                                            }
453                                        });
454                                        readSource.setCancelHandler(CANCEL_HANDLER);
455                                        readSource.resume();
456
457                                    } catch (IOException e) {
458                                        try {
459                                            channel.close();
460                                        } catch (IOException ignore) {
461                                        }
462                                        socketState = new CANCELED(true);
463                                        listener.onTransportFailure(e);
464                                    }
465                                }
466                            });
467
468                        } catch (IOException e) {
469                            try {
470                                channel.close();
471                            } catch (IOException ignore) {
472                            }
473                            socketState = new CANCELED(true);
474                            listener.onTransportFailure(e);
475                        }
476                    }
477                });
478            } else if (socketState.is(CONNECTED.class)) {
479                dispatchQueue.execute(new Task() {
480                    public void run() {
481                        try {
482                            trace("was connected.");
483                            onConnected();
484                        } catch (IOException e) {
485                            onTransportFailure(e);
486                        }
487                    }
488                });
489            } else {
490                System.err.println("cannot be started.  socket state is: " + socketState);
491            }
492        } finally {
493            if (onCompleted != null) {
494                onCompleted.run();
495            }
496        }
497    }
498
499    public void _stop(final Task onCompleted) {
500        trace("stopping.. at state: "+socketState);
501        socketState.onStop(onCompleted);
502    }
503
504    protected String resolveHostName(String host) throws UnknownHostException {
505        String localName = getLocalHost().getHostName();
506        if (localName != null && isUseLocalHost()) {
507            if (localName.equals(host)) {
508                return "localhost";
509            }
510        }
511        return host;
512    }
513
514    protected void onConnected() throws IOException {
515        yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
516        yieldSource.setEventHandler(new Task() {
517            public void run() {
518                drainInbound();
519            }
520        });
521        yieldSource.resume();
522        drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
523        drainOutboundSource.setEventHandler(new Task() {
524            public void run() {
525                flush();
526            }
527        });
528        drainOutboundSource.resume();
529
530        readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
531        writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
532
533        readSource.setCancelHandler(CANCEL_HANDLER);
534        writeSource.setCancelHandler(CANCEL_HANDLER);
535
536        readSource.setEventHandler(new Task() {
537            public void run() {
538                drainInbound();
539            }
540        });
541        writeSource.setEventHandler(new Task() {
542            public void run() {
543                flush();
544            }
545        });
546
547        if( maxReadRate !=0 || maxWriteRate !=0 ) {
548            rateLimitingChannel = new RateLimitingChannel();
549            schedualRateAllowanceReset();
550        }
551        listener.onTransportConnected();
552    }
553
554    private void schedualRateAllowanceReset() {
555        dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Task(){
556            public void run() {
557                if( !socketState.is(CONNECTED.class) ) {
558                    return;
559                }
560                rateLimitingChannel.resetAllowance();
561                schedualRateAllowanceReset();
562            }
563        });
564    }
565
566    private void dispose() {
567        if( readSource!=null ) {
568            readSource.cancel();
569            readSource=null;
570        }
571
572        if( writeSource!=null ) {
573            writeSource.cancel();
574            writeSource=null;
575        }
576    }
577
578    public void onTransportFailure(IOException error) {
579        listener.onTransportFailure(error);
580        socketState.onCanceled();
581    }
582
583
584    public boolean full() {
585        return codec==null ||
586               codec.full() ||
587               !socketState.is(CONNECTED.class) ||
588               getServiceState() != STARTED;
589    }
590
591    boolean rejectingOffers;
592
593    public boolean offer(Object command) {
594        dispatchQueue.assertExecuting();
595        if( full() ) {
596            return false;
597        }
598        try {
599            ProtocolCodec.BufferState rc = codec.write(command);
600            rejectingOffers = codec.full();
601            switch (rc ) {
602                case FULL:
603                    return false;
604                default:
605                    drainOutboundSource.merge(1);
606            }
607        } catch (IOException e) {
608            onTransportFailure(e);
609        }
610        return true;
611    }
612
613    boolean writeResumedForCodecFlush = false;
614
615    /**
616     *
617     */
618    public void flush() {
619        dispatchQueue.assertExecuting();
620        if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
621            return;
622        }
623        try {
624            if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
625                if( writeResumedForCodecFlush) {
626                    writeResumedForCodecFlush = false;
627                    suspendWrite();
628                }
629                rejectingOffers = false;
630                listener.onRefill();
631
632            } else {
633                if(!writeResumedForCodecFlush) {
634                    writeResumedForCodecFlush = true;
635                    resumeWrite();
636                }
637            }
638        } catch (IOException e) {
639            onTransportFailure(e);
640        }
641    }
642
643    protected boolean transportFlush() throws IOException {
644        return true;
645    }
646
647    public void drainInbound() {
648        if (!getServiceState().isStarted() || readSource.isSuspended()) {
649            return;
650        }
651        try {
652            long initial = codec.getReadCounter();
653            // Only process upto 2 x the read buffer worth of data at a time so we can give
654            // other connections a chance to process their requests.
655            while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
656                Object command = codec.read();
657                if ( command!=null ) {
658                    try {
659                        listener.onTransportCommand(command);
660                    } catch (Throwable e) {
661                        e.printStackTrace();
662                        onTransportFailure(new IOException("Transport listener failure."));
663                    }
664
665                    // the transport may be suspended after processing a command.
666                    if (getServiceState() == STOPPED || readSource.isSuspended()) {
667                        return;
668                    }
669                } else {
670                    return;
671                }
672            }
673            yieldSource.merge(1);
674        } catch (IOException e) {
675            onTransportFailure(e);
676        }
677    }
678
679    public SocketAddress getLocalAddress() {
680        return localAddress;
681    }
682
683    public SocketAddress getRemoteAddress() {
684        return remoteAddress;
685    }
686
687    private boolean assertConnected() {
688        try {
689            if ( !isConnected() ) {
690                throw new IOException("Not connected.");
691            }
692            return true;
693        } catch (IOException e) {
694            onTransportFailure(e);
695        }
696        return false;
697    }
698
699    public void suspendRead() {
700        if( isConnected() && readSource!=null ) {
701            readSource.suspend();
702        }
703    }
704
705
706    public void resumeRead() {
707        if( isConnected() && readSource!=null ) {
708            if( rateLimitingChannel!=null ) {
709                rateLimitingChannel.resumeRead();
710            } else {
711                _resumeRead();
712            }
713        }
714    }
715
716    private void _resumeRead() {
717        readSource.resume();
718        dispatchQueue.execute(new Task(){
719            public void run() {
720                drainInbound();
721            }
722        });
723    }
724
725    protected void suspendWrite() {
726        if( isConnected() && writeSource!=null ) {
727            writeSource.suspend();
728        }
729    }
730
731    protected void resumeWrite() {
732        if( isConnected() && writeSource!=null ) {
733            writeSource.resume();
734        }
735    }
736
737    public TransportListener getTransportListener() {
738        return listener;
739    }
740
741    public void setTransportListener(TransportListener transportListener) {
742        this.listener = transportListener;
743    }
744
745    public ProtocolCodec getProtocolCodec() {
746        return codec;
747    }
748
749    public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
750        this.codec = protocolCodec;
751        if( channel!=null && codec!=null ) {
752            initializeCodec();
753        }
754    }
755
756    public boolean isConnected() {
757        return socketState.is(CONNECTED.class);
758    }
759
760    public boolean isClosed() {
761        return getServiceState() == STOPPED;
762    }
763
764    public boolean isUseLocalHost() {
765        return useLocalHost;
766    }
767
768    /**
769     * Sets whether 'localhost' or the actual local host name should be used to
770     * make local connections. On some operating systems such as Macs its not
771     * possible to connect as the local host name so localhost is better.
772     */
773    public void setUseLocalHost(boolean useLocalHost) {
774        this.useLocalHost = useLocalHost;
775    }
776
777    private void trace(String message) {
778        // TODO:
779    }
780
781    public SocketChannel getSocketChannel() {
782        return channel;
783    }
784
785    public ReadableByteChannel getReadChannel() {
786        if(rateLimitingChannel!=null) {
787            return rateLimitingChannel;
788        } else {
789            return channel;
790        }
791    }
792
793    public WritableByteChannel getWriteChannel() {
794        if(rateLimitingChannel!=null) {
795            return rateLimitingChannel;
796        } else {
797            return channel;
798        }
799    }
800
801    public int getMaxReadRate() {
802        return maxReadRate;
803    }
804
805    public void setMaxReadRate(int maxReadRate) {
806        this.maxReadRate = maxReadRate;
807    }
808
809    public int getMaxWriteRate() {
810        return maxWriteRate;
811    }
812
813    public void setMaxWriteRate(int maxWriteRate) {
814        this.maxWriteRate = maxWriteRate;
815    }
816
817    public int getTrafficClass() {
818        return trafficClass;
819    }
820
821    public void setTrafficClass(int trafficClass) {
822        this.trafficClass = trafficClass;
823    }
824
825    public int getReceiveBufferSize() {
826        return receiveBufferSize;
827    }
828
829    public void setReceiveBufferSize(int receiveBufferSize) {
830        this.receiveBufferSize = receiveBufferSize;
831        if( channel!=null ) {
832            try {
833                channel.socket().setReceiveBufferSize(receiveBufferSize);
834            } catch (SocketException ignore) {
835            }
836        }
837    }
838
839    public int getSendBufferSize() {
840        return sendBufferSize;
841    }
842
843    public void setSendBufferSize(int sendBufferSize) {
844        this.sendBufferSize = sendBufferSize;
845        if( channel!=null ) {
846            try {
847                channel.socket().setReceiveBufferSize(sendBufferSize);
848            } catch (SocketException ignore) {
849            }
850        }
851    }
852
853    public boolean isKeepAlive() {
854        return keepAlive;
855    }
856
857    public void setKeepAlive(boolean keepAlive) {
858        this.keepAlive = keepAlive;
859    }
860
861    public Executor getBlockingExecutor() {
862        return blockingExecutor;
863    }
864
865    public void setBlockingExecutor(Executor blockingExecutor) {
866        this.blockingExecutor = blockingExecutor;
867    }
868
869    public boolean isCloseOnCancel() {
870        return closeOnCancel;
871    }
872
873    public void setCloseOnCancel(boolean closeOnCancel) {
874        this.closeOnCancel = closeOnCancel;
875    }
876}