NLP(十七)利用tensorflow-serving部署kashgari模型
本项目的data来自之前笔者标注的时间数据集,即标注出文本中的时间,采用BIO标注系统。chinese_wwm_ext文件夹为哈工大的预训练模型文件。
model_train.py为模型训练的代码,主要功能是完成时间序列标注模型的训练,完整的代码如下:
# -*- coding: utf-8 -*- # time: 2019-09-12 # place: Huangcun Beijing import kashgari from kashgari import utils from kashgari.corpus import DataReader from kashgari.embeddings import BERTEmbedding from kashgari.tasks.labeling import BiLSTM_CRF_Model # 模型训练 train_x, train_y = DataReader().read_conll_format_file('./data/time.train') valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev') test_x, test_y = DataReader().read_conll_format_file('./data/time.test') bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=128) model = BiLSTM_CRF_Model(bert_embedding) model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1) # Save model utils.convert_to_saved_model(model, model_path='saved_model/time_entity', version=1)
运行该代码,模型训练完后会生成saved_model文件夹,里面含有模型训练好后的文件,方便我们利用tensorflow/serving进行部署。接着我们利用tensorflow/serving来完成模型的部署,命令如下:
docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving
其中需要注意该模型所在的路径,路径需要写完整路径,以及模型的名称(MODEL_NAME),这在训练代码(train.py)中已经给出(saved_model/time_entity)。
接着我们使用tornado来搭建HTTP服务,帮助我们方便地进行模型预测,runServer.py的完整代码如下:
# -*- coding: utf-8 -*- import requests from kashgari import utils import numpy as np from model_predict import get_predict import json import tornado.httpserver import tornado.ioloop import tornado.options import tornado.web from tornado.options import define, options import traceback # tornado高并发 import tornado.web import tornado.gen import tornado.concurrent from concurrent.futures import ThreadPoolExecutor # 定义端口为12333 define("port", default=16016, help="run on the given port", type=int) # 模型预测 class ModelPredictHandler(tornado.web.RequestHandler): executor = ThreadPoolExecutor(max_workers=5) # get 函数 @tornado.gen.coroutine def get(self): origin_text = self.get_argument('text') result = yield self.function(origin_text) self.write(json.dumps(result, ensure_ascii=False)) @tornado.concurrent.run_on_executor def function(self, text): try: text = text.replace(' ', '') x = [_ for _ in text] # Pre-processor data processor = utils.load_processor(model_path='saved_model/time_entity/1') tensor = processor.process_x_dataset([x]) # only for bert Embedding tensor = [{ "Input-Token:0": i.tolist(), "Input-Segment:0": np.zeros(i.shape).tolist() } for i in tensor] # predict r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor}) preds = r.json()['predictions']