#nlp

斯坦福大学的 CS 课程上线了一门经典课程:《CS 25: Transformers United》。自 2017 年推出以来,Transformer 彻底改变了自然语言处理 (NLP)领域。现在,Transformer 在深度学习中被广泛使用,无论是计算机视觉 (CV)、强化学习 (RL)、生成对抗网络 (GAN)、语音甚至是生物学。除此之外,Transformer 还能够创建强大的语言模型(如 GPT-3),并在 AlphaFold2 中发挥了重要作用,该算法解决了蛋白质折叠问题。

目前这门课程在 Youtube 上日更连载中,地址为 https://www.youtube.com/playlist?list=PLoROMvodv4rNiJRchCzutFw5ItR_Z27CM

MyEncyclopedia Bilibili 为大家每日搬运同步视频,至今天7/16日已经有6集

明星讲课阵容

在今天公布的第一节课中,讲师为斯坦福大学硕士生 Divyansh Garg、软件工程师 Chetanya Rastogi(毕业于斯坦福大学)、软件工程师 Advay Pal(毕业于斯坦福大学)。

此外,第一节课的指导教授为 Christopher Manning,他是斯坦福大学计算机与语言学教授,也是将深度学习应用于自然语言处理领域的领军者。

从之前的课程描述来看,CS 25 课程邀请了来自不同领域关于 Transformer 研究的前沿人士进行客座讲座。OpenAI 的研究科学家 Mark Chen,主要介绍基于 Transformers 的 GPT-3、Codex;Google Brain 的科学家 Lucas Beyer,主要介绍 Transformer 在视觉领域的应用;Meta FAIR 科学家 Aditya Grover,主要介绍 RL 中的 Transformer 以及计算引擎等。

值得一提的是,AI 教父 Geoff Hinton 也带来了一次讲座。

课程明细

1. (Sep 20) Introduction to Transformers

Recommended Readings:

2. (Sept 27) Transformers in Language: GPT-3, Codex

Speaker: Mark Chen (OpenAI)

Recommended Readings: - Language Models are Few-Shot Learners
- Evaluating Large Language Models Trained on Code

3. (Oct 4) Applications in Vision

Speaker: Lucas Beyer (Google Brain)

Recommended Readings: - An Image is Worth 16x16 Words (Vision Transfomer)
- Additional Readings:
- How to train your ViT?

4. (Oct 11) Transformers in RL & Universal Compute Engines

Speaker: Aditya Grover (FAIR)

Recommended Readings: - Pretrained Transformers as Universal Computation Engines
- Decision Transformer: Reinforcement Learning via Sequence Modeling

5. (Oct 18) Scaling transformers

Speaker: Barret Zoph (Google Brain) with Irwan Bello and Liam Fedus

Recommended Readings: - Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
- ST-MoE: Designing Stable and Transferable Sparse Expert Models

6. (Oct 25) Perceiver: Arbitrary IO with transformers

Speaker: Andrew Jaegle (DeepMind)

Recommended Readings: - Perceiver: General Perception with Iterative Attention
- Perceiver IO: A General Architecture for Structured Inputs & Outputs

7. (Nov 1) Self Attention & Non-Parametric Transformers

Speaker: Aidan Gomez (University of Oxford)

Recommended Readings: - Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning

8. (Nov 8) GLOM: Representing part-whole hierarchies in a neural network

Speaker: Geoffrey Hinton (UoT) Recommended Readings:

9. (Nov 15) Interpretability with transformers

Speaker: Chris Olah (AnthropicAI)

Recommended Readings: - Multimodal Neurons in Artificial Neural Networks Additional Readings: - The Building Blocks of Interpretability

10. (Nov 29) Transformers for Applications in Audio, Speech and Music: From Language Modeling to Understanding to Synthesis

Speaker: Prateek Verma (Stanford)

在这一期中,我们延续上一期 Bert 中文短句相似度计算 Docker CPU镜像,继续使用 huggingface transformersentence-transformer 类库,并将英语句子生成 bert embedding,然后引入 faiss 类库来建立索引,最后查询最接近的句子。

