Test that SOCK_DESTROY affects poll() like a TCP RST.
Bug: 65684232
Test: all_tests.sh passes on 4.4 device kernel
Change-Id: I95ed9502cb55baa4f292edc8cbb28397b7881f65
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 24fe9dd..198ce61 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -18,6 +18,7 @@
from errno import * # pylint: disable=wildcard-import
import os
import random
+import select
from socket import * # pylint: disable=wildcard-import
import struct
import threading
@@ -108,19 +109,37 @@
self.assertFalse("???" in decoded)
return bytecode
- def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
+ """Simulates an external event during a blocking call on sock.
+
+ Args:
+ sock: The socket to use.
+ call: A function, the call to make. Takes one parameter, sock.
+ expected_errno: The value that call is expected to fail with, or None if
+ call is expected to succeed.
+ event: A function, the event that will happen during the blocking call.
+ Takes one parameter, sock.
+ """
thread = SocketExceptionThread(sock, call)
thread.start()
time.sleep(0.1)
- self.sock_diag.CloseSocketFromFd(sock)
+ event(sock)
thread.join(1)
self.assertFalse(thread.is_alive())
- self.assertIsNotNone(thread.exception)
- self.assertTrue(isinstance(thread.exception, IOError),
- "Expected IOError, got %s" % thread.exception)
- self.assertEqual(expected_errno, thread.exception.errno)
+ if expected_errno is not None:
+ self.assertIsNotNone(thread.exception)
+ self.assertTrue(isinstance(thread.exception, IOError),
+ "Expected IOError, got %s" % thread.exception)
+ self.assertEqual(expected_errno, thread.exception.errno)
+ else:
+ self.assertIsNone(thread.exception)
self.assertSocketClosed(sock)
+ def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ self._EventDuringBlockingCall(
+ sock, call, expected_errno,
+ lambda sock: self.sock_diag.CloseSocketFromFd(sock))
+
def setUp(self):
super(SockDiagBaseTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
@@ -652,6 +671,83 @@
self.ExpectNoPacketsOn(self.netid, msg)
+class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
+ """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
+
+ The behaviour of poll() in these cases is not what we might expect: if only
+ POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
+ is (also) specified, it will only return POLLOUT.
+ """
+
+ POLLIN_OUT = select.POLLIN | select.POLLOUT
+ POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
+
+ def setUp(self):
+ super(PollOnCloseTest, self).setUp()
+ self.netid = random.choice(self.tuns.keys())
+
+ def BlockingPoll(self, sock, mask, expected_event):
+ p = select.poll()
+ p.register(sock, mask)
+ expected_fds = [(sock.fileno(), expected_event)]
+ # Don't block forever or we'll hang continuous test runs on failure.
+ # A 5-second timeout should be long enough not to be flaky.
+ actual_fds = p.poll(5000)
+ self.assertEqual(expected_fds, actual_fds)
+
+ def RstDuringBlockingCall(self, sock, call, expected_errno):
+ self._EventDuringBlockingCall(
+ sock, call, expected_errno,
+ lambda _: self.ReceiveRstPacketOn(self.netid))
+
+ def assertSocketErrors(self, errno):
+ # The first operation returns the expected errno.
+ self.assertRaisesErrno(errno, self.accepted.recv, 4096)
+
+ # Subsequent operations behave as normal.
+ self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+ self.assertEquals("", self.accepted.recv(4096))
+ self.assertEquals("", self.accepted.recv(4096))
+
+ def CheckPollDestroy(self, mask, expected_event):
+ """Interrupts a poll() with SOCK_DESTROY."""
+ for version in [4, 5, 6]:
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+ self.CloseDuringBlockingCall(
+ self.accepted,
+ lambda sock: self.BlockingPoll(sock, mask, expected_event),
+ None)
+ self.assertSocketErrors(ECONNABORTED)
+
+ def CheckPollRst(self, mask, expected_event):
+ """Interrupts a poll() by receiving a TCP RST."""
+ for version in [4, 5, 6]:
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+ self.RstDuringBlockingCall(
+ self.accepted,
+ lambda sock: self.BlockingPoll(sock, mask, expected_event),
+ None)
+ self.assertSocketErrors(ECONNRESET)
+
+ def testReadPollRst(self):
+ self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP)
+
+ def testWritePollRst(self):
+ self.CheckPollRst(select.POLLOUT, select.POLLOUT)
+
+ def testReadWritePollRst(self):
+ self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT)
+
+ def testReadPollDestroy(self):
+ self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP)
+
+ def testWritePollDestroy(self):
+ self.CheckPollDestroy(select.POLLOUT, select.POLLOUT)
+
+ def testReadWritePollDestroy(self):
+ self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT)
+
+
class SockDestroyUdpTest(SockDiagBaseTest):
"""Tests SOCK_DESTROY on UDP sockets.