1package muduo.rpc; 2 3import java.util.Map; 4import java.util.concurrent.ConcurrentHashMap; 5import java.util.concurrent.atomic.AtomicLong; 6 7import muduo.rpc.proto.RpcProto.ErrorCode; 8import muduo.rpc.proto.RpcProto.MessageType; 9import muduo.rpc.proto.RpcProto.RpcMessage; 10import muduo.rpc.proto.RpcProto.RpcMessage.Builder; 11 12import org.jboss.netty.channel.Channel; 13import org.jboss.netty.channel.ChannelHandlerContext; 14import org.jboss.netty.channel.MessageEvent; 15 16import com.google.protobuf.BlockingRpcChannel; 17import com.google.protobuf.ByteString; 18import com.google.protobuf.Descriptors.MethodDescriptor; 19import com.google.protobuf.Message; 20import com.google.protobuf.RpcCallback; 21import com.google.protobuf.RpcController; 22import com.google.protobuf.Service; 23import com.google.protobuf.ServiceException; 24 25public class RpcChannel implements com.google.protobuf.RpcChannel, BlockingRpcChannel { 26 27 private final static class BlockingRpcCallback implements RpcCallback<Message> { 28 public Message response; 29 30 @Override 31 public void run(Message response) { 32 synchronized (this) { 33 this.response = response; 34 notify(); 35 } 36 } 37 } 38 39 private final static class Outstanding { 40 41 public Message responsePrototype; 42 public RpcCallback<Message> done; 43 44 public Outstanding(Message responsePrototype, RpcCallback<Message> done) { 45 this.responsePrototype = responsePrototype; 46 this.done = done; 47 } 48 } 49 50 private Channel channel; 51 private AtomicLong id = new AtomicLong(1); 52 private Map<Long, Outstanding> outstandings = new ConcurrentHashMap<Long, Outstanding>(); 53 private Map<String, Service> services; 54 55 public RpcChannel(Channel channel) { 56 this.channel = channel; 57 } 58 59 public void setServiceMap(Map<String, Service> services) { 60 this.services = services; 61 } 62 63 public Channel getChannel() { 64 return channel; 65 } 66 67 public void disconnect() { 68 channel.disconnect(); 69 } 70 71 public void messageReceived(ChannelHandlerContext ctx, final MessageEvent e) { 72 RpcMessage message = (RpcMessage) e.getMessage(); 73 assert e.getChannel() == channel; 74 // System.out.println(message); 75 if (message.getType() == MessageType.REQUEST) { 76 doRequest(message); 77 } else if (message.getType() == MessageType.RESPONSE) { 78 Outstanding o = outstandings.remove(message.getId()); 79 // System.err.println("messageReceived " + this); 80 if (o != null) { 81 Message resp = fromByteString(o.responsePrototype, message.getResponse()); 82 o.done.run(resp); 83 } else { 84 System.err.println("Unknown id " + message.getId()); 85 } 86 } 87 } 88 89 private void doRequest(RpcMessage message) { 90 Service service = services.get(message.getService()); 91 Builder errorBuilder = RpcMessage.newBuilder().setType(MessageType.ERROR); 92 boolean succeed = false; 93 if (service != null) { 94 MethodDescriptor method = service.getDescriptorForType() 95 .findMethodByName(message.getMethod()); 96 if (method != null) { 97 Message request = fromByteString(service.getRequestPrototype(method), 98 message.getRequest()); 99 if (request != null) { 100 final long id = message.getId(); 101 RpcCallback<Message> done = new RpcCallback<Message>() { 102 @Override 103 public void run(Message response) { 104 done(response, id); 105 } 106 }; 107 succeed = doCall(request, service, method, done); 108 } else { 109 errorBuilder.setError(ErrorCode.INVALID_REQUEST); 110 } 111 } else { 112 errorBuilder.setError(ErrorCode.NO_METHOD); 113 } 114 } else { 115 errorBuilder.setError(ErrorCode.NO_SERVICE); 116 } 117 if (!succeed) { 118 RpcMessage resp = errorBuilder.build(); 119 channel.write(resp); 120 } 121 } 122 123 private Message fromByteString(Message prototype, ByteString bytes) { 124 Message message = null; 125 try { 126 message = prototype.toBuilder().mergeFrom(bytes).build(); 127 } catch (Exception e) { 128 } 129 return message; 130 } 131 132 private boolean doCall(Message request, Service service, MethodDescriptor method, 133 RpcCallback<Message> done) { 134 service.callMethod(method, null, request, done); 135 return true; 136 } 137 138 protected void done(Message response, long id) { 139 if (response != null) { 140 RpcMessage resp = RpcMessage.newBuilder() 141 .setType(MessageType.RESPONSE) 142 .setId(id) 143 .setResponse(response.toByteString()) 144 .build(); 145 channel.write(resp); 146 } else { 147 RpcMessage resp = RpcMessage.newBuilder() 148 .setType(MessageType.ERROR) 149 .setId(id) 150 .setError(ErrorCode.INVALID_RESPONSE) 151 .build(); 152 channel.write(resp); 153 } 154 } 155 156 @Override 157 public void callMethod(MethodDescriptor method, RpcController controller, Message request, 158 Message responsePrototype, RpcCallback<Message> done) { 159 long callId = id.getAndIncrement(); 160 RpcMessage message = RpcMessage.newBuilder() 161 .setType(MessageType.REQUEST) 162 .setId(callId) 163 .setService(method.getService().getFullName()) 164 .setMethod(method.getName()) 165 .setRequest(request.toByteString()) 166 .build(); 167 outstandings.put(callId, new Outstanding(responsePrototype, done)); 168 channel.write(message); 169 } 170 171 @Override 172 public Message callBlockingMethod(MethodDescriptor method, RpcController controller, 173 Message request, Message responsePrototype) throws ServiceException { 174 BlockingRpcCallback done = new BlockingRpcCallback(); 175 callMethod(method, controller, request, responsePrototype, done); 176 // if (channel instanceof NioClientSocketChannel) 177 // channel.get 178 // assert 179 synchronized (done) { 180 while (done.response == null) { 181 try { 182 done.wait(); 183 } catch (InterruptedException e) { 184 } 185 } 186 } 187 return done.response; 188 } 189} 190