1 // excerpts from http://code.google.com/p/muduo/
2 //
3 // Use of this source code is governed by a BSD-style license
4 // that can be found in the License file.
5 //
6 // Author: Shuo Chen (chenshuo at chenshuo dot com)
7 
8 #include "TcpConnection.h"
9 
10 #include "logging/Logging.h"
11 #include "Channel.h"
12 #include "EventLoop.h"
13 #include "Socket.h"
14 #include "SocketsOps.h"
15 
16 #include <boost/bind.hpp>
17 
18 #include <errno.h>
19 #include <stdio.h>
20 
21 using namespace muduo;
22 
23 TcpConnection::TcpConnection(EventLoop* loop,
24                              const std::string& nameArg,
25                              int sockfd,
26                              const InetAddress& localAddr,
27                              const InetAddress& peerAddr)
28   : loop_(CHECK_NOTNULL(loop)),
29     name_(nameArg),
30     state_(kConnecting),
31     socket_(new Socket(sockfd)),
32     channel_(new Channel(loop, sockfd)),
33     localAddr_(localAddr),
34     peerAddr_(peerAddr)
35 {
36   LOG_DEBUG << "TcpConnection::ctor[" <<  name_ << "] at " << this
37             << " fd=" << sockfd;
38   channel_->setReadCallback(
39       boost::bind(&TcpConnection::handleRead, this, _1));
40   channel_->setWriteCallback(
41       boost::bind(&TcpConnection::handleWrite, this));
42   channel_->setCloseCallback(
43       boost::bind(&TcpConnection::handleClose, this));
44   channel_->setErrorCallback(
45       boost::bind(&TcpConnection::handleError, this));
46 }
47 
48 TcpConnection::~TcpConnection()
49 {
50   LOG_DEBUG << "TcpConnection::dtor[" <<  name_ << "] at " << this
51             << " fd=" << channel_->fd();
52 }
53 
54 void TcpConnection::send(const std::string& message)
55 {
56   if (state_ == kConnected) {
57     if (loop_->isInLoopThread()) {
58       sendInLoop(message);
59     } else {
60       loop_->runInLoop(
61           boost::bind(&TcpConnection::sendInLoop, this, message));
62     }
63   }
64 }
65 
66 void TcpConnection::sendInLoop(const std::string& message)
67 {
68   loop_->assertInLoopThread();
69   ssize_t nwrote = 0;
70   // if no thing in output queue, try writing directly
71   if (!channel_->isWriting() && outputBuffer_.readableBytes() == 0) {
72     nwrote = ::write(channel_->fd(), message.data(), message.size());
73     if (nwrote >= 0) {
74       if (implicit_cast<size_t>(nwrote) < message.size()) {
75         LOG_TRACE << "I am going to write more data";
76+      } else if (writeCompleteCallback_) {
77+        loop_->queueInLoop(
78+            boost::bind(writeCompleteCallback_, shared_from_this()));
79       }
80     } else {
81       nwrote = 0;
82       if (errno != EWOULDBLOCK) {
83         LOG_SYSERR << "TcpConnection::sendInLoop";
84       }
85     }
86   }
87 
88   assert(nwrote >= 0);
89   if (implicit_cast<size_t>(nwrote) < message.size()) {
90     outputBuffer_.append(message.data()+nwrote, message.size()-nwrote);
91     if (!channel_->isWriting()) {
92       channel_->enableWriting();
93     }
94   }
95 }
96 
97 void TcpConnection::shutdown()
98 {
99   // FIXME: use compare and swap
100   if (state_ == kConnected)
101   {
102     setState(kDisconnecting);
103     // FIXME: shared_from_this()?
104     loop_->runInLoop(boost::bind(&TcpConnection::shutdownInLoop, this));
105   }
106 }
107 
108 void TcpConnection::shutdownInLoop()
109 {
110   loop_->assertInLoopThread();
111   if (!channel_->isWriting())
112   {
113     // we are not writing
114     socket_->shutdownWrite();
115   }
116 }
117 
118+void TcpConnection::setTcpNoDelay(bool on)
119+{
120+  socket_->setTcpNoDelay(on);
121+}
122+
123 void TcpConnection::connectEstablished()
124 {
125   loop_->assertInLoopThread();
126   assert(state_ == kConnecting);
127   setState(kConnected);
128   channel_->enableReading();
129   connectionCallback_(shared_from_this());
130 }
131 
132 void TcpConnection::connectDestroyed()
133 {
134   loop_->assertInLoopThread();
135   assert(state_ == kConnected || state_ == kDisconnecting);
136   setState(kDisconnected);
137   channel_->disableAll();
138   connectionCallback_(shared_from_this());
139 
140   loop_->removeChannel(get_pointer(channel_));
141 }
142 
143 void TcpConnection::handleRead(Timestamp receiveTime)
144 {
145   int savedErrno = 0;
146   ssize_t n = inputBuffer_.readFd(channel_->fd(), &savedErrno);
147   if (n > 0) {
148     messageCallback_(shared_from_this(), &inputBuffer_, receiveTime);
149   } else if (n == 0) {
150     handleClose();
151   } else {
152     // FIXME: check savedErrno
153     handleError();
154   }
155 }
156 
157 void TcpConnection::handleWrite()
158 {
159   loop_->assertInLoopThread();
160   if (channel_->isWriting()) {
161     ssize_t n = ::write(channel_->fd(),
162                         outputBuffer_.peek(),
163                         outputBuffer_.readableBytes());
164     if (n > 0) {
165       outputBuffer_.retrieve(n);
166       if (outputBuffer_.readableBytes() == 0) {
167         channel_->disableWriting();
168+        if (writeCompleteCallback_) {
169+          loop_->queueInLoop(
170+              boost::bind(writeCompleteCallback_, shared_from_this()));
171+        }
172         if (state_ == kDisconnecting) {
173           shutdownInLoop();
174         }
175       } else {
176         LOG_TRACE << "I am going to write more data";
177       }
178     } else {
179       LOG_SYSERR << "TcpConnection::handleWrite";
180       abort();  // FIXME
181     }
182   } else {
183     LOG_TRACE << "Connection is down, no more writing";
184   }
185 }
186 
187 void TcpConnection::handleClose()
188 {
189   loop_->assertInLoopThread();
190   LOG_TRACE << "TcpConnection::handleClose state = " << state_;
191   assert(state_ == kConnected || state_ == kDisconnecting);
192   // we don't close fd, leave it to dtor, so we can find leaks easily.
193   channel_->disableAll();
194   // must be the last line
195   closeCallback_(shared_from_this());
196 }
197 
198 void TcpConnection::handleError()
199 {
200   int err = sockets::getSocketError(channel_->fd());
201   LOG_ERROR << "TcpConnection::handleError [" << name_
202             << "] - SO_ERROR = " << err << " " << strerror_tl(err);
203 }
204