Docker 镜像获取方式

本期 docker 镜像获取方式为,关注 MyEncyclopedia 公众号后回复 docker-faiss-transformer 即可获取如下完整命令。

1
docker run -p 8888:8888 myencyclopedia/faiss-demo bash -c 'jupyter notebook --allow-root --port 8888 --NotebookApp.token= --ip 0.0.0.0'

然后打开浏览器,输入 http://localhost:8888/notebooks/faiss_demo.ipynb

faiss 简介

Faiss 的全称是Facebook AI Similarity Search,是由 Facebook 开发的适用于稠密向量匹配的开源库,作为向量化检索开山鼻祖,Faiss 提供了一套查询海量高维数据集的解决方案,它从两个方面改善了暴力搜索算法存在的问题:降低空间占用和加快检索速度。此外,Faiss 提供了若干种方法实现数据压缩,包括 PCA、Product-Quantization等。

Faiss 主要特性:

  • 支持相似度检索和聚类;
  • 支持多种索引方式;
  • 支持CPU和GPU计算;
  • 支持Python和C++调用;

Faiss 使用流程

使用 faiss 分成两部,第一步需要对原始向量建立索引文件,第二步再对索引文件进行向量 search 操作。

在第一次建立索引文件的时候,需要经过 trainadd 两个过程;后续如果有新的向量需要被添加到索引文件,只需要一个 add 操作来实现增量索引更新,但是如果增量的量级与原始索引差不多的话,整个向量空间就可能发生了一些变化,这个时候就需要重新建立整个索引文件,也就是再用全部的向量来走一遍 trainadd,至于具体是如何 trainadd的,就和特定的索引类型有关了。

1. IndexFlatL2 & indexFlatIP

对于精确搜索,例如欧式距离 faiss.indexFlatL2 或 内积距离 faiss.indexFlatIP,没有 train 过程,add 完直接可以 search

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import faiss 

# 建立索引, 定义为dimension d = 128
index = faiss.IndexFlatL2(d)

# add vectors, xb 为 (100000,128)大小的numpy
index.add(xb)
print(index.ntotal)
# 索引中向量的数量, 输出100000

# 求4-近邻
k = 4
# xq为query embedding, 大小为(10000,128)
D, I = index.search(xq, k)
## D shape (10000,4),表示每个返回点的embedding 与 query embedding的距离,
## I shape (10000,4),表示和query embedding最接近的k个物品id,
print(I[:5])

2. IndexIVFFlat

IndexFlatL2 的结果虽然精确,但当数据集比较大的时候,暴力搜索的时间复杂度很高,因此我们一般会使用其他方式的索引来加速。比如 IndexIVFFlat,将数据集在 train 阶段分割为几部分,技术术语为 Voronoi Cells,每个数据向量只能落在一个cell中。Search 时只需要查询query向量落在cell中的数据了,降低了距离计算次数。这个过程本质就是高维 KNN 聚类算法。search 阶段使用倒排索引来。

IndexIVFFlat 需要一个训练的阶段,其与另外一个索引 quantizer 有关,通过 quantizer 来判断属于哪个cell。IndexIVFFlat 在搜索阶段,引入了nlist(cell的数量)与nprob(执行搜索的cell数)参数。增大nprobe可以得到与brute-force更为接近的结果,nprobe就是速度与精度的调节器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import faiss
nlist = 100
k = 4

# 建立索引, 定义为dimension d = 128
quantizer = faiss.IndexFlatL2(d)

# 使用欧式距离 L2 建立索引。
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)

## xb: (100000,128)
index.train(xb)
index.add(xb)
index.nprobe = 10 # 默认 nprobe 是 1 ,可以设置的大一些试试
D, I = index.search(xq, k)
print(I[-5:]) # 最后五次查询的结果

3. IndexIVFPQ

