Test that SOCK_DESTROY affects poll() like a TCP RST.

Bug: 65684232
Test: all_tests.sh passes on 4.4 device kernel
Change-Id: I95ed9502cb55baa4f292edc8cbb28397b7881f65
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 24fe9dd..198ce61 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -18,6 +18,7 @@
 from errno import *  # pylint: disable=wildcard-import
 import os
 import random
+import select
 from socket import *  # pylint: disable=wildcard-import
 import struct
 import threading
@@ -108,19 +109,37 @@
     self.assertFalse("???" in decoded)
     return bytecode
 
-  def CloseDuringBlockingCall(self, sock, call, expected_errno):
+  def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
+    """Simulates an external event during a blocking call on sock.
+
+    Args:
+      sock: The socket to use.
+      call: A function, the call to make. Takes one parameter, sock.
+      expected_errno: The value that call is expected to fail with, or None if
+        call is expected to succeed.
+      event: A function, the event that will happen during the blocking call.
+        Takes one parameter, sock.
+    """
     thread = SocketExceptionThread(sock, call)
     thread.start()
     time.sleep(0.1)
-    self.sock_diag.CloseSocketFromFd(sock)
+    event(sock)
     thread.join(1)
     self.assertFalse(thread.is_alive())
-    self.assertIsNotNone(thread.exception)
-    self.assertTrue(isinstance(thread.exception, IOError),
-                    "Expected IOError, got %s" % thread.exception)
-    self.assertEqual(expected_errno, thread.exception.errno)
+    if expected_errno is not None:
+      self.assertIsNotNone(thread.exception)
+      self.assertTrue(isinstance(thread.exception, IOError),
+                      "Expected IOError, got %s" % thread.exception)
+      self.assertEqual(expected_errno, thread.exception.errno)
+    else:
+      self.assertIsNone(thread.exception)
     self.assertSocketClosed(sock)
 
+  def CloseDuringBlockingCall(self, sock, call, expected_errno):
+    self._EventDuringBlockingCall(
+        sock, call, expected_errno,
+        lambda sock: self.sock_diag.CloseSocketFromFd(sock))
+
   def setUp(self):
     super(SockDiagBaseTest, self).setUp()
     self.sock_diag = sock_diag.SockDiag()
@@ -652,6 +671,83 @@
       self.ExpectNoPacketsOn(self.netid, msg)
 
 
+class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
+  """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
+
+  The behaviour of poll() in these cases is not what we might expect: if only
+  POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
+  is (also) specified, it will only return POLLOUT.
+  """
+
+  POLLIN_OUT = select.POLLIN | select.POLLOUT
+  POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
+
+  def setUp(self):
+    super(PollOnCloseTest, self).setUp()
+    self.netid = random.choice(self.tuns.keys())
+
+  def BlockingPoll(self, sock, mask, expected_event):
+    p = select.poll()
+    p.register(sock, mask)
+    expected_fds = [(sock.fileno(), expected_event)]
+    # Don't block forever or we'll hang continuous test runs on failure.
+    # A 5-second timeout should be long enough not to be flaky.
+    actual_fds = p.poll(5000)
+    self.assertEqual(expected_fds, actual_fds)
+
+  def RstDuringBlockingCall(self, sock, call, expected_errno):
+    self._EventDuringBlockingCall(
+        sock, call, expected_errno,
+        lambda _: self.ReceiveRstPacketOn(self.netid))
+
+  def assertSocketErrors(self, errno):
+    # The first operation returns the expected errno.
+    self.assertRaisesErrno(errno, self.accepted.recv, 4096)
+
+    # Subsequent operations behave as normal.
+    self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+    self.assertEquals("", self.accepted.recv(4096))
+    self.assertEquals("", self.accepted.recv(4096))
+
+  def CheckPollDestroy(self, mask, expected_event):
+    """Interrupts a poll() with SOCK_DESTROY."""
+    for version in [4, 5, 6]:
+      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+      self.CloseDuringBlockingCall(
+          self.accepted,
+          lambda sock: self.BlockingPoll(sock, mask, expected_event),
+          None)
+      self.assertSocketErrors(ECONNABORTED)
+
+  def CheckPollRst(self, mask, expected_event):
+    """Interrupts a poll() by receiving a TCP RST."""
+    for version in [4, 5, 6]:
+      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
+      self.RstDuringBlockingCall(
+          self.accepted,
+          lambda sock: self.BlockingPoll(sock, mask, expected_event),
+          None)
+      self.assertSocketErrors(ECONNRESET)
+
+  def testReadPollRst(self):
+    self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP)
+
+  def testWritePollRst(self):
+    self.CheckPollRst(select.POLLOUT, select.POLLOUT)
+
+  def testReadWritePollRst(self):
+    self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT)
+
+  def testReadPollDestroy(self):
+    self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP)
+
+  def testWritePollDestroy(self):
+    self.CheckPollDestroy(select.POLLOUT, select.POLLOUT)
+
+  def testReadWritePollDestroy(self):
+    self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT)
+
+
 class SockDestroyUdpTest(SockDiagBaseTest):
 
   """Tests SOCK_DESTROY on UDP sockets.