1420c9859SShuo Chenpackage muduo.rpc;
2420c9859SShuo Chen
3420c9859SShuo Chenimport java.util.Map;
4420c9859SShuo Chenimport java.util.concurrent.ConcurrentHashMap;
5420c9859SShuo Chenimport java.util.concurrent.atomic.AtomicLong;
6420c9859SShuo Chen
7420c9859SShuo Chenimport muduo.rpc.proto.RpcProto.ErrorCode;
8420c9859SShuo Chenimport muduo.rpc.proto.RpcProto.MessageType;
9420c9859SShuo Chenimport muduo.rpc.proto.RpcProto.RpcMessage;
10420c9859SShuo Chenimport muduo.rpc.proto.RpcProto.RpcMessage.Builder;
11420c9859SShuo Chen
12420c9859SShuo Chenimport org.jboss.netty.channel.Channel;
13420c9859SShuo Chenimport org.jboss.netty.channel.ChannelHandlerContext;
14420c9859SShuo Chenimport org.jboss.netty.channel.MessageEvent;
15420c9859SShuo Chen
16b5a588dfSShuo Chenimport com.google.protobuf.BlockingRpcChannel;
17420c9859SShuo Chenimport com.google.protobuf.ByteString;
18420c9859SShuo Chenimport com.google.protobuf.Descriptors.MethodDescriptor;
19420c9859SShuo Chenimport com.google.protobuf.Message;
20420c9859SShuo Chenimport com.google.protobuf.RpcCallback;
21420c9859SShuo Chenimport com.google.protobuf.RpcController;
22420c9859SShuo Chenimport com.google.protobuf.Service;
23b5a588dfSShuo Chenimport com.google.protobuf.ServiceException;
24420c9859SShuo Chen
25b5a588dfSShuo Chenpublic class RpcChannel implements com.google.protobuf.RpcChannel, BlockingRpcChannel {
26420c9859SShuo Chen
27b5a588dfSShuo Chen    private final static class BlockingRpcCallback implements RpcCallback<Message> {
28b5a588dfSShuo Chen        public Message response;
29420c9859SShuo Chen
30b5a588dfSShuo Chen        @Override
31b5a588dfSShuo Chen        public void run(Message response) {
32b5a588dfSShuo Chen            synchronized (this) {
33b5a588dfSShuo Chen                this.response = response;
34b5a588dfSShuo Chen                notify();
35b5a588dfSShuo Chen            }
36420c9859SShuo Chen        }
37b5a588dfSShuo Chen    }
38b5a588dfSShuo Chen
39b5a588dfSShuo Chen    private final static class Outstanding {
40420c9859SShuo Chen
41420c9859SShuo Chen        public Message responsePrototype;
42420c9859SShuo Chen        public RpcCallback<Message> done;
43b5a588dfSShuo Chen
44b5a588dfSShuo Chen        public Outstanding(Message responsePrototype, RpcCallback<Message> done) {
45b5a588dfSShuo Chen            this.responsePrototype = responsePrototype;
46b5a588dfSShuo Chen            this.done = done;
47b5a588dfSShuo Chen        }
48420c9859SShuo Chen    }
49420c9859SShuo Chen
50420c9859SShuo Chen    private Channel channel;
51420c9859SShuo Chen    private AtomicLong id = new AtomicLong(1);
52420c9859SShuo Chen    private Map<Long, Outstanding> outstandings = new ConcurrentHashMap<Long, Outstanding>();
53420c9859SShuo Chen    private Map<String, Service> services;
54420c9859SShuo Chen
55420c9859SShuo Chen    public RpcChannel(Channel channel) {
56420c9859SShuo Chen        this.channel = channel;
57420c9859SShuo Chen    }
58420c9859SShuo Chen
59420c9859SShuo Chen    public void setServiceMap(Map<String, Service> services) {
60420c9859SShuo Chen        this.services = services;
61420c9859SShuo Chen    }
62420c9859SShuo Chen
63420c9859SShuo Chen    public Channel getChannel() {
64420c9859SShuo Chen        return channel;
65420c9859SShuo Chen    }
66420c9859SShuo Chen
67b5a588dfSShuo Chen    public void disconnect() {
68b5a588dfSShuo Chen        channel.disconnect();
69b5a588dfSShuo Chen    }
70b5a588dfSShuo Chen
71420c9859SShuo Chen    public void messageReceived(ChannelHandlerContext ctx, final MessageEvent e) {
72420c9859SShuo Chen        RpcMessage message = (RpcMessage) e.getMessage();
73420c9859SShuo Chen        assert e.getChannel() == channel;
74b5a588dfSShuo Chen        // System.out.println(message);
75420c9859SShuo Chen        if (message.getType() == MessageType.REQUEST) {
76420c9859SShuo Chen            doRequest(message);
77420c9859SShuo Chen        } else if (message.getType() == MessageType.RESPONSE) {
786b1f253eSShuo Chen            Outstanding o = outstandings.remove(message.getId());
79b5a588dfSShuo Chen            // System.err.println("messageReceived " + this);
80420c9859SShuo Chen            if (o != null) {
81420c9859SShuo Chen                Message resp = fromByteString(o.responsePrototype, message.getResponse());
82420c9859SShuo Chen                o.done.run(resp);
83b5a588dfSShuo Chen            } else {
84b5a588dfSShuo Chen                System.err.println("Unknown id " + message.getId());
85420c9859SShuo Chen            }
86420c9859SShuo Chen        }
87420c9859SShuo Chen    }
88420c9859SShuo Chen
89420c9859SShuo Chen    private void doRequest(RpcMessage message) {
90420c9859SShuo Chen        Service service = services.get(message.getService());
91420c9859SShuo Chen        Builder errorBuilder = RpcMessage.newBuilder().setType(MessageType.ERROR);
92420c9859SShuo Chen        boolean succeed = false;
93420c9859SShuo Chen        if (service != null) {
94420c9859SShuo Chen            MethodDescriptor method = service.getDescriptorForType()
95420c9859SShuo Chen                    .findMethodByName(message.getMethod());
96420c9859SShuo Chen            if (method != null) {
97420c9859SShuo Chen                Message request = fromByteString(service.getRequestPrototype(method),
98420c9859SShuo Chen                        message.getRequest());
99420c9859SShuo Chen                if (request != null) {
100420c9859SShuo Chen                    final long id = message.getId();
101420c9859SShuo Chen                    RpcCallback<Message> done = new RpcCallback<Message>() {
102420c9859SShuo Chen                        @Override
103420c9859SShuo Chen                        public void run(Message response) {
104420c9859SShuo Chen                            done(response, id);
105420c9859SShuo Chen                        }
106420c9859SShuo Chen                    };
107420c9859SShuo Chen                    succeed = doCall(request, service, method, done);
108420c9859SShuo Chen                } else {
109420c9859SShuo Chen                    errorBuilder.setError(ErrorCode.INVALID_REQUEST);
110420c9859SShuo Chen                }
111420c9859SShuo Chen            } else {
112420c9859SShuo Chen                errorBuilder.setError(ErrorCode.NO_METHOD);
113420c9859SShuo Chen            }
114420c9859SShuo Chen        } else {
115420c9859SShuo Chen            errorBuilder.setError(ErrorCode.NO_SERVICE);
116420c9859SShuo Chen        }
117420c9859SShuo Chen        if (!succeed) {
118420c9859SShuo Chen            RpcMessage resp = errorBuilder.build();
119420c9859SShuo Chen            channel.write(resp);
120420c9859SShuo Chen        }
121420c9859SShuo Chen    }
122420c9859SShuo Chen
123420c9859SShuo Chen    private Message fromByteString(Message prototype, ByteString bytes) {
124420c9859SShuo Chen        Message message = null;
125420c9859SShuo Chen        try {
126420c9859SShuo Chen            message = prototype.toBuilder().mergeFrom(bytes).build();
127420c9859SShuo Chen        } catch (Exception e) {
128420c9859SShuo Chen        }
129420c9859SShuo Chen        return message;
130420c9859SShuo Chen    }
131420c9859SShuo Chen
132420c9859SShuo Chen    private boolean doCall(Message request, Service service, MethodDescriptor method,
133420c9859SShuo Chen            RpcCallback<Message> done) {
134420c9859SShuo Chen        service.callMethod(method, null, request, done);
135420c9859SShuo Chen        return true;
136420c9859SShuo Chen    }
137420c9859SShuo Chen
138420c9859SShuo Chen    protected void done(Message response, long id) {
139420c9859SShuo Chen        if (response != null) {
140420c9859SShuo Chen            RpcMessage resp = RpcMessage.newBuilder()
141420c9859SShuo Chen                    .setType(MessageType.RESPONSE)
142420c9859SShuo Chen                    .setId(id)
143420c9859SShuo Chen                    .setResponse(response.toByteString())
144420c9859SShuo Chen                    .build();
145420c9859SShuo Chen            channel.write(resp);
146420c9859SShuo Chen        } else {
147420c9859SShuo Chen            RpcMessage resp = RpcMessage.newBuilder()
148420c9859SShuo Chen                    .setType(MessageType.ERROR)
149420c9859SShuo Chen                    .setId(id)
150420c9859SShuo Chen                    .setError(ErrorCode.INVALID_RESPONSE)
151420c9859SShuo Chen                    .build();
152420c9859SShuo Chen            channel.write(resp);
153420c9859SShuo Chen        }
154420c9859SShuo Chen    }
155420c9859SShuo Chen
156420c9859SShuo Chen    @Override
157420c9859SShuo Chen    public void callMethod(MethodDescriptor method, RpcController controller, Message request,
158420c9859SShuo Chen            Message responsePrototype, RpcCallback<Message> done) {
159420c9859SShuo Chen        long callId = id.getAndIncrement();
160420c9859SShuo Chen        RpcMessage message = RpcMessage.newBuilder()
161420c9859SShuo Chen                .setType(MessageType.REQUEST)
162420c9859SShuo Chen                .setId(callId)
163420c9859SShuo Chen                .setService(method.getService().getFullName())
164420c9859SShuo Chen                .setMethod(method.getName())
165420c9859SShuo Chen                .setRequest(request.toByteString())
166420c9859SShuo Chen                .build();
167420c9859SShuo Chen        outstandings.put(callId, new Outstanding(responsePrototype, done));
168420c9859SShuo Chen        channel.write(message);
169420c9859SShuo Chen    }
170420c9859SShuo Chen
171b5a588dfSShuo Chen    @Override
172b5a588dfSShuo Chen    public Message callBlockingMethod(MethodDescriptor method, RpcController controller,
173b5a588dfSShuo Chen            Message request, Message responsePrototype) throws ServiceException {
174b5a588dfSShuo Chen        BlockingRpcCallback done = new BlockingRpcCallback();
175b5a588dfSShuo Chen        callMethod(method, controller, request, responsePrototype, done);
176b5a588dfSShuo Chen        // if (channel instanceof NioClientSocketChannel)
177b5a588dfSShuo Chen        // channel.get
178b5a588dfSShuo Chen        // assert
179b5a588dfSShuo Chen        synchronized (done) {
180b5a588dfSShuo Chen            while (done.response == null) {
181b5a588dfSShuo Chen                try {
182b5a588dfSShuo Chen                    done.wait();
183b5a588dfSShuo Chen                } catch (InterruptedException e) {
184b5a588dfSShuo Chen                }
185b5a588dfSShuo Chen            }
186b5a588dfSShuo Chen        }
187b5a588dfSShuo Chen        return done.response;
188b5a588dfSShuo Chen    }
189420c9859SShuo Chen}
190