Crypto++  5.6.3
Free C++ class library of cryptographic schemes
socketft.cpp
1 // socketft.cpp - written and placed in the public domain by Wei Dai
2 
3 #include "pch.h"
4 
5 // TODO: http://github.com/weidai11/cryptopp/issues/19
6 #define _WINSOCK_DEPRECATED_NO_WARNINGS
7 #include "socketft.h"
8 
9 #ifdef SOCKETS_AVAILABLE
10 
11 #include "wait.h"
12 
13 #ifdef USE_BERKELEY_STYLE_SOCKETS
14 #include <errno.h>
15 #include <netdb.h>
16 #include <unistd.h>
17 #include <arpa/inet.h>
18 #include <netinet/in.h>
19 #include <sys/ioctl.h>
20 #endif
21 
22 #ifdef PREFER_WINDOWS_STYLE_SOCKETS
23 # pragma comment(lib, "ws2_32.lib")
24 #endif
25 
26 NAMESPACE_BEGIN(CryptoPP)
27 
28 #ifdef USE_WINDOWS_STYLE_SOCKETS
29 const int SOCKET_EINVAL = WSAEINVAL;
30 const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
31 typedef int socklen_t;
32 #else
33 const int SOCKET_EINVAL = EINVAL;
34 const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
35 #endif
36 
37 // Solaris doesn't have INADDR_NONE
38 #ifndef INADDR_NONE
39 # define INADDR_NONE 0xffffffff
40 #endif /* INADDR_NONE */
41 
42 Socket::Err::Err(socket_t s, const std::string& operation, int error)
43  : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
44  , m_s(s)
45 {
46 }
47 
48 Socket::~Socket()
49 {
50  if (m_own)
51  {
52  try
53  {
54  CloseSocket();
55  }
56  catch (const Exception&)
57  {
58  assert(0);
59  }
60  }
61 }
62 
63 void Socket::AttachSocket(socket_t s, bool own)
64 {
65  if (m_own)
66  CloseSocket();
67 
68  m_s = s;
69  m_own = own;
70  SocketChanged();
71 }
72 
73 socket_t Socket::DetachSocket()
74 {
75  socket_t s = m_s;
76  m_s = INVALID_SOCKET;
77  SocketChanged();
78  return s;
79 }
80 
81 void Socket::Create(int nType)
82 {
83  assert(m_s == INVALID_SOCKET);
84  m_s = socket(AF_INET, nType, 0);
85  CheckAndHandleError("socket", m_s);
86  m_own = true;
87  SocketChanged();
88 }
89 
90 void Socket::CloseSocket()
91 {
92  if (m_s != INVALID_SOCKET)
93  {
94 #ifdef USE_WINDOWS_STYLE_SOCKETS
95  CancelIo((HANDLE) m_s);
96  CheckAndHandleError_int("closesocket", closesocket(m_s));
97 #else
98  CheckAndHandleError_int("close", close(m_s));
99 #endif
100  m_s = INVALID_SOCKET;
101  SocketChanged();
102  }
103 }
104 
105 void Socket::Bind(unsigned int port, const char *addr)
106 {
107  sockaddr_in sa;
108  memset(&sa, 0, sizeof(sa));
109  sa.sin_family = AF_INET;
110 
111  if (addr == NULL)
112  sa.sin_addr.s_addr = htonl(INADDR_ANY);
113  else
114  {
115  unsigned long result = inet_addr(addr);
116  if (result == INADDR_NONE)
117  {
118  SetLastError(SOCKET_EINVAL);
119  CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
120  }
121  sa.sin_addr.s_addr = result;
122  }
123 
124  sa.sin_port = htons((u_short)port);
125 
126  Bind((sockaddr *)&sa, sizeof(sa));
127 }
128 
129 void Socket::Bind(const sockaddr *psa, socklen_t saLen)
130 {
131  assert(m_s != INVALID_SOCKET);
132  // cygwin workaround: needs const_cast
133  CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
134 }
135 
136 void Socket::Listen(int backlog)
137 {
138  assert(m_s != INVALID_SOCKET);
139  CheckAndHandleError_int("listen", listen(m_s, backlog));
140 }
141 
142 bool Socket::Connect(const char *addr, unsigned int port)
143 {
144  assert(addr != NULL);
145 
146  sockaddr_in sa;
147  memset(&sa, 0, sizeof(sa));
148  sa.sin_family = AF_INET;
149  sa.sin_addr.s_addr = inet_addr(addr);
150 
151  if (sa.sin_addr.s_addr == INADDR_NONE)
152  {
153  hostent *lphost = gethostbyname(addr);
154  if (lphost == NULL)
155  {
156  SetLastError(SOCKET_EINVAL);
157  CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
158  }
159  else
160  {
161  sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr;
162  }
163  }
164 
165  sa.sin_port = htons((u_short)port);
166 
167  return Connect((const sockaddr *)&sa, sizeof(sa));
168 }
169 
170 bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
171 {
172  assert(m_s != INVALID_SOCKET);
173  int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
174  if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
175  return false;
176  CheckAndHandleError_int("connect", result);
177  return true;
178 }
179 
180 bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
181 {
182  assert(m_s != INVALID_SOCKET);
183  socket_t s = accept(m_s, psa, psaLen);
184  if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
185  return false;
186  CheckAndHandleError("accept", s);
187  target.AttachSocket(s, true);
188  return true;
189 }
190 
191 void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
192 {
193  assert(m_s != INVALID_SOCKET);
194  CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
195 }
196 
197 void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
198 {
199  assert(m_s != INVALID_SOCKET);
200  CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
201 }
202 
203 unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
204 {
205  assert(m_s != INVALID_SOCKET);
206  int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
207  CheckAndHandleError_int("send", result);
208  return result;
209 }
210 
211 unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
212 {
213  assert(m_s != INVALID_SOCKET);
214  int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
215  CheckAndHandleError_int("recv", result);
216  return result;
217 }
218 
219 void Socket::ShutDown(int how)
220 {
221  assert(m_s != INVALID_SOCKET);
222  int result = shutdown(m_s, how);
223  CheckAndHandleError_int("shutdown", result);
224 }
225 
226 void Socket::IOCtl(long cmd, unsigned long *argp)
227 {
228  assert(m_s != INVALID_SOCKET);
229 #ifdef USE_WINDOWS_STYLE_SOCKETS
230  CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
231 #else
232  CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
233 #endif
234 }
235 
236 bool Socket::SendReady(const timeval *timeout)
237 {
238  fd_set fds;
239  FD_ZERO(&fds);
240  FD_SET(m_s, &fds);
241  int ready;
242  if (timeout == NULL)
243  ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
244  else
245  {
246  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
247  ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
248  }
249  CheckAndHandleError_int("select", ready);
250  return ready > 0;
251 }
252 
253 bool Socket::ReceiveReady(const timeval *timeout)
254 {
255  fd_set fds;
256  FD_ZERO(&fds);
257  FD_SET(m_s, &fds);
258  int ready;
259  if (timeout == NULL)
260  ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
261  else
262  {
263  timeval timeoutCopy = *timeout; // select() modified timeout on Linux
264  ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
265  }
266  CheckAndHandleError_int("select", ready);
267  return ready > 0;
268 }
269 
270 unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
271 {
272  int port = atoi(name);
273  if (IntToString(port) == name)
274  return port;
275 
276  servent *se = getservbyname(name, protocol);
277  if (!se)
278  throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
279  return ntohs(se->s_port);
280 }
281 
282 void Socket::StartSockets()
283 {
284 #ifdef USE_WINDOWS_STYLE_SOCKETS
285  WSADATA wsd;
286  int result = WSAStartup(0x0202, &wsd);
287  if (result != 0)
288  throw Err(INVALID_SOCKET, "WSAStartup", result);
289 #endif
290 }
291 
292 void Socket::ShutdownSockets()
293 {
294 #ifdef USE_WINDOWS_STYLE_SOCKETS
295  int result = WSACleanup();
296  if (result != 0)
297  throw Err(INVALID_SOCKET, "WSACleanup", result);
298 #endif
299 }
300 
301 int Socket::GetLastError()
302 {
303 #ifdef USE_WINDOWS_STYLE_SOCKETS
304  return WSAGetLastError();
305 #else
306  return errno;
307 #endif
308 }
309 
310 void Socket::SetLastError(int errorCode)
311 {
312 #ifdef USE_WINDOWS_STYLE_SOCKETS
313  WSASetLastError(errorCode);
314 #else
315  errno = errorCode;
316 #endif
317 }
318 
319 void Socket::HandleError(const char *operation) const
320 {
321  int err = GetLastError();
322  throw Err(m_s, operation, err);
323 }
324 
325 #ifdef USE_WINDOWS_STYLE_SOCKETS
326 
327 SocketReceiver::SocketReceiver(Socket &s)
328  : m_s(s), m_eofReceived(false), m_resultPending(false)
329 {
330  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
331  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
332  memset(&m_overlapped, 0, sizeof(m_overlapped));
333  m_overlapped.hEvent = m_event;
334 }
335 
336 SocketReceiver::~SocketReceiver()
337 {
338 #ifdef USE_WINDOWS_STYLE_SOCKETS
339  CancelIo((HANDLE) m_s.GetSocket());
340 #endif
341 }
342 
343 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
344 {
345  assert(!m_resultPending && !m_eofReceived);
346 
347  DWORD flags = 0;
348  // don't queue too much at once, or we might use up non-paged memory
349  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
350  if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
351  {
352  if (m_lastResult == 0)
353  m_eofReceived = true;
354  }
355  else
356  {
357  switch (WSAGetLastError())
358  {
359  default:
360  m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
361  case WSAEDISCON:
362  m_lastResult = 0;
363  m_eofReceived = true;
364  break;
365  case WSA_IO_PENDING:
366  m_resultPending = true;
367  }
368  }
369  return !m_resultPending;
370 }
371 
372 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
373 {
374  if (m_resultPending)
375  container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
376  else if (!m_eofReceived)
377  container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
378 }
379 
380 unsigned int SocketReceiver::GetReceiveResult()
381 {
382  if (m_resultPending)
383  {
384  DWORD flags = 0;
385  if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
386  {
387  if (m_lastResult == 0)
388  m_eofReceived = true;
389  }
390  else
391  {
392  switch (WSAGetLastError())
393  {
394  default:
395  m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
396  case WSAEDISCON:
397  m_lastResult = 0;
398  m_eofReceived = true;
399  }
400  }
401  m_resultPending = false;
402  }
403  return m_lastResult;
404 }
405 
406 // *************************************************************
407 
408 SocketSender::SocketSender(Socket &s)
409  : m_s(s), m_resultPending(false), m_lastResult(0)
410 {
411  m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
412  m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
413  memset(&m_overlapped, 0, sizeof(m_overlapped));
414  m_overlapped.hEvent = m_event;
415 }
416 
417 
418 SocketSender::~SocketSender()
419 {
420 #ifdef USE_WINDOWS_STYLE_SOCKETS
421  CancelIo((HANDLE) m_s.GetSocket());
422 #endif
423 }
424 
425 void SocketSender::Send(const byte* buf, size_t bufLen)
426 {
427  assert(!m_resultPending);
428  DWORD written = 0;
429  // don't queue too much at once, or we might use up non-paged memory
430  WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
431  if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
432  {
433  m_resultPending = false;
434  m_lastResult = written;
435  }
436  else
437  {
438  if (WSAGetLastError() != WSA_IO_PENDING)
439  m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
440 
441  m_resultPending = true;
442  }
443 }
444 
445 void SocketSender::SendEof()
446 {
447  assert(!m_resultPending);
448  m_s.ShutDown(SD_SEND);
449  m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
450  m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
451  m_resultPending = true;
452 }
453 
454 bool SocketSender::EofSent()
455 {
456  if (m_resultPending)
457  {
458  WSANETWORKEVENTS events;
459  m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
460  if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
461  throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
462  if (events.iErrorCode[FD_CLOSE_BIT] != 0)
463  throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
464  m_resultPending = false;
465  }
466  return m_lastResult != 0;
467 }
468 
469 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
470 {
471  if (m_resultPending)
472  container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
473  else
474  container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
475 }
476 
477 unsigned int SocketSender::GetSendResult()
478 {
479  if (m_resultPending)
480  {
481  DWORD flags = 0;
482  BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
483  m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
484  m_resultPending = false;
485  }
486  return m_lastResult;
487 }
488 
489 #endif
490 
491 #ifdef USE_BERKELEY_STYLE_SOCKETS
492 
493 SocketReceiver::SocketReceiver(Socket &s)
494  : m_s(s), m_eofReceived(false), m_lastResult(0)
495 {
496 }
497 
498 void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
499 {
500  if (!m_eofReceived)
501  container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
502 }
503 
504 bool SocketReceiver::Receive(byte* buf, size_t bufLen)
505 {
506  m_lastResult = m_s.Receive(buf, bufLen);
507  if (bufLen > 0 && m_lastResult == 0)
508  m_eofReceived = true;
509  return true;
510 }
511 
512 unsigned int SocketReceiver::GetReceiveResult()
513 {
514  return m_lastResult;
515 }
516 
517 SocketSender::SocketSender(Socket &s)
518  : m_s(s), m_lastResult(0)
519 {
520 }
521 
522 void SocketSender::Send(const byte* buf, size_t bufLen)
523 {
524  m_lastResult = m_s.Send(buf, bufLen);
525 }
526 
527 void SocketSender::SendEof()
528 {
529  m_s.ShutDown(SD_SEND);
530 }
531 
532 unsigned int SocketSender::GetSendResult()
533 {
534  return m_lastResult;
535 }
536 
537 void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
538 {
539  container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
540 }
541 
542 #endif
543 
544 NAMESPACE_END
545 
546 #endif // #ifdef SOCKETS_AVAILABLE
Base class for all exceptions thrown by Crypto++.
Definition: cryptlib.h:124
container of wait objects
Definition: wait.h:151
The operating system reported an error.
Definition: cryptlib.h:201
const T1 UnsignedMin(const T1 &a, const T2 &b)
Safe comparison of values that could be neagtive and incorrectly promoted.
Definition: misc.h:422
std::string IntToString(T value, unsigned int base=10)
Converts a value to a string.
Definition: misc.h:449
Crypto++ library namespace.