IndexFlatL2 和 IndexIVFFlat都要存储所有的向量数据。对于超大规模数据集来说,可能会不大现实。因此IndexIVFPQ 索引可以用来压缩向量,具体的压缩算法就是 Product-Quantization,注意,由于高维向量被压缩,因此 search 时候返回也是近似的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import faiss

nlist = 100
# 每个向量分8段
m = 8
# 求4-近邻
k = 4
quantizer = faiss.IndexFlatL2(d) # 内部的索引方式依然不变
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8) # 每个向量都被编码为8个字节大小
index.train(xb)
index.add(xb)
index.nprobe = 10
D, I = index.search(xq, k) # 检索
print(I[-5:])

在本期中,我们仅使用基本的 IndexIVFFlat 和 IndexFlatIP 完成 bert embedding 的索引和搜索,后续会有篇幅来解读 Product-Quantization 的论文原理和代码实践。

ag_news 新闻数据集

ag_news 新闻数据集 3.0 包含了英语新闻标题,training 部分包含 120000条数据, test 部分包含 7600条数据。

ag_news 可以通过 huggingface datasets API 自动下载

1
2
3
4
5
6
7
8
9
10
def load_dataset(part='test') -> List[str]:
ds = datasets.load_dataset("ag_news")
list_str = [r['text'] for r in ds[part]]
return list_str

list_str = load_dataset(part='train')
print(f'{len(list_str)}')
for s in list_str[:3]:
print(s)
print('\n')

显示前三条新闻标题为

1
2
3
4
5
6
7
8
9
120000
Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.


Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

sentence-transformer

