在这一期中,我们来体验两个知名的 NLP 预训练类库 flair 和 transformer
的 zero-shot 短文本分类。所谓zero-shot
的意思是完全不需要数据集来训练,直接掉包解决问题。和以往一样,本期的
docker 镜像已经预装了 flair,transformer,pytorch,jupyter
notebook等包依赖 ,并且还预先下载了 flair 和 transformer
的两个预训练模型 和 yahoo
短文本主题数据集 ,整个 docker
镜像达到12GB,为了就是让大家无需翻墙下载额外数据或者模型,并且使用CPU就能体验最新的NLP
zero shot 文本分类。
Docker 镜像获取方式
关注 MyEncyclopedia
公众号后回复
docker-transformer-zero-shot
即可获取镜像地址和启动命令。
Flair zero shot
先来看一个 flair 短文本 zero shot 短文本分类的例子。下面的代码将句子
Spain beat Swiss for first Nations League win 归类到
politics ,
sports ,health 之一。
1 2 3 4 5 6 7 8 9 10 11 12 13 from flair.models import TARSClassifierfrom flair.data import Sentenceimport flair, torchflair.device = torch.device('cpu' ) text = 'Spain beat Swiss for first Nations League win' tars = TARSClassifier.load('tars-base' ) sentence = Sentence(text) classes = ['politics' , 'sports' , 'health' ] tars.predict_zero_shot(sentence, classes) print (sentence)print (sentence.to_dict())
最后两行输出如下,all labels
字段显示概率最高的是
sports
类别,达到 0.99+。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 Sentence: "Spain beat Swiss for first Nations League win" → sports (0.9952) { 'text' : 'Spain beat Swiss for first Nations League win' , 'all labels' : [{'value' : 'sports' , 'confidence' : 0.9952359795570374}] } 注意,在上面的代码中,`flair.device = torch.device('cpu' )` 强制使用了 cpu 资源,否则 flair 默认使用 gpu 会报错。 再来看看大名鼎鼎的 transformer zero shot 的结果。这里使用了默认的 transformer zero shot 分类的模型 Transformer Bart,小伙伴们可以使用其他模型,但是有些不兼容 zero shot 分类。代码如下 ```python from transformers import pipeline text = 'Spain beat Swiss for first Nations League win' classes = ['politics' , 'sports' , 'health' ] classifier = pipeline("zero-shot-classification" , device=-1) result = classifier(text, classes, multi_label=False) print (result)print (result['labels' ][0])
最后两行输出为
1 2 3 4 5 6 { 'sequence' : 'Spain beat Swiss for first Nations League win' , 'labels' : ['sports' , 'health' , 'politics' ], 'scores' : [0.9476209878921509, 0.03594793379306793, 0.016431059688329697] } sports
result
的
labels
中会按照最大概率排序输出类别和对应的分数。对于这句句子,也分的相当准确,sports
为 0.94+。
也注意到 pipeline("zero-shot-classification", device=-1)
语句中 -1 表示强制使用 cpu。
Yahoo 短文本主题数据分类效果
最后,来看一个真实数据集中这两者的实际效果,yahoo_answers_topics
是
huggingface
的一个短文本分类数据集,可以通过以下命令下载并加载
1 yahoo = load_dataset('yahoo_answers_topics' )
它的具体类别为
1 2 3 4 5 6 7 8 9 10 11 12 [ 'Society & Culture' , 'Science & Mathematics' , 'Health' , 'Education & Reference' , 'Computers & Internet' , 'Sports' , 'Business & Finance' , 'Entertainment & Music' , 'Family & Relationships' , 'Politics & Government' ]
由于数量比较大,这里只取随机的1000个来测试,一些数据点如下
A Permanent resident of Canada may stay out of Canada 3 years
without losing status.
Politics & Government
The official major league opening game occurred on April 10, 2006,
as the Cardinals defeated the Milwaukee Brewers 6-4. (Day Game)
Sports
Hold down the Command key while dragging and dropping files.
Computers & Internet
接着,对于每条短文本用 flair 和 transformer
来预测类别,最终统计准确率。
结果是 flair 准确率为 0.275 ,Transformer Bart 为
0.392 ,果然 transformer 显著胜出。其实,在
Yahoo数据集上取得 0.3 - 0.4
左右的效果已经不错了,毕竟有十个类别,全随机的准确率是
0.1。如果大家觉得这个效果一般的话,可以试试 tweet
情感分类数据集(具体在下面的链接中),Transformer 能达到惊人的
0.73。
下面附部分代码,完整代码可以从镜像中获得,或者感兴趣的小伙伴也可以访问
https://github.com/nlptown/nlp-notebooks/blob/master/Zero-Shot%20Text%20Classification.ipynb
获取所有五个数据集的代码,不过由于类库版本的关系,部分代码和模型或数据无法兼容,需要自行调试。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def evaluate_flair (dataset, default_name='neutral' ): classifier = TARSClassifier.load('tars-base' ) total, correct = 0 , 0 for item, gold_label_idx in tqdm(zip (dataset["test_texts" ], dataset["test_labels" ]), total=len (dataset["test_texts" ])): sentence = Sentence(item) classifier.predict_zero_shot(sentence, dataset["class_names" ]) sorted_labels = sorted (sentence.to_dict()['all labels' ], key=lambda k: k['confidence' ], reverse=True ) gold_label = dataset["class_names" ][gold_label_idx] if len (sorted_labels) > 0 : predicted_label = sorted_labels[0 ]['value' ] else : predicted_label = default_name if predicted_label == gold_label: correct += 1 total += 1 return correct / total
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def evaluate_huggingface (dataset ): classifier = pipeline("zero-shot-classification" , device=-1 ) correct = 0 predictions, gold_labels = [], [] for text, gold_label_idx in tqdm(zip (dataset["test_texts" ], dataset["test_labels" ]), total=len (dataset["test_texts" ])): result = classifier(text, dataset["class_names" ], multi_label=False ) predicted_label = result['labels' ][0 ] gold_label = dataset["class_names" ][gold_label_idx] predictions.append(predicted_label) gold_labels.append(gold_label) if predicted_label == gold_label: correct += 1 accuracy = correct / len (predictions) return accuracy
评论
shortname
for Disqus. Please set it in_config.yml
.