Crypto++
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 #include "socketft.h"
5 
6 #ifdef SOCKETS_AVAILABLE
7 
8 #include "wait.h"
9 
10 #ifdef USE_BERKELEY_STYLE_SOCKETS
11 #include <errno.h>
12 #include <netdb.h>
13 #include <unistd.h>
14 #include <arpa/inet.h>
15 #include <netinet/in.h>
16 #include <sys/ioctl.h>
17 #endif
18 
19 NAMESPACE_BEGIN(CryptoPP)
20 
21 #ifdef USE_WINDOWS_STYLE_SOCKETS
22 const int SOCKET_EINVAL = WSAEINVAL;
23 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
24 typedef int socklen_t;
25 #else
26 const int SOCKET_EINVAL = EINVAL;
27 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
28 #endif
29 
30 Socket::Err::Err(socket_t s, const std::string& operation, int error)
31  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
32  , m_s(s)
33 {
34 }
35 
36 Socket::~Socket()
37 {
38  if (m_own)
39  {
40  try
41  {
42  CloseSocket();
43  }
44  catch (...)
45  {
46  }
47  }
48 }
49 
50 void Socket::AttachSocket(socket_t s, bool own)
51 {
52  if (m_own)
53  CloseSocket();
54 
55  m_s = s;
56  m_own = own;
57  SocketChanged();
58 }
59 
60 socket_t Socket::DetachSocket()
61 {
62  socket_t s = m_s;
63  m_s = INVALID_SOCKET;
64  SocketChanged();
65  return s;
66 }
67 
68 void Socket::Create(int nType)
69 {
70  assert(m_s == INVALID_SOCKET);
71  m_s = socket(AF_INET, nType, 0);
72  CheckAndHandleError("socket", m_s);
73  m_own = true;
74  SocketChanged();
75 }
76 
77 void Socket::CloseSocket()
78 {
79  if (m_s != INVALID_SOCKET)
80  {
81 #ifdef USE_WINDOWS_STYLE_SOCKETS
82  CancelIo((HANDLE) m_s);
83  CheckAndHandleError_int("closesocket", closesocket(m_s));
84 #else
85  CheckAndHandleError_int("close", close(m_s));
86 #endif
87  m_s = INVALID_SOCKET;
88  SocketChanged();
89  }
90 }
91 
92 void Socket::Bind(unsigned int port, const char *addr)
93 {
94  sockaddr_in sa;
95  memset(&sa, 0, sizeof(sa));
96  sa.sin_family = AF_INET;
97 
98  if (addr == NULL)
99  sa.sin_addr.s_addr = htonl(INADDR_ANY);
100  else
101  {
102  unsigned long result = inet_addr(addr);
103  if (result == -1) // Solaris doesn't have INADDR_NONE
104  {
105  SetLastError(SOCKET_EINVAL);
106  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
107  }
108  sa.sin_addr.s_addr = result;
109  }
110 
111  sa.sin_port = htons((u_short)port);
112 
113  Bind((sockaddr *)&sa, sizeof(sa));
114 }
115 
116 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
117 {
118  assert(m_s != INVALID_SOCKET);
119  // cygwin workaround: needs const_cast
120  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
121 }
122 
123 void Socket::Listen(int backlog)
124 {
125  assert(m_s != INVALID_SOCKET);
126  CheckAndHandleError_int("listen", listen(m_s, backlog));
127 }
128 
129 bool Socket::Connect(const char *addr, unsigned int port)
130 {
131  assert(addr != NULL);
132 
133  sockaddr_in sa;
134  memset(&sa, 0, sizeof(sa));
135  sa.sin_family = AF_INET;
136  sa.sin_addr.s_addr = inet_addr(addr);
137 
138  if (sa.sin_addr.s_addr == -1) // Solaris doesn't have INADDR_NONE
139  {
140  hostent *lphost = gethostbyname(addr);
141  if (lphost == NULL)
142  {
143  SetLastError(SOCKET_EINVAL);
144  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
145  }
146 
147  sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr;
148  }
149 
150  sa.sin_port = htons((u_short)port);
151 
152  return Connect((const sockaddr *)&sa, sizeof(sa));
153 }
154 
155 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
156 {
157  assert(m_s != INVALID_SOCKET);
158  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
159  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
160  return false;
161  CheckAndHandleError_int("connect", result);
162  return true;
163 }
164 
165 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
166 {
167  assert(m_s != INVALID_SOCKET);
168  socket_t s = accept(m_s, psa, psaLen);
169  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
170  return false;
171  CheckAndHandleError("accept", s);
172  target.AttachSocket(s, true);
173  return true;
174 }
175 
176 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
177 {
178  assert(m_s != INVALID_SOCKET);
179  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
180 }
181 
182 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
183 {
184  assert(m_s != INVALID_SOCKET);
185  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
186 }
187 
188 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
189 {
190  assert(m_s != INVALID_SOCKET);
191  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
192  CheckAndHandleError_int("send", result);
193  return result;
194 }
195 
196 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
197 {
198  assert(m_s != INVALID_SOCKET);
199  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
200  CheckAndHandleError_int("recv", result);
201  return result;
202 }
203 
204 void Socket::ShutDown(int how)
205 {
206  assert(m_s != INVALID_SOCKET);
207  int result = shutdown(m_s, how);
208  CheckAndHandleError_int("shutdown", result);
209 }
210 
211 void Socket::IOCtl(long cmd, unsigned long *argp)
212 {
213  assert(m_s != INVALID_SOCKET);
214 #ifdef USE_WINDOWS_STYLE_SOCKETS
215  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
216 #else
217  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
218 #endif
219 }
220 
221 bool Socket::SendReady(const timeval *timeout)
222 {
223  fd_set fds;
224  FD_ZERO(&fds);
225  FD_SET(m_s, &fds);
226  int ready;
227  if (timeout == NULL)
228  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
229  else
230  {
231  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
232  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
233  }
234  CheckAndHandleError_int("select", ready);
235  return ready > 0;
236 }
237 
238 bool Socket::ReceiveReady(const timeval *timeout)
239 {
240  fd_set fds;
241  FD_ZERO(&fds);
242  FD_SET(m_s, &fds);
243  int ready;
244  if (timeout == NULL)
245  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
246  else
247  {
248  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
249  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
250  }
251  CheckAndHandleError_int("select", ready);
252  return ready > 0;
253 }
254 
255 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
256 {
257  int port = atoi(name);
258  if (IntToString(port) == name)
259  return port;
260 
261  servent *se = getservbyname(name, protocol);
262  if (!se)
263  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
264  return ntohs(se->s_port);
265 }
266 
268 {
269 #ifdef USE_WINDOWS_STYLE_SOCKETS
270  WSADATA wsd;
271  int result = WSAStartup(0x0202, &wsd);
272  if (result != 0)
273  throw Err(INVALID_SOCKET, "WSAStartup", result);
274 #endif
275 }
276 
278 {
279 #ifdef USE_WINDOWS_STYLE_SOCKETS
280  int result = WSACleanup();
281  if (result != 0)
282  throw Err(INVALID_SOCKET, "WSACleanup", result);
283 #endif
284 }
285 
287 {
288 #ifdef USE_WINDOWS_STYLE_SOCKETS
289  return WSAGetLastError();
290 #else
291  return errno;
292 #endif
293 }
294 
295 void Socket::SetLastError(int errorCode)
296 {
297 #ifdef USE_WINDOWS_STYLE_SOCKETS
298  WSASetLastError(errorCode);
299 #else
300  errno = errorCode;
301 #endif
302 }
303 
304 void Socket::HandleError(const char *operation) const
305 {
306  int err = GetLastError();
307  throw Err(m_s, operation, err);
308 }
309 
310 #ifdef USE_WINDOWS_STYLE_SOCKETS
311 
312 SocketReceiver::SocketReceiver(Socket &s)
313  : m_s(s), m_resultPending(false), m_eofReceived(false)
314 {
315  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
316  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
317  memset(&m_overlapped, 0, sizeof(m_overlapped));
318  m_overlapped.hEvent = m_event;
319 }
320 
321 SocketReceiver::~SocketReceiver()
322 {
323 #ifdef USE_WINDOWS_STYLE_SOCKETS
324  CancelIo((HANDLE) m_s.GetSocket());
325 #endif
326 }
327 
328 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
329 {
330  assert(!m_resultPending && !m_eofReceived);
331 
332  DWORD flags = 0;
333  // don't queue too much at once, or we might use up non-paged memory
334  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
335  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
336  {
337  if (m_lastResult == 0)
338  m_eofReceived = true;
339  }
340  else
341  {
342  switch (WSAGetLastError())
343  {
344  default:
345  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
346  case WSAEDISCON:
347  m_lastResult = 0;
348  m_eofReceived = true;
349  break;
350  case WSA_IO_PENDING:
351  m_resultPending = true;
352  }
353  }
354  return !m_resultPending;
355 }
356 
358 {
359  if (m_resultPending)
360  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
361  else if (!m_eofReceived)
362  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
363 }
364 
365 unsigned int SocketReceiver::GetReceiveResult()
366 {
367  if (m_resultPending)
368  {
369  DWORD flags = 0;
370  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
371  {
372  if (m_lastResult == 0)
373  m_eofReceived = true;
374  }
375  else
376  {
377  switch (WSAGetLastError())
378  {
379  default:
380  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
381  case WSAEDISCON:
382  m_lastResult = 0;
383  m_eofReceived = true;
384  }
385  }
386  m_resultPending = false;
387  }
388  return m_lastResult;
389 }
390 
391 // *************************************************************
392 
393 SocketSender::SocketSender(Socket &s)
394  : m_s(s), m_resultPending(false), m_lastResult(0)
395 {
396  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
397  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
398  memset(&m_overlapped, 0, sizeof(m_overlapped));
399  m_overlapped.hEvent = m_event;
400 }
401 
402 
403 SocketSender::~SocketSender()
404 {
405 #ifdef USE_WINDOWS_STYLE_SOCKETS
406  CancelIo((HANDLE) m_s.GetSocket());
407 #endif
408 }
409 
410 void SocketSender::Send(const byte* buf, size_t bufLen)
411 {
412  assert(!m_resultPending);
413  DWORD written = 0;
414  // don't queue too much at once, or we might use up non-paged memory
415  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
416  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
417  {
418  m_resultPending = false;
419  m_lastResult = written;
420  }
421  else
422  {
423  if (WSAGetLastError() != WSA_IO_PENDING)
424  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
425 
426  m_resultPending = true;
427  }
428 }
429 
430 void SocketSender::SendEof()
431 {
432  assert(!m_resultPending);
433  m_s.ShutDown(SD_SEND);
434  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
435  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
436  m_resultPending = true;
437 }
438 
439 bool SocketSender::EofSent()
440 {
441  if (m_resultPending)
442  {
443  WSANETWORKEVENTS events;
444  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
445  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
446  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
447  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
448  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
449  m_resultPending = false;
450  }
451  return m_lastResult != 0;
452 }
453 
455 {
456  if (m_resultPending)
457  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
458  else
459  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
460 }
461 
462 unsigned int SocketSender::GetSendResult()
463 {
464  if (m_resultPending)
465  {
466  DWORD flags = 0;
467  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
468  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
469  m_resultPending = false;
470  }
471  return m_lastResult;
472 }
473 
474 #endif
475 
476 #ifdef USE_BERKELEY_STYLE_SOCKETS
477 
478 SocketReceiver::SocketReceiver(Socket &s)
479  : m_s(s), m_lastResult(0), m_eofReceived(false)
480 {
481 }
482 
483 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
484 {
485  if (!m_eofReceived)
486  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
487 }
488 
489 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
490 {
491  m_lastResult = m_s.Receive(buf, bufLen);
492  if (bufLen > 0 && m_lastResult == 0)
493  m_eofReceived = true;
494  return true;
495 }
496 
497 unsigned int SocketReceiver::GetReceiveResult()
498 {
499  return m_lastResult;
500 }
501 
502 SocketSender::SocketSender(Socket &s)
503  : m_s(s), m_lastResult(0)
504 {
505 }
506 
507 void SocketSender::Send(const byte* buf, size_t bufLen)
508 {
509  m_lastResult = m_s.Send(buf, bufLen);
510 }
511 
512 void SocketSender::SendEof()
513 {
514  m_s.ShutDown(SD_SEND);
515 }
516 
517 unsigned int SocketSender::GetSendResult()
518 {
519  return m_lastResult;
520 }
521 
522 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
523 {
524  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
525 }
526 
527 #endif
528 
529 NAMESPACE_END
530 
531 #endif // #ifdef SOCKETS_AVAILABLE