和上一期一样,我们利用sentence-transformer 生成句子级别的embedding。其原理基于 Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (https://arxiv.org/abs/1908.10084)这篇论文。基本思想很直接,将句子中的每个词的 bert embedding ,输进入一个池化层(pooling),例如选择最简单的平均池化层,将所有token embedding 的均值作为输出,便得到跟输入句子长度无关的一个定长的 sentence embedding。

结果展示

数据集 train 部分由于包含的样本比较多,需要一段时间生成 bert embedding,大家可以使用 load_dataset(part='test') 来快速体验。下面我们演示一个查询 how to make money 的最接近结果。

1
2
3
4
index = load_index('news_train.index')
list_id = query(model, index, 'how to make money')
for id in list_id:
print(list_str[id])
1
2
3
4
5
6
7
8
9
Profit From That Traffic Ticket Got a traffic ticket? Can't beat 'em? Join 'em by investing in the company that processes those tickets.

Answers in the Margins By just looking at operating margins, investors can find profitable industry leaders.

Types of Investors: Which Are You? Learn a little about yourself, and it may improve your performance.

Target Can Aim High Target can maintain its discount image while offering pricier services and merchandise.

Finance moves Ford into the black US carmaker Ford Motor returns to profit, as the money it makes from lending to customers outweighs losses from selling vehicles.

核心代码

所有可运行代码和数据都已经包含在 docker 镜像中了,下面列出核心代码

建立索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def train_flat(index_name, id_list, embedding_list, num_clusters):
import numpy as np
import faiss

dim = 768
m = 16

embeddings = np.asarray(embedding_list)

quantiser = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFFlat(quantiser, dim, num_clusters, faiss.METRIC_INNER_PRODUCT)
index.train(embeddings) ## clustering

ids = np.arange(0, len(id_list))
ids = np.asarray(ids.astype('int64'))

index.add_with_ids(embeddings, ids)
print(index.is_trained)
print("Total Number of Embeddings in the index", index.ntotal)
faiss.write_index(index, index_name)

查询结果

1
2
3
4
5
6
7
def query(model, index, query_str: str) -> List[int]:
topk = 5
q_embed = model.encode([query_str])
D, I = index.search(q_embed, topk)
print(D)
print(I)
return I[0].tolist()

在这一期中,我们还是制作了一个集数据,模型,代码一体的 docker 环境,给大家开箱即用体验中文BERT句子embedding体验。具体地,我们基于 BERT-wwm-exthuggingface transformersentence-transformer 把玩中文句子embedding 并寻找和查询短语相似度最接近的句子。

Docker 镜像获取方式

本期 docker 镜像获取方式为,关注 MyEncyclopedia 公众号后回复 docker-sentence-transformer 即可获取镜像地址和启动命令。

哈工大讯飞中文 Bert

在中文预训练领域,哈工大讯飞联合实验室(HFL)发布的基于全词Mask的中文预训练模型 BERT-wwm-ext 是业界的标杆之一。BERT-wwm-ext 支持 Tensorflow, Pytorch (通过 huggingface transformer 接口)以及 PaddleHub 的接口或者类库,使用起来十分方便。下面的代码为官网中通过 huggingface transformer 接口直接下载并加载到 Pytorch 平台中。Github 地址为 https://github.com/ymcui/Chinese-BERT-wwm

1
2
3
4
5
from transformers import BertTokenizer, BertModel

model_name = 'hfl/chinese-bert-wwm'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

通过 huggingface transformer 的好处在于 sentence-transformer 也支持 huggingface,因此,通过 huggingface,我们无需手动串联 BERT-wwm-extsentence-transformer,少写了不少代码。

sentence-transformer

sentence-transformer 顾名思义是利用 transformer 词向量的预训练模型来生成句子级别的embedding。原理基于这篇论文 Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (https://arxiv.org/abs/1908.10084)。基本思想直接了当:将句子中的每个词经 bert embedding 后,输入池化层 (pooling),例如选择最简单的平均池化层,再将所有token embedding 的均值作为输出,便得到跟输入句子长度无关的一个定长的 sentence embedding。

下面的代码是其官网的一个基本例子,底层通过 huggingface 接口自动下载并加载 bert 词向量,并计算三句英语句子的 sentence embedding。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

#Our sentences we like to encode
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

#Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

#Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
print("Sentence:", sentence)
print("Embedding:", embedding)
print("")

当然,我们也可以绕过 sentence-transformer API,直接使用 pytorch API 和 huggingface 手动实现平均池化层,生成句子的 sentence embedding。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import AutoTokenizer, AutoModel
import torch

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask

#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence',
'Sentences are passed as a list of string.',
'The quick brown fox jumps over the lazy dog.']

#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')

#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)

#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

中文最相近的句子

有了上面每个组件的使用方法,让我们生成下面中文句子的embedding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
sentences = [
'今天晚上想吃牛排',
'MyEncyclopedia公众号全栈人工智能',
'人工智能需要懂很多数学么',
'上海疫情有完没完',
'教育部:连续7天社会面无疫情 高校可组织校园招聘',
'福建舰"下水!100秒看中国航母高光时刻',
'医保承担多少核酸检测费用?压力多大?',
'张家口过度防疫整改后又被曝光:要证明牛是阴性',
'上海多家银行天天排队爆满 有老人凌晨2点开始排队',
'A股不惧海外暴跌!走出独立行情沪指收复3300点',
'俄方称已准备好重启俄乌和谈',
'《自然》:奥密克戎感染后嗅觉丧失症状比原来少了'
]

接着我们给出如下三个短语的查询,找到和每个查询最匹配的三个句子

1
2
3
q1 = '码农的春天来了么'
q2 = '国际局势'
q3 = '健康'

运行结果如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Query: 码农的春天来了么

Top 3 most similar sentences in corpus:
人工智能需要懂很多数学么 (Cosine Score: 0.7606)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.7498)
上海疫情有完没完 (Cosine Score: 0.7449)

----------------------------------------------
Query: 国际局势

Top 3 most similar sentences in corpus:
俄方称已准备好重启俄乌和谈 (Cosine Score: 0.7041)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.6897)
上海疫情有完没完 (Cosine Score: 0.6888)

----------------------------------------------
Query: 健康

