在文章

  本项目的data来自之前笔者标注的时间数据集,即标注出文本中的时间,采用BIO标注系统。chinese_wwm_ext文件夹为哈工大的预训练模型文件。
  model_train.py为模型训练的代码,主要功能是完成时间序列标注模型的训练,完整的代码如下:

# -*- coding: utf-8 -*- # time: 2019-09-12 # place: Huangcun Beijing  import kashgari from kashgari import utils from kashgari.corpus import DataReader from kashgari.embeddings import BERTEmbedding from kashgari.tasks.labeling import BiLSTM_CRF_Model  # 模型训练  train_x, train_y = DataReader().read_conll_format_file('./data/time.train') valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev') test_x, test_y = DataReader().read_conll_format_file('./data/time.test')  bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12',                                task=kashgari.LABELING,                                sequence_length=128)  model = BiLSTM_CRF_Model(bert_embedding)  model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1)  # Save model utils.convert_to_saved_model(model,                              model_path='saved_model/time_entity',                              version=1)

  运行该代码,模型训练完后会生成saved_model文件夹,里面含有模型训练好后的文件,方便我们利用tensorflow/serving进行部署。接着我们利用tensorflow/serving来完成模型的部署,命令如下:

docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving

其中需要注意该模型所在的路径,路径需要写完整路径,以及模型的名称(MODEL_NAME),这在训练代码(train.py)中已经给出(saved_model/time_entity)。

  接着我们使用tornado来搭建HTTP服务,帮助我们方便地进行模型预测,runServer.py的完整代码如下:

# -*- coding: utf-8 -*- import requests from kashgari import utils import numpy as np from model_predict import get_predict  import json import tornado.httpserver import tornado.ioloop import tornado.options import tornado.web from tornado.options import define, options import traceback  # tornado高并发 import tornado.web import tornado.gen import tornado.concurrent from concurrent.futures import ThreadPoolExecutor  # 定义端口为12333 define("port", default=16016, help="run on the given port", type=int)  # 模型预测 class ModelPredictHandler(tornado.web.RequestHandler):     executor = ThreadPoolExecutor(max_workers=5)      # get 函数     @tornado.gen.coroutine     def get(self):         origin_text = self.get_argument('text')         result = yield self.function(origin_text)         self.write(json.dumps(result, ensure_ascii=False))      @tornado.concurrent.run_on_executor     def function(self, text):         try:             text = text.replace(' ', '')             x = [_ for _ in text]              # Pre-processor data             processor = utils.load_processor(model_path='saved_model/time_entity/1')             tensor = processor.process_x_dataset([x])              # only for bert Embedding             tensor = [{                 "Input-Token:0": i.tolist(),                 "Input-Segment:0": np.zeros(i.shape).tolist()             } for i in tensor]              # predict             r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor})             preds = r.json()['predictions']