diff --git a/src/main/java/org/xbill/DNS/DClass.java b/src/main/java/org/xbill/DNS/DClass.java index 42f5c09b7..3d09fc3a9 100644 --- a/src/main/java/org/xbill/DNS/DClass.java +++ b/src/main/java/org/xbill/DNS/DClass.java @@ -31,6 +31,9 @@ public final class DClass { /** Matches any class */ public static final int ANY = 255; + /** Indicates on mDNS that querier will accept unicast replies from a multicast request. */ + public static final int UNICAST_RESPONSE = 0x8000; + private static class DClassMnemonic extends Mnemonic { public DClassMnemonic() { super("DClass", CASE_UPPER); diff --git a/src/main/java/org/xbill/DNS/NioTcpClient.java b/src/main/java/org/xbill/DNS/NioTcpClient.java index 9eef1bbb3..76f014677 100644 --- a/src/main/java/org/xbill/DNS/NioTcpClient.java +++ b/src/main/java/org/xbill/DNS/NioTcpClient.java @@ -10,7 +10,9 @@ import java.nio.channels.Selector; import java.nio.channels.SocketChannel; import java.time.Duration; +import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Queue; import java.util.concurrent.CompletableFuture; @@ -74,7 +76,8 @@ private static class Transaction { private final byte[] queryData; private final long endTime; private final SocketChannel channel; - private final CompletableFuture f; + private final CompletableFuture> f; + private boolean sendDone; void send() throws IOException { @@ -206,7 +209,7 @@ private void processRead() { int id = ((data[0] & 0xFF) << 8) + (data[1] & 0xFF); int qid = t.query.getHeader().getID(); if (id == qid) { - t.f.complete(data); + t.f.complete(Collections.singletonList(data)); it.remove(); return; } @@ -235,13 +238,13 @@ private static class ChannelKey { final InetSocketAddress remote; } - static CompletableFuture sendrecv( + static CompletableFuture> sendrecv( InetSocketAddress local, InetSocketAddress remote, Message query, byte[] data, Duration timeout) { - CompletableFuture f = new CompletableFuture<>(); + CompletableFuture> f = new CompletableFuture<>(); try { final Selector selector = selector(); long endTime = System.nanoTime() + timeout.toNanos(); diff --git a/src/main/java/org/xbill/DNS/NioUdpClient.java b/src/main/java/org/xbill/DNS/NioUdpClient.java index 2c7df7264..3f3ffa2ca 100644 --- a/src/main/java/org/xbill/DNS/NioUdpClient.java +++ b/src/main/java/org/xbill/DNS/NioUdpClient.java @@ -4,6 +4,7 @@ import java.io.EOFException; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.net.SocketException; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; @@ -12,7 +13,10 @@ import java.nio.channels.Selector; import java.security.SecureRandom; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentLinkedQueue; @@ -71,8 +75,7 @@ private static void checkTransactionTimeouts() { for (Iterator it = pendingTransactions.iterator(); it.hasNext(); ) { Transaction t = it.next(); if (t.endTime - System.nanoTime() < 0) { - t.silentCloseChannel(); - t.f.completeExceptionally(new SocketTimeoutException("Query timed out")); + t.closeTransaction(); it.remove(); } } @@ -81,19 +84,16 @@ private static void checkTransactionTimeouts() { @RequiredArgsConstructor private static class Transaction implements KeyProcessor { private final byte[] data; - private final int max; + final int max; private final long endTime; private final DatagramChannel channel; - private final CompletableFuture f; + private final SocketAddress remoteSocketAddress; + final CompletableFuture> f; void send() throws IOException { ByteBuffer buffer = ByteBuffer.wrap(data); - verboseLog( - "UDP write", - channel.socket().getLocalSocketAddress(), - channel.socket().getRemoteSocketAddress(), - data); - int n = channel.send(buffer, channel.socket().getRemoteSocketAddress()); + verboseLog("UDP write", channel.socket().getLocalSocketAddress(), remoteSocketAddress, data); + int n = channel.send(buffer, remoteSocketAddress); if (n <= 0) { throw new EOFException(); } @@ -109,10 +109,12 @@ public void processReadyKey(SelectionKey key) { DatagramChannel channel = (DatagramChannel) key.channel(); ByteBuffer buffer = ByteBuffer.allocate(max); + SocketAddress source; int read; try { - read = channel.read(buffer); - if (read <= 0) { + source = channel.receive(buffer); + read = buffer.position(); + if (read <= 0 || source == null) { throw new EOFException(); } } catch (IOException e) { @@ -125,17 +127,13 @@ public void processReadyKey(SelectionKey key) { buffer.flip(); byte[] data = new byte[read]; System.arraycopy(buffer.array(), 0, data, 0, read); - verboseLog( - "UDP read", - channel.socket().getLocalSocketAddress(), - channel.socket().getRemoteSocketAddress(), - data); + verboseLog("UDP read", channel.socket().getLocalSocketAddress(), source, data); silentCloseChannel(); - f.complete(data); + f.complete(Collections.singletonList(data)); pendingTransactions.remove(this); } - private void silentCloseChannel() { + void silentCloseChannel() { try { channel.disconnect(); channel.close(); @@ -143,11 +141,73 @@ private void silentCloseChannel() { // ignore, we either already have everything we need or can't do anything } } + + void closeTransaction() { + silentCloseChannel(); + f.completeExceptionally(new SocketTimeoutException("Query timed out")); + } + } + + private static class MultiAnswerTransaction extends Transaction { + MultiAnswerTransaction( + byte[] query, + int max, + long endTime, + DatagramChannel channel, + SocketAddress remoteSocketAddress, + CompletableFuture> f) { + super(query, max, endTime, channel, remoteSocketAddress, f); + } + + public void processReadyKey(SelectionKey key) { + if (!key.isReadable()) { + silentCloseChannel(); + f.completeExceptionally(new EOFException("channel not readable")); + pendingTransactions.remove(this); + return; + } + + DatagramChannel channel = (DatagramChannel) key.channel(); + ByteBuffer buffer = ByteBuffer.allocate(max); + SocketAddress source; + int read; + try { + source = channel.receive(buffer); + read = buffer.position(); + if (read <= 0 || source == null) { + return; // ignore this datagram + } + } catch (IOException e) { + silentCloseChannel(); + f.completeExceptionally(e); + pendingTransactions.remove(this); + return; + } + + buffer.flip(); + byte[] data = new byte[read]; + System.arraycopy(buffer.array(), 0, data, 0, read); + verboseLog("UDP read", channel.socket().getLocalSocketAddress(), source, data); + answers.add(data); + } + + private ArrayList answers = new ArrayList<>(); + + @Override + void closeTransaction() { + if (answers.size() > 0) { + silentCloseChannel(); + f.complete(answers); + } else { + // we failed, no answers + super.closeTransaction(); + } + } } - static CompletableFuture sendrecv( + static CompletableFuture> sendrecv( InetSocketAddress local, InetSocketAddress remote, byte[] data, int max, Duration timeout) { - CompletableFuture f = new CompletableFuture<>(); + CompletableFuture> f = new CompletableFuture<>(); try { final Selector selector = selector(); DatagramChannel channel = DatagramChannel.open(); @@ -169,6 +229,9 @@ static CompletableFuture sendrecv( addr = new InetSocketAddress(local.getAddress(), port); } + if (addr.getPort() == SimpleResolver.RESERVED_MDNS_PORT) { + continue; // can't use the mDNS server port, try again + } channel.bind(addr); bound = true; @@ -185,9 +248,15 @@ static CompletableFuture sendrecv( } } - channel.connect(remote); long endTime = System.nanoTime() + timeout.toNanos(); - Transaction t = new Transaction(data, max, endTime, channel, f); + Transaction t; + if (!remote.getAddress().isMulticastAddress()) { + channel.connect(remote); + t = new Transaction(data, max, endTime, channel, remote, f); + } else { + // stop this a little before the timeout so we can report what answers we did get + t = new MultiAnswerTransaction(data, max, endTime - 1000000000L, channel, remote, f); + } pendingTransactions.add(t); registrationQueue.add(t); selector.wakeup(); diff --git a/src/main/java/org/xbill/DNS/SimpleResolver.java b/src/main/java/org/xbill/DNS/SimpleResolver.java index 6a71eef57..eccddac6c 100644 --- a/src/main/java/org/xbill/DNS/SimpleResolver.java +++ b/src/main/java/org/xbill/DNS/SimpleResolver.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.SocketTimeoutException; import java.net.UnknownHostException; import java.time.Duration; import java.util.List; @@ -28,6 +29,9 @@ public class SimpleResolver implements Resolver { /** The default port to send queries to */ public static final int DEFAULT_PORT = 53; + /** The port we can't use as a client because it's reserved for mDNS servers. */ + public static final int RESERVED_MDNS_PORT = 5353; + /** The default EDNS payload size */ public static final int DEFAULT_EDNS_PAYLOADSIZE = 1280; @@ -339,7 +343,7 @@ CompletableFuture sendAsync(Message query, boolean forceTcp) { address.getPort()); log.trace("Query:\n{}", query); - CompletableFuture result; + CompletableFuture> result; if (tcp) { result = NioTcpClient.sendrecv(localAddress, address, query, out, timeoutValue); } else { @@ -347,74 +351,52 @@ CompletableFuture sendAsync(Message query, boolean forceTcp) { } return result.thenComposeAsync( - in -> { + v -> { CompletableFuture f = new CompletableFuture<>(); - // Check that the response is long enough. - if (in.length < Header.LENGTH) { - f.completeExceptionally(new WireParseException("invalid DNS header - too short")); - return f; - } - - // Check that the response ID matches the query ID. We want - // to check this before actually parsing the message, so that - // if there's a malformed response that's not ours, it - // doesn't confuse us. - int id = ((in[0] & 0xFF) << 8) + (in[1] & 0xFF); - if (id != qid) { - f.completeExceptionally( - new WireParseException("invalid message id: expected " + qid + "; got id " + id)); - return f; - } - - Message response; - try { - response = parseMessage(in); - } catch (WireParseException e) { - f.completeExceptionally(e); - return f; - } - - // validate name, class and type (rfc5452#section-9.1) - if (!query.getQuestion().getName().equals(response.getQuestion().getName())) { - f.completeExceptionally( - new WireParseException( - "invalid name in message: expected " - + query.getQuestion().getName() - + "; got " - + response.getQuestion().getName())); - return f; - } - - if (query.getQuestion().getDClass() != response.getQuestion().getDClass()) { - f.completeExceptionally( - new WireParseException( - "invalid class in message: expected " - + DClass.string(query.getQuestion().getDClass()) - + "; got " - + DClass.string(response.getQuestion().getDClass()))); - return f; - } - - if (query.getQuestion().getType() != response.getQuestion().getType()) { - f.completeExceptionally( - new WireParseException( - "invalid type in message: expected " - + Type.string(query.getQuestion().getType()) - + "; got " - + Type.string(response.getQuestion().getType()))); - return f; + VerifyOneResponse verifyOneResponse = new VerifyOneResponse(query, qid); + @SuppressWarnings("unchecked") + Message response = null; + Throwable responseFailureException = null; + for (byte[] in : v) { + Message r = verifyOneResponse.parse(in); + if (r == null) { + responseFailureException = verifyOneResponse.getFailureReason(); + continue; + } + + verifyTSIG(query, r, in, tsig); + if (!tcp && !ignoreTruncation && r.getHeader().getFlag(Flags.TC)) { + log.debug("Got truncated response for id {}, discarding`", qid); + log.trace("Truncated response: {}", r); + continue; + } + + if (response == null) { + response = r; + } else { + // not the first answer, append the results to first response + for (Record rec : r.getSection(Section.ANSWER)) { + response.addRecord(rec, Section.ANSWER); + } + for (Record rec : r.getSection(Section.AUTHORITY)) { + response.addRecord(rec, Section.AUTHORITY); + } + for (Record rec : r.getSection(Section.ADDITIONAL)) { + response.addRecord(rec, Section.ADDITIONAL); + } + } } - verifyTSIG(query, response, in, tsig); - if (!tcp && !ignoreTruncation && response.getHeader().getFlag(Flags.TC)) { - log.debug("Got truncated response for id {}, retrying via TCP", qid); - log.trace("Truncated response: {}", response); - return sendAsync(query, true); + if (response != null) { + response.setResolver(this); + f.complete(response); + } else { + if (responseFailureException == null) { + responseFailureException = new SocketTimeoutException("Query timed out"); + } + f.completeExceptionally(responseFailureException); } - - response.setResolver(this); - f.complete(response); return f; }); } @@ -444,4 +426,90 @@ private Message sendAXFR(Message query) throws IOException { public String toString() { return "SimpleResolver [" + address + "]"; } + + private class VerifyOneResponse { + private Message query; + private int qid; + private Throwable failureReasonException; + + public VerifyOneResponse(Message query, int qid) { + this.query = query; + this.qid = qid; + } + + /** + * Keep the exception that caused this to fail. We don't throw the exception immediately + * because, when using mDNS, we might get both failing and successful answers, in which case we + * just want to discard the bad answers and keep the good ones; the exception only gets thrown + * if only bad answers are received. + * + * @return the Throwable reporting why this parse failed. + */ + public Throwable getFailureReason() { + return failureReasonException; + } + + public Message parse(byte[] in) { + failureReasonException = null; + + // Check that the response is long enough. + if (in.length < Header.LENGTH) { + failureReasonException = new WireParseException("invalid DNS header - too short"); + return null; + } + + // Check that the response ID matches the query ID. We want + // to check this before actually parsing the message, so that + // if there's a malformed response that's not ours, it + // doesn't confuse us. + int id = ((in[0] & 0xFF) << 8) + (in[1] & 0xFF); + if (id != qid) { + failureReasonException = + new WireParseException("invalid message id: expected " + qid + "; got id " + id); + return null; + } + + Message r; + try { + r = parseMessage(in); + } catch (WireParseException e) { + failureReasonException = e; + return null; + } + + // validate name, class and type (rfc5452#section-9.1) + if (!query.getQuestion().getName().equals(r.getQuestion().getName())) { + failureReasonException = + new WireParseException( + "invalid name in message: expected " + + query.getQuestion().getName() + + "; got " + + r.getQuestion().getName()); + return null; + } + + if ((query.getQuestion().getDClass() & ~DClass.UNICAST_RESPONSE) + != (r.getQuestion().getDClass() & ~DClass.UNICAST_RESPONSE)) { + failureReasonException = + new WireParseException( + "invalid class in message: expected " + + DClass.string(query.getQuestion().getDClass()) + + "; got " + + DClass.string(r.getQuestion().getDClass())); + return null; + } + + if (query.getQuestion().getType() != r.getQuestion().getType()) { + failureReasonException = + new WireParseException( + "invalid type in message: expected " + + Type.string(query.getQuestion().getType()) + + "; got " + + Type.string(r.getQuestion().getType())); + return null; + } + + return r; + } + } } diff --git a/src/test/java/org/xbill/DNS/NioTcpClientTest.java b/src/test/java/org/xbill/DNS/NioTcpClientTest.java index 1b2e009ed..c99d936a4 100644 --- a/src/test/java/org/xbill/DNS/NioTcpClientTest.java +++ b/src/test/java/org/xbill/DNS/NioTcpClientTest.java @@ -67,7 +67,7 @@ void testResponseStream() throws InterruptedException, IOException { .thenAccept( d -> { try { - clientReceivedAnswers[jj] = new Message(d); + clientReceivedAnswers[jj] = new Message(d.get(0)); cdl2.countDown(); } catch (IOException e) { fail(e);