贝叶斯算法可以用来做拼写检查、文本分类、垃圾邮件过滤等工作,前面我们用贝叶斯做了文本分类,这次用它来做拼写检查,参考:How to Write a Spelling Corrector
给定一个单词, 我们的任务是选择和它最相似的拼写正确的单词.
对应的贝叶斯问题就是, 给定一个词 w, 在所有正确的拼写词中, 我们想要找一个正确的词 c, 使得对于 w 的条件概率最大, 也就是说:
argmaxc P(c|w)
按照贝叶斯理论上面的式子等价于:
argmaxc P(w|c) P(c) / P(w)
因为用户可以输错任何词, 因此对于任何 c 来讲, 出现 w 的概率 P(w) 都是一样的, 从而我们在上式中忽略它, 写成:
argmaxc P(w|c) P(c)
因此argmaxc P(w|c) P(c)就是编辑距离与P(c)的的乘积
其中编辑距离:两个词之间的编辑距离定义为使用了几次插入(在词中插入一个单字母), 删除(删除一个单字母), 交换(交换相邻两个字母), 替换(把一个字母换成另一个)的操作从一个词变到另一个词.
一般情况下,编辑距离为2时已经可以覆盖大部分情况
为了尽量覆盖较多的词语,首先从词典中读入常见的英文单词
从en-US读取词语【词语开始[Words]】
然后,从训练语料(训练语料在此下载 big.txt)训练我们的词典(语言模型,得到词语概率,出现频率越高的词语越常见)
1 /// <summary>
2 /// 训练词典
3 /// </summary>
4 /// <param name="trainingFile"></param>
5 /// <param name="ht"></param>
6 public static void TrainDic(string trainingFile, Dictionary<string, int> ht)
7 {
8
9 StreamReader reader = new StreamReader(trainingFile);
10 string sLine = "";//存放每一个句子
11
12 string pattern = @"[a-z]+";//匹配单词
13
14 Regex regex = new Regex(pattern);
15 int count = 0;//计算单词的个数
16
17 while (sLine != null)
18 {
19 sLine = reader.ReadLine();
20 if (sLine != null)
21 {
22 sLine = sLine.ToLower().Replace("'", " ");
23 var matchWords = regex.Matches(sLine);
24
25 foreach (Match match in matchWords)
26 {
27 var word = match.Value;
28 if (!ht.ContainsKey(word))
29 {
30 count++;
31 ht.Add(word, 1);
32 }
33 else
34 {
35 ht[word]++;
36 }
37 }
38 }
39 }
40 reader.Close();
41 }
为了复用,可以将训练后的词典保存取来
StringBuilder dicBuilder = new StringBuilder();
foreach (var item in Dic)
{
dicBuilder.AppendLine(item.Key + "\t" + item.Value);
}
File.WriteAllText(dicFile, dicBuilder.ToString());
我们定义优先级: 编辑举例为1》编辑举例为2
首先,找到编辑距离为1的词语
/// <summary>
/// 编辑距离为1的词语
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetEdits1(string word)
{
var n = word.Length;
var tempWord = "";
var editsWords = new List<string>();
for (int i = 0; i < n; i++)//delete一个字母的情况
{
tempWord = word.Substring(0, i) + word.Substring(i + 1);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
for (int i = 0; i < n - 1; i++)//调换transposition一个字母的情况
{
tempWord = word.Substring(0, i) + word.Substring(i + 1, 1) + word.Substring(i, 1) + word.Substring(i + 2);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
for (int i = 0; i < n; i++)//替换replace一个字母的情况
{
string t = word.Substring(i, 1);
for (int ch = 'a'; ch <= 'z'; ch++)
{
if (ch != Convert.ToChar(t))
{
tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i + 1);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
}
}
for (int i = 0; i <= n; i++)//insert一个字母的情况
{
//string t = word.Substring(i, 1);
for (int ch = 'a'; ch <= 'z'; ch++)
{
tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
}
return editsWords;
}
如果编辑举例为1的词语没有正确的词语时,继续寻找为2的词语,为了控制规模,只选取正确的词语
/// <summary>
/// 获取编辑距离为2的单词
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetEdits2(string word)
{
Stopwatch watch = new Stopwatch();
watch.Start();
var words = GetEdits1(word);
var result = words.AsReadOnly().ToList();
foreach (var edit in words)
{
GetEdits1(edit).ForEach(w =>
{
if (Dic.ContainsKey(w))
{
result.Add(w);
}
});
}
watch.Stop();
Console.WriteLine(watch.ElapsedMilliseconds);
return result;
}
最后是获取建议词语的代码,最后的结果按照概率大小倒排序,取前5个
/// <summary>
/// 获取建议词语
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetSuggestWords(string word)
{
var result = GetEdits1(word).Where(w => Dic.ContainsKey(w)).ToList();
if (result.Count == 0)
{
result = GetEdits2(word);
if (result.Count == 0)
{
result.Add(word);
}
}
// 按先验概率排序
result = result.OrderByDescending(w => Dic.ContainsKey(w) ? Dic[w] : 1).ToList();
return result.Take(Math.Min(result.Count, 5)).ToList();
}
static Dictionary<string, int> Dic;
static string dicFile = "dic.txt";
static string trainingFile = "training.txt";
static void Main(string[] args)
{
if (File.Exists(dicFile))
{
Console.WriteLine("加载词典中...");
LoadDic();
Console.WriteLine("加载词典完成");
}
else
{
Console.WriteLine("训练词典中...");
Dic = LoadUSDic();
TrainDic(trainingFile, Dic);
StringBuilder dicBuilder = new StringBuilder();
foreach (var item in Dic)
{
dicBuilder.AppendLine(item.Key + "\t" + item.Value);
}
File.WriteAllText(dicFile, dicBuilder.ToString());
var wordCount = Dic.Count;
Console.WriteLine("训练完成...");
}
Console.WriteLine("请输入词语...");
var inputWord = Console.ReadLine();
while (!inputWord.Equals("exit"))
{
if (Dic.ContainsKey(inputWord))
{
Console.WriteLine("你输入的词语 【" + inputWord + "】 是正确的!");
}
else
{
var suggestWords = GetSuggestWords(inputWord);
Console.WriteLine("候选词语: ");
foreach (var word in suggestWords)
{
Console.WriteLine("\t\t\t " + word);
}
}
Console.WriteLine("请输入词语....");
inputWord = Console.ReadLine();
}
}
/// <summary>
/// 加载词典
/// </summary>
public static void LoadDic()
{
Dic = new Dictionary<string, int>();
var lines = File.ReadAllLines(dicFile);
foreach (var line in lines)
{
if (line != "")
{
var dicItem = line.Split('\t');
if (dicItem.Length == 2)
Dic.Add(dicItem[0], int.Parse(dicItem[1]));
}
}
}
完整代码
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Collections;
using System.IO;
using System.Text.RegularExpressions;
using System.Diagnostics;
namespace SpellCheck
{
class Program
{
static Dictionary<string, int> Dic;
static string dicFile = "dic.txt";
static string trainingFile = "training.txt";
static void Main(string[] args)
{
if (File.Exists(dicFile))
{
Console.WriteLine("加载词典中...");
LoadDic();
Console.WriteLine("加载词典完成");
}
else
{
Console.WriteLine("训练词典中...");
Dic = LoadUSDic();
TrainDic(trainingFile, Dic);
StringBuilder dicBuilder = new StringBuilder();
foreach (var item in Dic)
{
dicBuilder.AppendLine(item.Key + "\t" + item.Value);
}
File.WriteAllText(dicFile, dicBuilder.ToString());
var wordCount = Dic.Count;
Console.WriteLine("训练完成...");
}
Console.WriteLine("请输入词语...");
var inputWord = Console.ReadLine();
while (!inputWord.Equals("exit"))
{
if (Dic.ContainsKey(inputWord))
{
Console.WriteLine("你输入的词语 【" + inputWord + "】 是正确的!");
}
else
{
var suggestWords = GetSuggestWords(inputWord);
Console.WriteLine("候选词语: ");
foreach (var word in suggestWords)
{
Console.WriteLine("\t\t\t " + word);
}
}
Console.WriteLine("请输入词语....");
inputWord = Console.ReadLine();
}
}
/// <summary>
/// 加载词典
/// </summary>
public static void LoadDic()
{
Dic = new Dictionary<string, int>();
var lines = File.ReadAllLines(dicFile);
foreach (var line in lines)
{
if (line != "")
{
var dicItem = line.Split('\t');
if (dicItem.Length == 2)
Dic.Add(dicItem[0], int.Parse(dicItem[1]));
}
}
}
/// <summary>
/// 训练词典
/// </summary>
/// <param name="trainingFile"></param>
/// <param name="ht"></param>
public static void TrainDic(string trainingFile, Dictionary<string, int> ht)
{
StreamReader reader = new StreamReader(trainingFile);
string sLine = "";//存放每一个句子
string pattern = @"[a-z]+";//匹配单词
Regex regex = new Regex(pattern);
int count = 0;//计算单词的个数
while (sLine != null)
{
sLine = reader.ReadLine();
if (sLine != null)
{
sLine = sLine.ToLower().Replace("'", " ");
var matchWords = regex.Matches(sLine);
foreach (Match match in matchWords)
{
var word = match.Value;
if (!ht.ContainsKey(word))
{
count++;
ht.Add(word, 1);
}
else
{
ht[word]++;
}
}
}
}
reader.Close();
}
/// <summary>
/// 从en-US读取词语【词语开始[Words]】
/// </summary>
/// <returns></returns>
public static Dictionary<string, int> LoadUSDic()
{
var dic = new Dictionary<string, int>();
string currentSection = "";
FileStream fs = new FileStream("en-US.dic", FileMode.Open, FileAccess.Read, FileShare.Read);
StreamReader sr = new StreamReader(fs, Encoding.UTF8);
while (sr.Peek() >= 0)
{
string tempLine = sr.ReadLine().Trim();
if (tempLine.Length > 0)
{
switch (tempLine)
{
case "[Words]":
currentSection = tempLine;
break;
default:
switch (currentSection)
{
case "[Words]": // dictionary word list
// splits word into its parts
string[] parts = tempLine.Split('/');
dic.Add(parts[0], 1);
break;
} // currentSection swith
break;
} //tempLine switch
} // if templine
} // read line
sr.Close();
fs.Close();
return dic;
}
/// <summary>
/// 编辑距离为1的词语
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetEdits1(string word)
{
var n = word.Length;
var tempWord = "";
var editsWords = new List<string>();
for (int i = 0; i < n; i++)//delete一个字母的情况
{
tempWord = word.Substring(0, i) + word.Substring(i + 1);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
for (int i = 0; i < n - 1; i++)//调换transposition一个字母的情况
{
tempWord = word.Substring(0, i) + word.Substring(i + 1, 1) + word.Substring(i, 1) + word.Substring(i + 2);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
for (int i = 0; i < n; i++)//替换replace一个字母的情况
{
string t = word.Substring(i, 1);
for (int ch = 'a'; ch <= 'z'; ch++)
{
if (ch != Convert.ToChar(t))
{
tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i + 1);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
}
}
for (int i = 0; i <= n; i++)//insert一个字母的情况
{
//string t = word.Substring(i, 1);
for (int ch = 'a'; ch <= 'z'; ch++)
{
tempWord = word.Substring(0, i) + Convert.ToChar(ch) + word.Substring(i);
if (!editsWords.Contains(tempWord))
editsWords.Add(tempWord);
}
}
return editsWords;
}
/// <summary>
/// 获取编辑距离为2的单词
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetEdits2(string word)
{
Stopwatch watch = new Stopwatch();
watch.Start();
var words = GetEdits1(word);
var result = words.AsReadOnly().ToList();
foreach (var edit in words)
{
GetEdits1(edit).ForEach(w =>
{
if (Dic.ContainsKey(w))
{
result.Add(w);
}
});
}
watch.Stop();
Console.WriteLine(watch.ElapsedMilliseconds);
return result;
}
//static WordCompare compare = new WordCompare();
/// <summary>
/// 获取建议词语
/// </summary>
/// <param name="word"></param>
/// <returns></returns>
public static List<string> GetSuggestWords(string word)
{
var result = GetEdits1(word).Where(w => Dic.ContainsKey(w)).ToList();
if (result.Count == 0)
{
result = GetEdits2(word);
if (result.Count == 0)
{
result.Add(word);
}
}
// 按先验概率排序
result = result.OrderByDescending(w => Dic.ContainsKey(w) ? Dic[w] : 1).ToList();
return result.Take(Math.Min(result.Count, 5)).ToList();
}
/// <summary>
/// 自定义比较
/// </summary>
class WordCompare : IComparer<string>
{
public int Compare(string x, string y)
{
var hash1 = Dic.ContainsKey(x) ? Dic[x] : 1;
var hash2 = Dic.ContainsKey(y) ? Dic[y] : 1;
return hash1.CompareTo(hash2);
}
}
}
}