TensorFlow与消息队列–服务器

消息队列(Message Queue,简称 MQ),可以将一些费时的任务放入队列,慢慢处理,改善客户端的体验。与TensorFlow服务结合,可以调高服务器的计算能力,将费时的数据传输与相对较快的预测计算分开来

一.数据流图

说明:
应用场景是收集客户端中的图片,进行分类处理,结果保留到服务器与数据库。客户端不需要实时知道图片的分类结果。如果客户端要知道实时结果,就不能这样设计。
该设计主要解决了2个问题,第一 加快了客户端上传图片的速度;第二 将图像预测的工作分离了出来,单独在GPU上运行,实现最大的效率。
详细流程:
1.客户端将图片通过HTTP协议post到“生产者”
2.“生产者”将收到的图片与其他信息打包发个消息队列
3.“消费者”拿到消息后,先解包,再进行图片预处理,随后将图片张量(数值)发个tensorflow搭建的服务进行分类。
4“消费者”收到结果后,保存结果数据。

二.各部分介绍

客户端:
客户端使用java程序编写,使用java主要看中了其稳定性。核心上传代码如下,就是简单的http post

URL url = new URL("http://109.120.57.85:7000/dg?filename=abc");
URLConnection urlcon=url.openConnection();
urlcon.setRequestProperty("accept", "*/*");
urlcon.setRequestProperty("connection", "Keep-Alive");
urlcon.setRequestProperty("user-agent","Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:65.0) Gecko/20100101 Firefox/65.0");

urlcon.setDoOutput(true);
urlcon.setDoInput(true);

FileInputStream in = new FileInputStream(new File("abc.jpg"));
byte[] tempbyte = new byte[10240];
int n = in.read(tempbyte);
while(n > 0) {
	urlcon.getOutputStream().write(tempbyte, 0, n);
	n = in.read(tempbyte);
}
urlcon.getOutputStream().flush();
in.close();

BufferedReader br = new BufferedReader(new InputStreamReader(urlcon.getInputStream(), "utf-8"));
String ret =  br.readLine();
System.out.println(ret);
br.close();

生产者:
生产者使用python编写,数据打包使用了protobuf

# coding=utf-8
#!/usr/bin/env python
import pika
import AIMessage_pb2

from flask import Flask,request
app = Flask(__name__)

class Publisher:
    def __init__(self):
        self._conn = None
        self._channel = None

    def connect(self):
        if not self._conn or self._conn.is_closed:
            self._conn = pika.BlockingConnection(pika.ConnectionParameters(host='localhost'))
            self._channel = self._conn.channel()
            self._channel.queue_declare(queue='task_queue', durable=True, arguments = {'x-message-ttl' : 600000})

    def _publish(self, msg):
        self._channel.basic_publish(exchange='',
                              routing_key='task_queue',
                              body=msg,
                              properties=pika.BasicProperties(
                                 delivery_mode = 2, # make message persistent
                              ))

    def publish(self, msg):
        """Publish msg, reconnecting if necessary."""

        try:
            self._publish(msg)
        except pika.exceptions.ConnectionClosed:
            self.connect()
            self._publish(msg)

    def close(self):
        if self._conn and self._conn.is_open:
            self._conn.close()


publisher = Publisher()
publisher.connect()

@app.route('/')
def hello_world():
    return 'Hello World!'

@app.route('/dg',methods=['GET','POST'])
def dg():
    if request.method == 'POST':
        ai_message = AIMessage_pb2.AIMessage()
        ai_message.url = request.full_path
        ai_message.imgbuffer = request.stream.read()

        publisher.publish(ai_message.SerializeToString())
    
    return 'ok!'

if __name__ == '__main__':
    app.run(host = '0.0.0.0')

消息队列:
消息队列使用了RabbitMQ。对比Redis,其功能多,最重要的是图片数据大,使用Redis性能差。对比activemq,其安装维护简单一点。

消费者:
消费者也是python编写的。以下消费者的代码并没有真的将图片处理的运算包含进来,这部分运算放到了TensorFlow搭建的服务中去了。性能并没有搭到最大,因为图片预处理与神经网络预测是同步顺序运行,妨碍了gpu的发挥。但修改也简单的

# coding=utf-8
#!/usr/bin/env python
import pika
import AIMessage_pb2
import http.client

headers = {"User-agent":"Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1"}

connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost'))
channel = connection.channel()

channel.queue_declare(queue='task_queue', durable=True , arguments = {'x-message-ttl' : 600000})
print(' [*] Waiting for messages. To exit press CTRL+C')

def callback(ch, method, properties, body):
    ai_message = AIMessage_pb2.AIMessage()
    ai_message.ParseFromString(body)

    conn = http.client.HTTPConnection("127.0.0.1", 7000, timeout=30)
    conn.request('POST', ai_message.url, ai_message.imgbuffer, headers)
    resp = conn.getresponse()
    content = resp.read()
    ch.basic_ack(delivery_tag = method.delivery_tag)

channel.basic_qos(prefetch_count=1)
channel.basic_consume(callback,
                      queue='task_queue')

channel.start_consuming()

TensorFlow服务:
这个可以参考之前的文章https://www.yangyouji.info/archives/319 这里就不在赘述


发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注