Java程序  |  133行  |  2.89 KB

// Copyright (c) 2005 Brian Wellington (bwelling@xbill.org)

package org.xbill.DNS;

import java.io.*;
import java.net.*;
import java.nio.*;
import java.nio.channels.*;

final class TCPClient extends Client {

public
TCPClient(long endTime) throws IOException {
	super(SocketChannel.open(), endTime);
}

void
bind(SocketAddress addr) throws IOException {
	SocketChannel channel = (SocketChannel) key.channel();
	channel.socket().bind(addr);
}

void
connect(SocketAddress addr) throws IOException {
	SocketChannel channel = (SocketChannel) key.channel();
	if (channel.connect(addr))
		return;
	key.interestOps(SelectionKey.OP_CONNECT);
	try {
		while (!channel.finishConnect()) {
			if (!key.isConnectable())
				blockUntil(key, endTime);
		}
	}
	finally {
		if (key.isValid())
			key.interestOps(0);
	}
}

void
send(byte [] data) throws IOException {
	SocketChannel channel = (SocketChannel) key.channel();
	verboseLog("TCP write", data);
	byte [] lengthArray = new byte[2];
	lengthArray[0] = (byte)(data.length >>> 8);
	lengthArray[1] = (byte)(data.length & 0xFF);
	ByteBuffer [] buffers = new ByteBuffer[2];
	buffers[0] = ByteBuffer.wrap(lengthArray);
	buffers[1] = ByteBuffer.wrap(data);
	int nsent = 0;
	key.interestOps(SelectionKey.OP_WRITE);
	try {
		while (nsent < data.length + 2) {
			if (key.isWritable()) {
				long n = channel.write(buffers);
				if (n < 0)
					throw new EOFException();
				nsent += (int) n;
				if (nsent < data.length + 2 &&
				    System.currentTimeMillis() > endTime)
					throw new SocketTimeoutException();
			} else
				blockUntil(key, endTime);
		}
	}
	finally {
		if (key.isValid())
			key.interestOps(0);
	}
}

private byte []
_recv(int length) throws IOException {
	SocketChannel channel = (SocketChannel) key.channel();
	int nrecvd = 0;
	byte [] data = new byte[length];
	ByteBuffer buffer = ByteBuffer.wrap(data);
	key.interestOps(SelectionKey.OP_READ);
	try {
		while (nrecvd < length) {
			if (key.isReadable()) {
				long n = channel.read(buffer);
				if (n < 0)
					throw new EOFException();
				nrecvd += (int) n;
				if (nrecvd < length &&
				    System.currentTimeMillis() > endTime)
					throw new SocketTimeoutException();
			} else
				blockUntil(key, endTime);
		}
	}
	finally {
		if (key.isValid())
			key.interestOps(0);
	}
	return data;
}

byte []
recv() throws IOException {
	byte [] buf = _recv(2);
	int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF);
	byte [] data = _recv(length);
	verboseLog("TCP read", data);
	return data;
}

static byte []
sendrecv(SocketAddress local, SocketAddress remote, byte [] data, long endTime)
throws IOException
{
	TCPClient client = new TCPClient(endTime);
	try {
		if (local != null)
			client.bind(local);
		client.connect(remote);
		client.send(data);
		return client.recv();
	}
	finally {
		client.cleanup();
	}
}

static byte []
sendrecv(SocketAddress addr, byte [] data, long endTime) throws IOException {
	return sendrecv(null, addr, data, endTime);
}

}