import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data
import jieba
dtype = torch.FloatTensor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(123)
stop_lst = [u'的',u'\u3000',r'“',r'”',u'…',u'!',u'?',u':',u'\n',u',',"还","看",";",";", "但","让","个","也","要","安","星",u'随着', u'对于', u'对',u'等',u'。',' ',u' ',u'、',u'?',u'在',u'了',"我","你","他","她","它"]
word_sequence = [word for word in word_cuts if word not in stop_lst]
vocab = list(set(word_sequence))
word2idx = {w:i for i,w in enumerate(vocab) if w}
skip_grams = []
for idx in range(C, len(word_sequence) - C):
center = word2idx[word_sequence[idx]] # center word
context_idx = list(range(idx - C, idx)) + list(range(idx + 1, idx + C + 1)) # context word idx
context = [word2idx[word_sequence[i]] for i in context_idx]
for w in context:
skip_grams.append([center, w])
print(skip_grams[:2])
def make_data(skip_grams):
input_data = []
output_data = []
for i in range(len(skip_grams)):
input_data.append(np.eye(voc_size)[skip_grams[i][0]]) # central word
output_data.append(skip_grams[i][1]) # background word
return input_data, output_data
5.构建网络结构 Model
class Word2Vec(nn.Module):
def init(self):
super(Word2Vec, self).init()
# W and V is not Traspose relationship
self.W = nn.Parameter(torch.randn(voc_size, embedding_size).type(dtype))
self.V = nn.Parameter(torch.randn(embedding_size, voc_size).type(dtype))
def forward(self, X):
# X : [batch_size, voc_size] one-hot
# torch.mm only for 2 dim matrix, but torch.matmul can use to any dim
hidden_layer = torch.matmul(X, self.W) # hidden_layer : [batch_size, embedding_size]
output_layer = torch.matmul(hidden_layer, self.V) # output_layer : [batch_size, voc_size]
return output_layer
model = Word2Vec().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
6.正式训练
for epoch in range(epochs):
for i, (batch_x, batch_y) in enumerate(loader):
batch_x = batch_x.to(device)
batch_y = batch_y.to(device)
print('batch_x:',batch_x)
#print('batch_y:',batch_y)
pred = model(batch_x)
loss = criterion(pred, batch_y)
if (epoch + 1) % 100 == 0:
print(epoch + 1, i, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
7.可视化展示,直观查看训练结果 (可选步骤,训练数据太大,图中效果不明显)
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(10,8))
for i, label in enumerate(vocab):
W, WT = model.parameters()
x,y = float(W[i][1]), float(W[i][2])
plt.scatter(x, y)
plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
if i > 50:
break
plt.show()
查询训练结果 len(docments)* embedding_size
W, WT = model.parameters()
word_vec = {voc:w for voc,w in zip(vocab,W.tolist())}
for name,words in word_vec.items():
print("{}==>{}".format(name,words))