Test SOCK_DESTROY on UDP sockets.
Change-Id: If781c33417ce46d7b408bf322b8ae7c6e3c196ab
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 57bc19c..b835261 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -54,7 +54,7 @@
525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
"""
@staticmethod
- def _CreateLotsOfSockets():
+ def _CreateLotsOfSockets(socktype):
# Dict mapping (addr, sport, dport) tuples to socketpairs.
socketpairs = {}
for _ in xrange(NUM_SOCKETS):
@@ -62,7 +62,7 @@
(AF_INET, "127.0.0.1"),
(AF_INET6, "::1"),
(AF_INET6, "::ffff:127.0.0.1")])
- socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
+ socketpair = net_test.CreateSocketPair(family, socktype, addr)
sport, dport = (socketpair[0].getsockname()[1],
socketpair[1].getsockname()[1])
socketpairs[(addr, sport, dport)] = socketpair
@@ -100,6 +100,19 @@
self.assertFalse("???" in decoded)
return bytecode
+ def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ thread = SocketExceptionThread(sock, call)
+ thread.start()
+ time.sleep(0.1)
+ self.sock_diag.CloseSocketFromFd(sock)
+ thread.join(1)
+ self.assertFalse(thread.is_alive())
+ self.assertIsNotNone(thread.exception)
+ self.assertTrue(isinstance(thread.exception, IOError),
+ "Expected IOError, got %s" % thread.exception)
+ self.assertEqual(expected_errno, thread.exception.errno)
+ self.assertSocketClosed(sock)
+
def setUp(self):
super(SockDiagBaseTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
@@ -126,7 +139,7 @@
def testFindsAllMySockets(self):
"""Tests that basic socket dumping works."""
- self.socketpairs = self._CreateLotsOfSockets()
+ self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
@@ -188,7 +201,7 @@
states = 1 << tcp_test.TCP_ESTABLISHED
self.assertMultiLineEqual(expected, bytecode.encode("hex"))
self.assertEquals(76, len(bytecode))
- self.socketpairs = self._CreateLotsOfSockets()
+ self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
states=states)
allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
@@ -336,7 +349,7 @@
"""
def testClosesSockets(self):
- self.socketpairs = self._CreateLotsOfSockets()
+ self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
for _, socketpair in self.socketpairs.iteritems():
# Close one of the sockets.
# This will send a RST that will close the other side as well.
@@ -361,12 +374,6 @@
# Check that both sockets in the pair are closed.
self.assertSocketsClosed(socketpair)
- def testNonTcpSockets(self):
- s = socket(AF_INET6, SOCK_DGRAM, 0)
- s.connect(("::1", 53))
- self.sock_diag.FindSockDiagFromFd(s) # No exceptions? Good.
- self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
-
# TODO:
# Test that killing unix sockets returns EOPNOTSUPP.
@@ -570,19 +577,6 @@
self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
- def CloseDuringBlockingCall(self, sock, call, expected_errno):
- thread = SocketExceptionThread(sock, call)
- thread.start()
- time.sleep(0.1)
- self.sock_diag.CloseSocketFromFd(sock)
- thread.join(1)
- self.assertFalse(thread.is_alive())
- self.assertIsNotNone(thread.exception)
- self.assertTrue(isinstance(thread.exception, IOError),
- "Expected IOError, got %s" % thread.exception)
- self.assertEqual(expected_errno, thread.exception.errno)
- self.assertSocketClosed(sock)
-
def testAcceptInterrupted(self):
"""Tests that accept() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
@@ -622,14 +616,59 @@
@unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "does not yet exist")
+class SockDestroyUdpTest(SockDiagBaseTest):
+
+ """Tests SOCK_DESTROY on UDP sockets.
+
+ Relevant kernel commits:
+ upstream net-next:
+ 5d77dca net: diag: support SOCK_DESTROY for UDP sockets
+ f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
+ """
+
+ def testClosesUdpSockets(self):
+ self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
+ for (addr, sport, dport), socketpair in self.socketpairs.iteritems():
+ s1, s2 = socketpair
+ self.assertSocketConnected(s1)
+ self.sock_diag.CloseSocketFromFd(s1)
+ self.assertSocketClosed(s1)
+ self.assertSocketConnected(s2)
+ self.sock_diag.CloseSocketFromFd(s2)
+ self.assertSocketClosed(s2)
+
+ def testReadInterrupted(self):
+ """Tests that read() is interrupted by SOCK_DESTROY."""
+ for version in [4, 5, 6]:
+ family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+ s = net_test.UDPSocket(family)
+ self.SelectInterface(s, random.choice(self.NETIDS), "mark")
+ addr = self.GetRemoteAddress(version)
+
+ # Check that reads on connected sockets are interrupted.
+ s.connect((addr, 53))
+ self.assertEquals(3, s.send("foo"))
+ self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
+ ECONNABORTED)
+
+ # A destroyed socket is no longer connected, but still usable.
+ self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
+ self.assertEquals(3, s.sendto("foo", (addr, 53)))
+
+ # Check that reads on unconnected sockets are also interrupted.
+ self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
+ ECONNABORTED)
+
+
[email protected](net_test.LINUX_VERSION >= (4, 9, 0), "does not yet exist")
class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
"""Tests SOCK_DIAG bytecode filters that use marks.
Relevant kernel commits:
upstream net-next:
- a52e95a net: diag: allow socket bytecode filters to match socket marks
627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
+ a52e95a net: diag: allow socket bytecode filters to match socket marks
"""
def FilterEstablishedSockets(self, instructions):