Генератор текста на основе триграмм

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3.  
  4. import re
  5. from random import uniform
  6. from collections import defaultdict
  7.  
  8. r_alphabet = re.compile(u'[а-яА-Я0-9-]+|[.,:;?!]+')
  9.  
  10. def gen_lines(corpus):
  11.     data = open(corpus)
  12.     for line in data:
  13.         yield line.decode('utf-8').lower()
  14.  
  15. def gen_tokens(lines):
  16.     for line in lines:
  17.         for token in r_alphabet.findall(line):
  18.             yield token
  19.  
  20. def gen_trigrams(tokens):
  21.     t0, t1 = '$', '$'
  22.     for t2 in tokens:
  23.         yield t0, t1, t2
  24.         if t2 in '.!?':
  25.             yield t1, t2, '$'
  26.             yield t2, '$','$'
  27.             t0, t1 = '$', '$'
  28.         else:
  29.             t0, t1 = t1, t2
  30.  
  31. def train(corpus):
  32.     lines = gen_lines(corpus)
  33.     tokens = gen_tokens(lines)
  34.     trigrams = gen_trigrams(tokens)
  35.  
  36.     bi, tri = defaultdict(lambda: 0.0), defaultdict(lambda: 0.0)
  37.  
  38.     for t0, t1, t2 in trigrams:
  39.         bi[t0, t1] += 1
  40.         tri[t0, t1, t2] += 1
  41.  
  42.     model = {}
  43.     for (t0, t1, t2), freq in tri.iteritems():
  44.         if (t0, t1) in model:
  45.             model[t0, t1].append((t2, freq/bi[t0, t1]))
  46.         else:
  47.             model[t0, t1] = [(t2, freq/bi[t0, t1])]
  48.     return model
  49.  
  50. def generate_sentence(model):
  51.     phrase = ''
  52.     t0, t1 = '$', '$'
  53.     while 1:
  54.         t0, t1 = t1, unirand(model[t0, t1])
  55.         if t1 == '$': break
  56.         if t1 in ('.!?,;:') or t0 == '$':
  57.             phrase += t1
  58.         else:
  59.             phrase += ' ' + t1
  60.     return phrase.capitalize()
  61.  
  62. def unirand(seq):
  63.     sum_, freq_ = 0, 0
  64.     for item, freq in seq:
  65.         sum_ += freq
  66.     rnd = uniform(0, sum_)
  67.     for token, freq in seq:
  68.         freq_ += freq
  69.         if rnd < freq_:
  70.             return token
  71.  
  72. if __name__ == '__main__':
  73.     model = train('tolstoy.txt')
  74.     for i in range(10):
  75.         print generate_sentence(model)

Реклама

Мы в соцсетях

tw tg yt gt