當前位置:成語大全網 - 端午節詩句 - 中文NLP筆記:11.?基於?LSTM?生成古詩

中文NLP筆記:11.?基於?LSTM?生成古詩

基於 LSTM 生成古詩

1. 語料準備

? 壹***四萬多首古詩,壹行壹首詩

2. 預處理

? 將漢字表示為 One-Hot 的形式

? 在每行末尾加上 ] 符號是為了標識這首詩已經結束,說明 ] 符號之前的語句和之後的語句是沒有關聯關系的,後面會舍棄掉包含 ] 符號的訓練數據。

? puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》']

def preprocess_file(Config):

# 語料文本內容

files_content = ''

with open(Config.poetry_file, 'r', encoding='utf-8') as f:

for line in f:

# 每行的末尾加上"]"符號代表壹首詩結束

for char in puncs:

line = line.replace(char, "")

files_content += line.strip() + "]"

words = sorted(list(files_content))

words.remove(']')

counted_words = {}

for word in words:

if word in counted_words:

counted_words[word] += 1

else:

counted_words[word] = 1

# 去掉低頻的字

erase = []

for key in counted_words:

if counted_words[key] <= 2:

erase.append(key)

for key in erase:

del counted_words[key]

del counted_words[']']

wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])

words, _ = zip(*wordPairs)

# word到id的映射

word2num = dict((c, i + 1) for i, c in enumerate(words))

num2word = dict((i, c) for i, c in enumerate(words))

word2numF = lambda x: word2num.get(x, 0)

return word2numF, num2word, words, files_content

3. 模型參數配置

? class Config(object):

poetry_file = 'poetry.txt'

weight_file = 'poetry_model.h5'

# 根據前六個字預測第七個字

max_len = 6

batch_size = 512

learning_rate = 0.001

4. 構建模型

? 通過 PoetryModel 類實現

? class PoetryModel(object):

def __init__(self, config):

pass

def build_model(self):

pass

def sample(self, preds, temperature=1.0):

pass

def generate_sample_result(self, epoch, logs):

pass

def predict(self, text):

pass

def data_generator(self):

pass

def train(self):

pass

? (1)init 函數

? 加載 Config 配置信息,進行語料預處理和模型加載

? def __init__(self, config):

self.model = None

self.do_train = True

self.loaded_model = False

self.config = config

# 文件預處理

self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)

if os.path.exists(self.config.weight_file):

self.model = load_model(self.config.weight_file)

self.model.summary()

else:

self.train()

self.do_train = False

self.loaded_model = True

? (2)build_model 函數

? GRU 模型建立

? def build_model(self):

'''建立模型'''

input_tensor = Input(shape=(self.config.max_len,))

embedd = Embedding(len(self.num2word)+1, 300, input_length=self.config.max_len)(input_tensor)

lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

dropout = Dropout(0.6)(lstm)

lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)

dropout = Dropout(0.6)(lstm)

flatten = Flatten()(lstm)

dense = Dense(len(self.words), activation='softmax')(flatten)

self.model = Model(inputs=input_tensor, outputs=dense)

optimizer = Adam(lr=self.config.learning_rate)

self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

? (3)sample 函數

? def sample(self, preds, temperature=1.0):

preds = np.asarray(preds).astype('float64')

preds = np.log(preds) / temperature

exp_preds = np.exp(preds)

preds = exp_preds / np.sum(exp_preds)

probas = np.random.multinomial(1, preds, 1)

return np.argmax(probas)

? (4)訓練模型

? def generate_sample_result(self, epoch, logs):?

print("\n==================Epoch {}=====================".format(epoch))

for diversity in [0.5, 1.0, 1.5]:

print("------------Diversity {}--------------".format(diversity))

start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)

generated = ''

sentence = self.files_content[start_index: start_index + self.config.max_len]

generated += sentence

for i in range(20):

x_pred = np.zeros((1, self.config.max_len))

for t, char in enumerate(sentence[-6:]):

x_pred[0, t] = self.word2numF(char)

preds = self.model.predict(x_pred, verbose=0)[0]

next_index = self.sample(preds, diversity)

next_char = self.num2word[next_index]

generated += next_char

sentence = sentence + next_char

print(sentence)

? (5)predict 函數

? 根據給出的文字,生成詩句

? def predict(self, text):

if not self.loaded_model:

return

with open(self.config.poetry_file, 'r', encoding='utf-8') as f:

file_list = f.readlines()

random_line = random.choice(file_list)

# 如果給的text不到四個字,則隨機補全

if not text or len(text) != 4:

for _ in range(4 - len(text)):

random_str_index = random.randrange(0, len(self.words))

text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。',

? ','] else self.num2word.get(

random_str_index + 1)

seed = random_line[-(self.config.max_len):-1]

res = ''

seed = 'c' + seed

for c in text:

seed = seed[1:] + c

for j in range(5):

x_pred = np.zeros((1, self.config.max_len))

for t, char in enumerate(seed):

x_pred[0, t] = self.word2numF(char)

preds = self.model.predict(x_pred, verbose=0)[0]

next_index = self.sample(preds, 1.0)

next_char = self.num2word[next_index]

seed = seed[1:] + next_char

res += seed

return res

? (6) data_generator 函數

? 生成數據,提供給模型訓練時使用

def data_generator(self):

i = 0

while 1:

x = self.files_content[i: i + self.config.max_len]

y = self.files_content[i + self.config.max_len]

puncs = [']', '[', '(', ')', '{', '}', ':', '《', '》', ':']

if len([i for i in puncs if i in x]) != 0:

i += 1

continue

if len([i for i in puncs if i in y]) != 0:

i += 1

continue

y_vec = np.zeros(

shape=(1, len(self.words)),

dtype=np.bool

)

y_vec[0, self.word2numF(y)] = 1.0

x_vec = np.zeros(

shape=(1, self.config.max_len),

dtype=np.int32

)

for t, char in enumerate(x):

x_vec[0, t] = self.word2numF(char)

yield x_vec, y_vec

i += 1

? (7)train 函數

? def train(self):

#number_of_epoch = len(self.files_content) // self.config.batch_size

number_of_epoch = 10

if not self.model:

self.build_model()

self.model.summary()

self.model.fit_generator(

generator=self.data_generator(),

verbose=True,

steps_per_epoch=self.config.batch_size,

epochs=number_of_epoch,

callbacks=[

keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),

LambdaCallback(on_epoch_end=self.generate_sample_result)

]

)

5. 進行模型訓練

? model = PoetryModel(Config)

6. 作詩

? text = input("text:")

sentence = model.predict(text)

print(sentence)

學習資料:

《中文自然語言處理入門實戰》