PYTORCH - matplotlib를 이용한 셀프 어텐션 visualization
0. 개요
자연어처리 rnn 과 셀프 어텐션(self attention)으로 감성분석 실험 도중에 다른 논문들과 해외 블로그 처럼 셀프 어텐션의 스코어값을 눈으로 확인해 보고 싶어서 visualization을 찾아보게 되었습니다. 그 중에서 pytorch docs 와 zhaocq-nlp님의 깃허브를 참고하여 만들어 적용해 보았습니다.
1. matplotlib
matplotlib는 파이썬 라이브러리 중에서 그래프를 그릴 때 주로 사용되는 패키지 입니다. 2D와 3D 형태로 보여지는 패키지 입니다. matplolib설치는 리눅스 기준으로 아래 코드를 입력하면 설치가 됩니다.
python3 -m pip install -U matplotlib
사실 matplotlib는 많은 사람들이 사용하고 있는 패키지로, pytorch docs에서도 이 패키지를 이용하여 그래프로 나타내었습니다.
2. zhaocq-nlp 코드
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | # -*- coding: utf-8 -*- from __future__ import absolute_import import numpy import matplotlib.pyplot as plt import json import argparse # input: # alignment matrix - numpy array # shape (target tokens + eos, number of hidden source states = source tokens +eos) # one line correpsonds to one decoding step producing one target token # each line has the attention model weights corresponding to that decoding step # each float on a line is the attention model weight for a corresponding source state. # plot: a heat map of the alignment matrix # x axis are the source tokens (alignment is to source hidden state that roughly corresponds to a source token) # y axis are the target tokens # http://stackoverflow.com/questions/14391959/heatmap-in-matplotlib-with-pcolor def plot_head_map(mma, target_labels, source_labels): fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False) # target words -> row labels ax.set_yticklabels(target_labels, minor=False) plt.xticks(rotation=45) # plt.tight_layout() plt.show() # column labels -> target words # row labels -> source words def read_alignment_matrix(f): header = f.readline().strip().split('|||') if header[0] == '': return None, None, None, None sid = int(header[0].strip()) # number of tokens in source and translation +1 for eos src_count, trg_count = map(int, header[-1].split()) # source words source_labels = header[3].decode('UTF-8').split() # source_labels.append('</s>') # target words target_labels = header[1].decode('UTF-8').split() target_labels.append('</s>') mm = [] for r in range(trg_count): alignment = map(float, f.readline().strip().split()) mm.append(alignment) mma = numpy.array(mm) return sid, mma, target_labels, source_labels def read_plot_alignment_matrices(f, start=0): attentions = json.load(f, encoding="utf-8") for idx, att in attentions.items(): if idx < start: continue source_labels = att["source"].split() + ["SEQUENCE_END"] target_labels = att["translation"].split() att_list = att["attentions"] assert att_list[0]["type"] == "simple", "Do not use this tool for multihead attention." mma = numpy.array(att_list[0]["value"]) if mma.shape[0] == len(target_labels) + 1: target_labels += ["SEQUENCE_END"] plot_head_map(mma, target_labels, source_labels) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--input', '-i', type=argparse.FileType("rb"), default="trans.att", metavar='PATH', help="Input file (default: standard input)") parser.add_argument('--start', type=int, default=0) args = parser.parse_args() read_plot_alignment_matrices(args.input, args.start) | cs |
100줄 짜리 코드가 있는데요. 사실 저희가 필요한 코드는 몇줄 안됩니다. 우선 필요한 함수를 살펴보면,
def read_plot_alignment_matrices(f, start=0):
def plot_head_map(mma, target_labels, source_labels):
이 두개의 함수만 이용하면 됩니다. read_plot_aligment_matrices 는 json 파일을 받아서, 입력시퀀스, 타겟시퀀스, 스코어 값들을 불러와 변수에 저장해 줍니다. mma는 스코어를 np.array로 바꿔준 값들 입니다.
plot_head_map 함수는 read_plot_aligment_matrices에서 만들어진 변수들을 대입해 주면, matplotlib에서 멋있게 그래프로 표현해 줍니다.
여기서 문제점은 제 모델은 결과를 json으로 저장하지 않는다는 것 입니다. 따라서 json에서 값들을 불러오는것이 아닌 test끝나고 나온 결과값들을 변수에 저장해 주었다가 대입하는 방법을 생각했습니다.
3. 내 코드에 알맞게 수정
아래 코드는 최종적으로 제 모델에 적용할 수 있게끔 짠 코드입니다. 어텐션 그래프를 그리기 위해서는 셀프 어텐션 스코어값, 타겟 시퀀스, 입력 시퀀스 이 세가지만 있으면 됩니다.
1. train.py 에 import하여 read_plot_aligment_matrices을 함수처럼 불러올 수 있게 한다.
2. 셀프 어텐션 스코어값, 타겟 시퀀스, 입력 시퀀스를 저장할 변수를 선언한다.
3. test도중 나오는 값들을 저장한다.
4. 저장된 값들을 read_plot_aligment_matrices에 대입해주어 그래프를 확인한다.
<plot_heatmap.py>
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | # -*- coding: utf-8 -*- from __future__ import absolute_import import numpy import matplotlib.pyplot as plt import json import argparse # input: # alignment matrix - numpy array # shape (target tokens + eos, number of hidden source states = source tokens +eos) # one line correpsonds to one decoding step producing one target token # each line has the attention model weights corresponding to that decoding step # each float on a line is the attention model weight for a corresponding source state. # plot: a heat map of the alignment matrix # x axis are the source tokens (alignment is to source hidden state that roughly corresponds to a source token) # y axis are the target tokens # http://stackoverflow.com/questions/14391959/heatmap-in-matplotlib-with-pcolor def plot_head_map(mma, target_labels, source_labels): fig, ax = plt.subplots() heatmap = ax.pcolor(mma, cmap=plt.cm.Blues) # put the major ticks at the middle of each cell ax.set_xticks(numpy.arange(mma.shape[1]) + 0.5, minor=False) # mma.shape[1] = target seq 길이 ax.set_yticks(numpy.arange(mma.shape[0]) + 0.5, minor=False) # mma.shape[0] = input seq 길이 # without this I get some extra columns rows # http://stackoverflow.com/questions/31601351/why-does-this-matplotlib-heatmap-have-an-extra-blank-column ax.set_xlim(0, int(mma.shape[1])) ax.set_ylim(0, int(mma.shape[0])) # want a more natural, table-like display ax.invert_yaxis() ax.xaxis.tick_top() # source words -> column labels ax.set_xticklabels(source_labels, minor=False) # target words -> row labels ax.set_yticklabels(target_labels, minor=False) plt.xticks(rotation=45) # plt.tight_layout() plt.show() def read_plot_alignment_matrices(source_labels, target_labels, alpha): mma = alpha.cpu().data.numpy() plot_head_map(mma, target_labels, source_labels) | cs |
<test_visual.py> -> train.py 변형
1 2 3 4 5 6 7 8 9 10 11 | print(test[args.seq_num]['x_tokens']) #입력 시퀀스 출력 input_seq = test[args.seq_num]['x_tokens'] # 입력 시퀀스 저장 test_batches = BatchGen([test[args.seq_num]], batch_size=args.batch_size, evaluation=True, gpu=args.cuda, opt=opt, elmo_layer=elmo_layer) predictions, ans_list, attn_score = evaluate(test_batches, model, opt) # attn_score = self attention 스코어 값을 저장 te_acc, te_f1, te_prec, te_rec = f1_score(predictions, ans_list, opt) log.warn("*** TEST_BEST acc: {0:.2f} F1: {1:.2f} ".format(te_acc, te_f1)) write_result(predictions, ans_list) print('predictions = ',predictions, 'ans_list = ', ans_list) # 예측한 값과 정답 비교 read_plot_alignment_matrices(input_seq, input_seq, attn_score[0]) # self attention 이기 때문에 입력과 출력이 같다. | cs |
위의 코드는 train.py 에 test하는 부분입니다. test객체는 test 데이터를 전부 담고 있습니다. batch size를 1로 주어 원하는 문장에 대해 self attention score 그래프를 그릴 수 있게 해주었고, test가 끝난 뒤에 저장된 변수들을 read_plot_aligment_matrices에 넣어주어 그래프를 그려보았습니다. 아래 사진은 그 결과입니다.
input seq : 상당히 시답잖고 어설프네.
명령어를 보시면 batch_size 는 1로 주어 모델이 한문장만 볼 수 있게끔 해주었고, seq_num 578은 test 데이터에서 578번째 문장을 모델에 넣어주라는 의미 입니다. 여기에서 578번째 문장은 '상당히 시답잖고 어설프네'라는 문이네요.
제 모델은 영화평 댓글을 긍정인지, 부정인지 나누어주는 감성분석입니다. 그래프를 보시면, '시답잖' 부분과 '어설프네' 부분이 많은 스코어를 가지고 있는 것으로 보아, 이 부분이 결과에 많이 영향을 주었다는 점을 알 수 있습니다.
4. 적용하려면?
이 코드를 적용하려면 세 가지만 있으면 됩니다.
---------------
첫번째 : self attention 모델에서 뽑아온 score값
두번째 : 입력 시퀀스
세번째 : 아웃풋 시퀀스 (self attention이라 입력 시퀀스와 같다.)
---------------
뽑아오신 값들을 read_plot_aligment_matrices에 넣어주시면 바로 그래프를 볼 수 있습니다.
5. 출처 및 참조
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
https://github.com/zhaocq-nlp/Attention-Visualization