玩转transformer+flair zero shot 短文本分类:无需翻墙或额外下载模型和数据集的CPU docker镜像

在这一期中,我们来体验两个知名的 NLP 预训练类库 flair 和 transformer 的 zero-shot 短文本分类。所谓zero-shot 的意思是完全不需要数据集来训练,直接掉包解决问题。和以往一样,本期的 docker 镜像已经预装了 flair,transformer,pytorch,jupyter notebook等包依赖,并且还预先下载了 flair 和 transformer 的两个预训练模型yahoo 短文本主题数据集,整个 docker 镜像达到12GB,为了就是让大家无需翻墙下载额外数据或者模型,并且使用CPU就能体验最新的NLP zero shot 文本分类。

Docker 镜像获取方式

关注 MyEncyclopedia 公众号后回复 docker-transformer-zero-shot 即可获取镜像地址和启动命令。

Flair zero shot

先来看一个 flair 短文本 zero shot 短文本分类的例子。下面的代码将句子 Spain beat Swiss for first Nations League win 归类到 politics, sportshealth 之一。

1
2
3
4
5
6
7
8
9
10
11
12
13
from flair.models import TARSClassifier
from flair.data import Sentence
import flair, torch
flair.device = torch.device('cpu')

text = 'Spain beat Swiss for first Nations League win'
tars = TARSClassifier.load('tars-base')
sentence = Sentence(text)
classes = ['politics', 'sports', 'health']
tars.predict_zero_shot(sentence, classes)

print(sentence)
print(sentence.to_dict())

最后两行输出如下,all labels 字段显示概率最高的是 sports类别,达到 0.99+。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Sentence: "Spain beat Swiss for first Nations League win" → sports (0.9952)
{
'text': 'Spain beat Swiss for first Nations League win',
'all labels': [{'value': 'sports', 'confidence': 0.9952359795570374}]
}

注意,在上面的代码中,`flair.device = torch.device('cpu')` 强制使用了 cpu 资源,否则 flair 默认使用 gpu 会报错。


## Transformer zero shot
再来看看大名鼎鼎的 transformer zero shot 的结果。这里使用了默认的 transformer zero shot 分类的模型 Transformer Bart,小伙伴们可以使用其他模型,但是有些不兼容 zero shot 分类。代码如下

​```python
from transformers import pipeline

text = 'Spain beat Swiss for first Nations League win'
classes = ['politics', 'sports', 'health']
classifier = pipeline("zero-shot-classification", device=-1)
result = classifier(text, classes, multi_label=False)

print(result)
print(result['labels'][0])

最后两行输出为

1
2
3
4
5
6
{
'sequence': 'Spain beat Swiss for first Nations League win',
'labels': ['sports', 'health', 'politics'],
'scores': [0.9476209878921509, 0.03594793379306793, 0.016431059688329697]
}
sports

resultlabels中会按照最大概率排序输出类别和对应的分数。对于这句句子,也分的相当准确,sports 为 0.94+。

也注意到 pipeline("zero-shot-classification", device=-1) 语句中 -1 表示强制使用 cpu。

Yahoo 短文本主题数据分类效果

最后,来看一个真实数据集中这两者的实际效果,yahoo_answers_topicshuggingface的一个短文本分类数据集,可以通过以下命令下载并加载

1
yahoo = load_dataset('yahoo_answers_topics')

它的具体类别为

1
2
3
4
5
6
7
8
9
10
11
12
[
'Society & Culture',
'Science & Mathematics',
'Health',
'Education & Reference',
'Computers & Internet',
'Sports',
'Business & Finance',
'Entertainment & Music',
'Family & Relationships',
'Politics & Government'
]

由于数量比较大,这里只取随机的1000个来测试,一些数据点如下

Text Topic
A Permanent resident of Canada may stay out of Canada 3 years without losing status. Politics & Government
The official major league opening game occurred on April 10, 2006, as the Cardinals defeated the Milwaukee Brewers 6-4. (Day Game) Sports
Hold down the Command key while dragging and dropping files. Computers & Internet

接着,对于每条短文本用 flair 和 transformer 来预测类别,最终统计准确率。

结果是 flair 准确率为 0.275,Transformer Bart 为 0.392,果然 transformer 显著胜出。其实,在 Yahoo数据集上取得 0.3 - 0.4 左右的效果已经不错了,毕竟有十个类别,全随机的准确率是 0.1。如果大家觉得这个效果一般的话,可以试试 tweet 情感分类数据集(具体在下面的链接中),Transformer 能达到惊人的 0.73。

下面附部分代码,完整代码可以从镜像中获得,或者感兴趣的小伙伴也可以访问

https://github.com/nlptown/nlp-notebooks/blob/master/Zero-Shot%20Text%20Classification.ipynb 获取所有五个数据集的代码,不过由于类库版本的关系,部分代码和模型或数据无法兼容,需要自行调试。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def evaluate_flair(dataset, default_name='neutral'):
classifier = TARSClassifier.load('tars-base')
total, correct = 0, 0
for item, gold_label_idx in tqdm(zip(dataset["test_texts"], dataset["test_labels"]),
total=len(dataset["test_texts"])):
sentence = Sentence(item)
classifier.predict_zero_shot(sentence, dataset["class_names"])
sorted_labels = sorted(sentence.to_dict()['all labels'], key=lambda k: k['confidence'], reverse=True)
gold_label = dataset["class_names"][gold_label_idx]
if len(sorted_labels) > 0:
predicted_label = sorted_labels[0]['value']
else:
predicted_label = default_name
if predicted_label == gold_label:
correct += 1
total += 1

return correct / total

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def evaluate_huggingface(dataset):
classifier = pipeline("zero-shot-classification", device=-1)
correct = 0
predictions, gold_labels = [], []
for text, gold_label_idx in tqdm(zip(dataset["test_texts"], dataset["test_labels"]),
total=len(dataset["test_texts"])):

result = classifier(text, dataset["class_names"], multi_label=False)
predicted_label = result['labels'][0]

gold_label = dataset["class_names"][gold_label_idx]

predictions.append(predicted_label)
gold_labels.append(gold_label)

if predicted_label == gold_label:
correct += 1

accuracy = correct / len(predictions)
return accuracy
Bert 中文短句相似度计算 Docker CPU镜像 手机和微信中完美重排和阅读 Arxiv 论文

Author and License Contact MyEncyclopedia to Authorize
myencyclopedia.top link https://blog.myencyclopedia.top/zh/2022/docker-flair-transformer-zero-shot/
github.io link https://myencyclopedia.github.io/zh/2022/docker-flair-transformer-zero-shot/

You need to set install_url to use ShareThis. Please set it in _config.yml.

评论

You forgot to set the shortname for Disqus. Please set it in _config.yml.
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×