Top 3 most similar sentences in corpus:
上海疫情有完没完 (Cosine Score: 0.5882)
MyEncyclopedia公众号全栈人工智能 (Cosine Score: 0.5870)
今天晚上想吃牛排 (Cosine Score: 0.5815)

结果发现 上海疫情有完没完 是一切问题的关键。。。

完整代码

附上完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from sentence_transformers import SentenceTransformer

model_name = 'hfl/chinese-bert-wwm'
model = SentenceTransformer(model_name)

sentences = [
'今天晚上想吃牛排',
'MyEncyclopedia公众号全栈人工智能',
'人工智能需要懂很多数学么',
'上海疫情有完没完',
'教育部:连续7天社会面无疫情 高校可组织校园招聘',
'福建舰"下水!100秒看中国航母高光时刻',
'医保承担多少核酸检测费用?压力多大?',
'张家口过度防疫整改后又被曝光:要证明牛是阴性',
'上海多家银行天天排队爆满 有老人凌晨2点开始排队',
'A股不惧海外暴跌!走出独立行情沪指收复3300点',
'俄方称已准备好重启俄乌和谈',
'《自然》:奥密克戎感染后嗅觉丧失症状比原来少了'
]
sentence_embeddings = model.encode(sentences)

q1 = '码农的春天来了么'
q2 = '国际局势'
q3 = '健康'
queries = [q1, q2, q3]
query_embeddings = model.encode(queries)

import scipy

number_top_matches = 3
for query, query_embedding in zip(queries, query_embeddings):
distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]
results = zip(range(len(distances)), distances)
results = sorted(results, key=lambda x: x[1])
print("\nQuery:", query)
print("\nTop {} most similar sentences in corpus:".format(number_top_matches))

for idx, distance in results[0:number_top_matches]:
print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))

在这一期中,我们来体验两个知名的 NLP 预训练类库 flair 和 transformer 的 zero-shot 短文本分类。所谓zero-shot 的意思是完全不需要数据集来训练,直接掉包解决问题。和以往一样,本期的 docker 镜像已经预装了 flair,transformer,pytorch,jupyter notebook等包依赖,并且还预先下载了 flair 和 transformer 的两个预训练模型yahoo 短文本主题数据集,整个 docker 镜像达到12GB,为了就是让大家无需翻墙下载额外数据或者模型,并且使用CPU就能体验最新的NLP zero shot 文本分类。

Docker 镜像获取方式

关注 MyEncyclopedia 公众号后回复 docker-transformer-zero-shot 即可获取镜像地址和启动命令。

Flair zero shot

先来看一个 flair 短文本 zero shot 短文本分类的例子。下面的代码将句子 Spain beat Swiss for first Nations League win 归类到 politics, sportshealth 之一。

1
2
3
4
5
6
7
8
9
10
11
12
13
from flair.models import TARSClassifier
from flair.data import Sentence
import flair, torch
flair.device = torch.device('cpu')

text = 'Spain beat Swiss for first Nations League win'
tars = TARSClassifier.load('tars-base')
sentence = Sentence(text)
classes = ['politics', 'sports', 'health']
tars.predict_zero_shot(sentence, classes)

print(sentence)
print(sentence.to_dict())

最后两行输出如下,all labels 字段显示概率最高的是 sports类别,达到 0.99+。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Sentence: "Spain beat Swiss for first Nations League win" → sports (0.9952)
{
'text': 'Spain beat Swiss for first Nations League win',
'all labels': [{'value': 'sports', 'confidence': 0.9952359795570374}]
}

注意,在上面的代码中,`flair.device = torch.device('cpu')` 强制使用了 cpu 资源,否则 flair 默认使用 gpu 会报错。


## Transformer zero shot
再来看看大名鼎鼎的 transformer zero shot 的结果。这里使用了默认的 transformer zero shot 分类的模型 Transformer Bart,小伙伴们可以使用其他模型,但是有些不兼容 zero shot 分类。代码如下

