1package muduo.rpc;
2
3import java.util.Map;
4import java.util.concurrent.ConcurrentHashMap;
5import java.util.concurrent.atomic.AtomicLong;
6
7import muduo.rpc.proto.RpcProto.ErrorCode;
8import muduo.rpc.proto.RpcProto.MessageType;
9import muduo.rpc.proto.RpcProto.RpcMessage;
10import muduo.rpc.proto.RpcProto.RpcMessage.Builder;
11
12import org.jboss.netty.channel.Channel;
13import org.jboss.netty.channel.ChannelHandlerContext;
14import org.jboss.netty.channel.MessageEvent;
15
16import com.google.protobuf.BlockingRpcChannel;
17import com.google.protobuf.ByteString;
18import com.google.protobuf.Descriptors.MethodDescriptor;
19import com.google.protobuf.Message;
20import com.google.protobuf.RpcCallback;
21import com.google.protobuf.RpcController;
22import com.google.protobuf.Service;
23import com.google.protobuf.ServiceException;
24
25public class RpcChannel implements com.google.protobuf.RpcChannel, BlockingRpcChannel {
26
27    private final static class BlockingRpcCallback implements RpcCallback<Message> {
28        public Message response;
29
30        @Override
31        public void run(Message response) {
32            synchronized (this) {
33                this.response = response;
34                notify();
35            }
36        }
37    }
38
39    private final static class Outstanding {
40
41        public Message responsePrototype;
42        public RpcCallback<Message> done;
43
44        public Outstanding(Message responsePrototype, RpcCallback<Message> done) {
45            this.responsePrototype = responsePrototype;
46            this.done = done;
47        }
48    }
49
50    private Channel channel;
51    private AtomicLong id = new AtomicLong(1);
52    private Map<Long, Outstanding> outstandings = new ConcurrentHashMap<Long, Outstanding>();
53    private Map<String, Service> services;
54
55    public RpcChannel(Channel channel) {
56        this.channel = channel;
57    }
58
59    public void setServiceMap(Map<String, Service> services) {
60        this.services = services;
61    }
62
63    public Channel getChannel() {
64        return channel;
65    }
66
67    public void disconnect() {
68        channel.disconnect();
69    }
70
71    public void messageReceived(ChannelHandlerContext ctx, final MessageEvent e) {
72        RpcMessage message = (RpcMessage) e.getMessage();
73        assert e.getChannel() == channel;
74        // System.out.println(message);
75        if (message.getType() == MessageType.REQUEST) {
76            doRequest(message);
77        } else if (message.getType() == MessageType.RESPONSE) {
78            Outstanding o = outstandings.remove(message.getId());
79            // System.err.println("messageReceived " + this);
80            if (o != null) {
81                Message resp = fromByteString(o.responsePrototype, message.getResponse());
82                o.done.run(resp);
83            } else {
84                System.err.println("Unknown id " + message.getId());
85            }
86        }
87    }
88
89    private void doRequest(RpcMessage message) {
90        Service service = services.get(message.getService());
91        Builder errorBuilder = RpcMessage.newBuilder().setType(MessageType.ERROR);
92        boolean succeed = false;
93        if (service != null) {
94            MethodDescriptor method = service.getDescriptorForType()
95                    .findMethodByName(message.getMethod());
96            if (method != null) {
97                Message request = fromByteString(service.getRequestPrototype(method),
98                        message.getRequest());
99                if (request != null) {
100                    final long id = message.getId();
101                    RpcCallback<Message> done = new RpcCallback<Message>() {
102                        @Override
103                        public void run(Message response) {
104                            done(response, id);
105                        }
106                    };
107                    succeed = doCall(request, service, method, done);
108                } else {
109                    errorBuilder.setError(ErrorCode.INVALID_REQUEST);
110                }
111            } else {
112                errorBuilder.setError(ErrorCode.NO_METHOD);
113            }
114        } else {
115            errorBuilder.setError(ErrorCode.NO_SERVICE);
116        }
117        if (!succeed) {
118            RpcMessage resp = errorBuilder.build();
119            channel.write(resp);
120        }
121    }
122
123    private Message fromByteString(Message prototype, ByteString bytes) {
124        Message message = null;
125        try {
126            message = prototype.toBuilder().mergeFrom(bytes).build();
127        } catch (Exception e) {
128        }
129        return message;
130    }
131
132    private boolean doCall(Message request, Service service, MethodDescriptor method,
133            RpcCallback<Message> done) {
134        service.callMethod(method, null, request, done);
135        return true;
136    }
137
138    protected void done(Message response, long id) {
139        if (response != null) {
140            RpcMessage resp = RpcMessage.newBuilder()
141                    .setType(MessageType.RESPONSE)
142                    .setId(id)
143                    .setResponse(response.toByteString())
144                    .build();
145            channel.write(resp);
146        } else {
147            RpcMessage resp = RpcMessage.newBuilder()
148                    .setType(MessageType.ERROR)
149                    .setId(id)
150                    .setError(ErrorCode.INVALID_RESPONSE)
151                    .build();
152            channel.write(resp);
153        }
154    }
155
156    @Override
157    public void callMethod(MethodDescriptor method, RpcController controller, Message request,
158            Message responsePrototype, RpcCallback<Message> done) {
159        long callId = id.getAndIncrement();
160        RpcMessage message = RpcMessage.newBuilder()
161                .setType(MessageType.REQUEST)
162                .setId(callId)
163                .setService(method.getService().getFullName())
164                .setMethod(method.getName())
165                .setRequest(request.toByteString())
166                .build();
167        outstandings.put(callId, new Outstanding(responsePrototype, done));
168        channel.write(message);
169    }
170
171    @Override
172    public Message callBlockingMethod(MethodDescriptor method, RpcController controller,
173            Message request, Message responsePrototype) throws ServiceException {
174        BlockingRpcCallback done = new BlockingRpcCallback();
175        callMethod(method, controller, request, responsePrototype, done);
176        // if (channel instanceof NioClientSocketChannel)
177        // channel.get
178        // assert
179        synchronized (done) {
180            while (done.response == null) {
181                try {
182                    done.wait();
183                } catch (InterruptedException e) {
184                }
185            }
186        }
187        return done.response;
188    }
189}
190