Bert 中文短句相似度计算 Docker CPU镜像

在这一期中,我们还是制作了一个集数据,模型,代码一体的 docker 环境,给大家开箱即用体验中文BERT句子embedding体验。具体地,我们基于 BERT-wwm-exthuggingface transformersentence-transformer 把玩中文句子embedding 并寻找和查询短语相似度最接近的句子。

Docker 镜像获取方式

本期 docker 镜像获取方式为,关注 MyEncyclopedia 公众号后回复 docker-sentence-transformer 即可获取镜像地址和启动命令。

哈工大讯飞中文 Bert

在中文预训练领域,哈工大讯飞联合实验室(HFL)发布的基于全词Mask的中文预训练模型 BERT-wwm-ext 是业界的标杆之一。BERT-wwm-ext 支持 Tensorflow, Pytorch (通过 huggingface transformer 接口)以及 PaddleHub 的接口或者类库,使用起来十分方便。下面的代码为官网中通过 huggingface transformer 接口直接下载并加载到 Pytorch 平台中。Github 地址为 https://github.com/ymcui/Chinese-BERT-wwm

1
2
3
4
5
from transformers import BertTokenizer, BertModel

model_name = 'hfl/chinese-bert-wwm'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

通过 huggingface transformer 的好处在于 sentence-transformer 也支持 huggingface,因此,通过 huggingface,我们无需手动串联 BERT-wwm-extsentence-transformer,少写了不少代码。

sentence-transformer

sentence-transformer 顾名思义是利用 transformer 词向量的预训练模型来生成句子级别的embedding。原理基于这篇论文 Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (https://arxiv.org/abs/1908.10084)。基本思想直接了当:将句子中的每个词经 bert embedding 后,输入池化层 (pooling),例如选择最简单的平均池化层,再将所有token embedding 的均值作为输出,便得到跟输入句子长度无关的一个定长的 sentence embedding。

下面的代码是其官网的一个基本例子,底层通过 huggingface 接口自动下载并加载 bert 词向量,并计算三句英语句子的 sentence embedding。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
print("Sentence:", sentence)
print("Embedding:", embedding)
print("")

当然,我们也可以绕过 sentence-transformer API,直接使用 pytorch API 和 huggingface 手动实现平均池化层,生成句子的 sentence embedding。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import AutoTokenizer, AutoModel
import torch

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)

#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

中文最相近的句子

有了上面每个组件的使用方法,让我们生成下面中文句子的embedding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
sentences = [
'今天晚上想吃牛排',
'MyEncyclopedia公众号全栈人工智能',
'人工智能需要懂很多数学么',
'上海疫情有完没完',
'教育部:连续7天社会面无疫情 高校可组织校园招聘',
'福建舰"下水!100秒看中国航母高光时刻',
'医保承担多少核酸检测费用?压力多大?',
'张家口过度防疫整改后又被曝光:要证明牛是阴性',
'上海多家银行天天排队爆满 有老人凌晨2点开始排队',
'A股不惧海外暴跌!走出独立行情沪指收复3300点',
'俄方称已准备好重启俄乌和谈',
'《自然》:奥密克戎感染后嗅觉丧失症状比原来少了'
]

接着我们给出如下三个短语的查询,找到和每个查询最匹配的三个句子

1
2
3
q1 = '码农的春天来了么'
q2 = '国际局势'
q3 = '健康'

运行结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Query: 码农的春天来了么

Top 3 most similar sentences in corpus:
人工智能需要懂很多数学么 (Cosine Score: 0.7606)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.7498)
上海疫情有完没完 (Cosine Score: 0.7449)

----------------------------------------------
Query: 国际局势

Top 3 most similar sentences in corpus:
俄方称已准备好重启俄乌和谈 (Cosine Score: 0.7041)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.6897)
上海疫情有完没完 (Cosine Score: 0.6888)

----------------------------------------------
Query: 健康

Top 3 most similar sentences in corpus:
上海疫情有完没完 (Cosine Score: 0.5882)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.5870)
今天晚上想吃牛排 (Cosine Score: 0.5815)

结果发现 上海疫情有完没完 是一切问题的关键。。。

完整代码

附上完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from sentence_transformers import SentenceTransformer

model_name = 'hfl/chinese-bert-wwm'
model = SentenceTransformer(model_name)

sentences = [
'今天晚上想吃牛排',
'MyEncyclopedia公众号全栈人工智能',
'人工智能需要懂很多数学么',
'上海疫情有完没完',
'教育部:连续7天社会面无疫情 高校可组织校园招聘',
'福建舰"下水!100秒看中国航母高光时刻',
'医保承担多少核酸检测费用?压力多大?',
'张家口过度防疫整改后又被曝光:要证明牛是阴性',
'上海多家银行天天排队爆满 有老人凌晨2点开始排队',
'A股不惧海外暴跌!走出独立行情沪指收复3300点',
'俄方称已准备好重启俄乌和谈',
'《自然》:奥密克戎感染后嗅觉丧失症状比原来少了'
]
sentence_embeddings = model.encode(sentences)

q1 = '码农的春天来了么'
q2 = '国际局势'
q3 = '健康'
queries = [q1, q2, q3]
query_embeddings = model.encode(queries)

import scipy

number_top_matches = 3
for query, query_embedding in zip(queries, query_embeddings):
distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]
results = zip(range(len(distances)), distances)
results = sorted(results, key=lambda x: x[1])
print("\nQuery:", query)
print("\nTop {} most similar sentences in corpus:".format(number_top_matches))

for idx, distance in results[0:number_top_matches]:
print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))
实战入门 faiss 搜索bert 最邻近句子:docker CPU镜像开箱即用,无需额外安装下载 玩转transformer+flair zero shot 短文本分类:无需翻墙或额外下载模型和数据集的CPU docker镜像

Author and License Contact MyEncyclopedia to Authorize
myencyclopedia.top link https://blog.myencyclopedia.top/zh/2022/docker-sentence-transformer-chinese/
github.io link https://myencyclopedia.github.io/zh/2022/docker-sentence-transformer-chinese/

You need to set install_url to use ShareThis. Please set it in _config.yml.

评论

You forgot to set the shortname for Disqus. Please set it in _config.yml.
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×