​```python
from transformers import pipeline

text = 'Spain beat Swiss for first Nations League win'
classes = ['politics', 'sports', 'health']
classifier = pipeline("zero-shot-classification", device=-1)
result = classifier(text, classes, multi_label=False)

print(result)
print(result['labels'][0])

最后两行输出为

1
2
3
4
5
6
{
'sequence': 'Spain beat Swiss for first Nations League win',
'labels': ['sports', 'health', 'politics'],
'scores': [0.9476209878921509, 0.03594793379306793, 0.016431059688329697]
}
sports

resultlabels中会按照最大概率排序输出类别和对应的分数。对于这句句子,也分的相当准确,sports 为 0.94+。

也注意到 pipeline("zero-shot-classification", device=-1) 语句中 -1 表示强制使用 cpu。

Yahoo 短文本主题数据分类效果

最后,来看一个真实数据集中这两者的实际效果,yahoo_answers_topicshuggingface的一个短文本分类数据集,可以通过以下命令下载并加载

1
yahoo = load_dataset('yahoo_answers_topics')

它的具体类别为

1
2
3
4
5
6
7
8
9
10
11
12
[
'Society & Culture',
'Science & Mathematics',
'Health',
'Education & Reference',
'Computers & Internet',
'Sports',
'Business & Finance',
'Entertainment & Music',
'Family & Relationships',
'Politics & Government'
]

由于数量比较大,这里只取随机的1000个来测试,一些数据点如下

Text Topic
A Permanent resident of Canada may stay out of Canada 3 years without losing status. Politics & Government
The official major league opening game occurred on April 10, 2006, as the Cardinals defeated the Milwaukee Brewers 6-4. (Day Game) Sports
Hold down the Command key while dragging and dropping files. Computers & Internet

接着,对于每条短文本用 flair 和 transformer 来预测类别,最终统计准确率。

结果是 flair 准确率为 0.275,Transformer Bart 为 0.392,果然 transformer 显著胜出。其实,在 Yahoo数据集上取得 0.3 - 0.4 左右的效果已经不错了,毕竟有十个类别,全随机的准确率是 0.1。如果大家觉得这个效果一般的话,可以试试 tweet 情感分类数据集(具体在下面的链接中),Transformer 能达到惊人的 0.73。

下面附部分代码,完整代码可以从镜像中获得,或者感兴趣的小伙伴也可以访问

https://github.com/nlptown/nlp-notebooks/blob/master/Zero-Shot%20Text%20Classification.ipynb 获取所有五个数据集的代码,不过由于类库版本的关系,部分代码和模型或数据无法兼容,需要自行调试。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def evaluate_flair(dataset, default_name='neutral'):
classifier = TARSClassifier.load('tars-base')
total, correct = 0, 0
for item, gold_label_idx in tqdm(zip(dataset["test_texts"], dataset["test_labels"]),
total=len(dataset["test_texts"])):
sentence = Sentence(item)
classifier.predict_zero_shot(sentence, dataset["class_names"])
sorted_labels = sorted(sentence.to_dict()['all labels'], key=lambda k: k['confidence'], reverse=True)
gold_label = dataset["class_names"][gold_label_idx]
if len(sorted_labels) > 0:
predicted_label = sorted_labels[0]['value']
else:
predicted_label = default_name
if predicted_label == gold_label:
correct += 1
total += 1

return correct / total

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def evaluate_huggingface(dataset):
classifier = pipeline("zero-shot-classification", device=-1)
correct = 0
predictions, gold_labels = [], []
for text, gold_label_idx in tqdm(zip(dataset["test_texts"], dataset["test_labels"]),
total=len(dataset["test_texts"])):

result = classifier(text, dataset["class_names"], multi_label=False)
predicted_label = result['labels'][0]

gold_label = dataset["class_names"][gold_label_idx]

predictions.append(predicted_label)
gold_labels.append(gold_label)

if predicted_label == gold_label:
correct += 1

accuracy = correct / len(predictions)
return accuracy
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×