Test SOCK_DESTROY on UDP sockets.

Change-Id: If781c33417ce46d7b408bf322b8ae7c6e3c196ab
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 57bc19c..b835261 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -54,7 +54,7 @@
         525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
   """
   @staticmethod
-  def _CreateLotsOfSockets():
+  def _CreateLotsOfSockets(socktype):
     # Dict mapping (addr, sport, dport) tuples to socketpairs.
     socketpairs = {}
     for _ in xrange(NUM_SOCKETS):
@@ -62,7 +62,7 @@
           (AF_INET, "127.0.0.1"),
           (AF_INET6, "::1"),
           (AF_INET6, "::ffff:127.0.0.1")])
-      socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
+      socketpair = net_test.CreateSocketPair(family, socktype, addr)
       sport, dport = (socketpair[0].getsockname()[1],
                       socketpair[1].getsockname()[1])
       socketpairs[(addr, sport, dport)] = socketpair
@@ -100,6 +100,19 @@
     self.assertFalse("???" in decoded)
     return bytecode
 
+  def CloseDuringBlockingCall(self, sock, call, expected_errno):
+    thread = SocketExceptionThread(sock, call)
+    thread.start()
+    time.sleep(0.1)
+    self.sock_diag.CloseSocketFromFd(sock)
+    thread.join(1)
+    self.assertFalse(thread.is_alive())
+    self.assertIsNotNone(thread.exception)
+    self.assertTrue(isinstance(thread.exception, IOError),
+                    "Expected IOError, got %s" % thread.exception)
+    self.assertEqual(expected_errno, thread.exception.errno)
+    self.assertSocketClosed(sock)
+
   def setUp(self):
     super(SockDiagBaseTest, self).setUp()
     self.sock_diag = sock_diag.SockDiag()
@@ -126,7 +139,7 @@
 
   def testFindsAllMySockets(self):
     """Tests that basic socket dumping works."""
-    self.socketpairs = self._CreateLotsOfSockets()
+    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
     sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
     self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
 
@@ -188,7 +201,7 @@
     states = 1 << tcp_test.TCP_ESTABLISHED
     self.assertMultiLineEqual(expected, bytecode.encode("hex"))
     self.assertEquals(76, len(bytecode))
-    self.socketpairs = self._CreateLotsOfSockets()
+    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
     filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
                                                         states=states)
     allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
@@ -336,7 +349,7 @@
   """
 
   def testClosesSockets(self):
-    self.socketpairs = self._CreateLotsOfSockets()
+    self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
     for _, socketpair in self.socketpairs.iteritems():
       # Close one of the sockets.
       # This will send a RST that will close the other side as well.
@@ -361,12 +374,6 @@
       # Check that both sockets in the pair are closed.
       self.assertSocketsClosed(socketpair)
 
-  def testNonTcpSockets(self):
-    s = socket(AF_INET6, SOCK_DGRAM, 0)
-    s.connect(("::1", 53))
-    self.sock_diag.FindSockDiagFromFd(s)  # No exceptions? Good.
-    self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
-
   # TODO:
   # Test that killing unix sockets returns EOPNOTSUPP.
 
@@ -570,19 +577,6 @@
       self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
       self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
 
-  def CloseDuringBlockingCall(self, sock, call, expected_errno):
-    thread = SocketExceptionThread(sock, call)
-    thread.start()
-    time.sleep(0.1)
-    self.sock_diag.CloseSocketFromFd(sock)
-    thread.join(1)
-    self.assertFalse(thread.is_alive())
-    self.assertIsNotNone(thread.exception)
-    self.assertTrue(isinstance(thread.exception, IOError),
-                    "Expected IOError, got %s" % thread.exception)
-    self.assertEqual(expected_errno, thread.exception.errno)
-    self.assertSocketClosed(sock)
-
   def testAcceptInterrupted(self):
     """Tests that accept() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
@@ -622,14 +616,59 @@
 
 
 @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "does not yet exist")
+class SockDestroyUdpTest(SockDiagBaseTest):
+
+  """Tests SOCK_DESTROY on UDP sockets.
+
+    Relevant kernel commits:
+      upstream net-next:
+        5d77dca net: diag: support SOCK_DESTROY for UDP sockets
+        f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
+  """
+
+  def testClosesUdpSockets(self):
+    self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
+    for (addr, sport, dport), socketpair in self.socketpairs.iteritems():
+      s1, s2 = socketpair
+      self.assertSocketConnected(s1)
+      self.sock_diag.CloseSocketFromFd(s1)
+      self.assertSocketClosed(s1)
+      self.assertSocketConnected(s2)
+      self.sock_diag.CloseSocketFromFd(s2)
+      self.assertSocketClosed(s2)
+
+  def testReadInterrupted(self):
+    """Tests that read() is interrupted by SOCK_DESTROY."""
+    for version in [4, 5, 6]:
+      family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+      s = net_test.UDPSocket(family)
+      self.SelectInterface(s, random.choice(self.NETIDS), "mark")
+      addr = self.GetRemoteAddress(version)
+
+      # Check that reads on connected sockets are interrupted.
+      s.connect((addr, 53))
+      self.assertEquals(3, s.send("foo"))
+      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
+                                   ECONNABORTED)
+
+      # A destroyed socket is no longer connected, but still usable.
+      self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
+      self.assertEquals(3, s.sendto("foo", (addr, 53)))
+
+      # Check that reads on unconnected sockets are also interrupted.
+      self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
+                                   ECONNABORTED)
+
+
[email protected](net_test.LINUX_VERSION >= (4, 9, 0), "does not yet exist")
 class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
 
   """Tests SOCK_DIAG bytecode filters that use marks.
 
     Relevant kernel commits:
       upstream net-next:
-        a52e95a net: diag: allow socket bytecode filters to match socket marks
         627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
+        a52e95a net: diag: allow socket bytecode filters to match socket marks
   """
 
   def FilterEstablishedSockets(self, instructions):