De-flakify sock_diag_test.

SockDestroyUdpTest.testClosesUdpSockets was failing a low
percentage of the time because of two bugs:

1. When creating a UDP socket pair, we were connecting one socket
   to the other, but the other socket to iself. This made it
   possible for sockets on ::ffff:127.0.0.1 and ::1 to be bound
   to the same port, and:
2. The code that attempts to get a diag_msg for a given socket
   did not consider the fact that socket dumps, unlike operations
   that operate on a specific socket with a cookie, do not use
   the IP addresses in idiag_src and idiag_dst. This made it
   so that attempting to close a socket on [::1:1234 might have
   instead closed a different socket on [::ffff:127.0.0.1]:1234.

Fix #1 by not doing the wrong thing, and fix #2 by using the
inode number to ensure that FindSockInfoFromFd returns the
requested socket.

Also remove FindSockInfoFromReq, since it was only used once and
was potentially misleading - while it only returns one socket, it
might not return the one the caller expects. This allows us to
tighten the code and ensure that operations that are supposed to
return one socket only ever return one socket.

Test: all_tests.sh passes on android-3.18.
Bug: 31119353
Change-Id: I5d65e5a30c37490db516b5c6e730f89b5fea1b27
diff --git a/net/test/net_test.py b/net/test/net_test.py
index f58e2b9..0048ae6 100755
--- a/net/test/net_test.py
+++ b/net/test/net_test.py
@@ -176,14 +176,14 @@
   addr = listensock.getsockname()
   if socktype == SOCK_STREAM:
     listensock.listen(1)
-  clientsock.connect(addr)
+  clientsock.connect(listensock.getsockname())
   if socktype == SOCK_STREAM:
     acceptedsock, _ = listensock.accept()
     DisableFinWait(clientsock)
     DisableFinWait(acceptedsock)
     listensock.close()
   else:
-    listensock.connect(addr)
+    listensock.connect(clientsock.getsockname())
     acceptedsock = listensock
   return clientsock, acceptedsock
 
diff --git a/net/test/sock_diag.py b/net/test/sock_diag.py
index 4efd0c5..13ed0a8 100755
--- a/net/test/sock_diag.py
+++ b/net/test/sock_diag.py
@@ -19,6 +19,7 @@
 # pylint: disable=g-bad-todo
 
 import errno
+import os
 from socket import *  # pylint: disable=wildcard-import
 import struct
 
@@ -374,15 +375,21 @@
     sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
 
-  def FindSockInfoFromReq(self, req):
-    for diag_msg, attrs in self.Dump(req, ""):
-      return diag_msg, attrs
-    raise ValueError("Dump of %s returned no sockets" % req)
-
   def FindSockInfoFromFd(self, s):
     """Gets a diag_msg and attrs from the kernel for the specified socket."""
     req = self.DiagReqFromSocket(s)
-    return self.FindSockInfoFromReq(req)
+    # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it
+    # only uses them when targeting a specific socket with a cookie. Check the
+    # the inode number to ensure we don't mistakenly match another socket on
+    # the same port but with a different IP address.
+    inode = os.fstat(s.fileno()).st_ino
+    results = self.Dump(req, "")
+    if len(results) == 0:
+      raise ValueError("Dump of %s returned no sockets" % req)
+    for diag_msg, attrs in results:
+      if diag_msg.inode == inode:
+        return diag_msg, attrs
+    raise ValueError("Dump of %s did not contain inode %d" % (req, inode))
 
   def FindSockDiagFromFd(self, s):
     """Gets an InetDiagMsg from the kernel for the specified socket."""
diff --git a/net/test/sock_diag_test.py b/net/test/sock_diag_test.py
index 1d57c34..c7ac0d4 100755
--- a/net/test/sock_diag_test.py
+++ b/net/test/sock_diag_test.py
@@ -504,8 +504,10 @@
       # to work on 3.10.
       if net_test.LINUX_VERSION >= (3, 18):
         diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
-        diag_msg, attrs = self.sock_diag.FindSockInfoFromReq(diag_req)
-        self.assertEquals(tcp_test.TCP_FIN_WAIT2, diag_msg.state)
+        infos = self.sock_diag.Dump(diag_req, "")
+        self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
+                            for diag_msg, attrs in infos),
+                        "Expected to find FIN_WAIT2 socket in %s" % infos)
 
   def FindChildSockets(self, s):
     """Finds the SYN_RECV child sockets of a given listening socket."""