前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >java实现Word2Vec计算语义相似度,AI入门,附源码,分步骤详细注释版

java实现Word2Vec计算语义相似度,AI入门,附源码,分步骤详细注释版

原创
作者头像
用户3992092
发布2024-08-12 20:53:15
1300
发布2024-08-12 20:53:15

1.准备工作,源码以及预训练文件

源码以及预训练文件比较大 下载地址https://pan.quark.cn/s/aeb85eaf95e2

2.核心代码Main函数

代码语言:java
复制
public class Main {

    public static void main(String[] args) throws IOException {

        // 输入的两个句子

        String input1 = "一寸光阴一寸金,寸金难买寸光阴。";

        String input2 = "光阴似箭";

        // 词向量模型文件路径

        String modelFile = "baike\_26g\_news\_13g\_novel\_229g.bin";

        // 读取词向量模型文件

        InputStream is = ClassLoader.getSystemResourceAsStream(modelFile);

        BufferedInputStream bufferedInputStream = new BufferedInputStream(Objects.requireNonNull(is), 1024 \* 1024);

        Word2Vec word2VecModel = WordVectorSerializer.readBinaryModel(bufferedInputStream, false, true);

        // 计算并输出两个句子的相似度

        System.out.println(sentenceSimilarity(input1, input2, word2VecModel));

    }

3.根据文本内容获取对应的词向量列表

代码语言:java
复制
    /**

     * 根据文本内容获取对应的词向量列表

     * @param text 文本内容

     * @param model 词向量模型

     * @return 词向量列表

     */

    private static List<INDArray> getWordVectors(String text, Word2Vec model) {

        // 将文本分词

        List<String> words = segmentWords(text.toLowerCase(Locale.getDefault()));

        // 创建一个列表来存储词向量

        List<INDArray> wordVectors = new ArrayList<>(words.size());

        for (String word : words) {

            if (model.hasWord(word)) {

                wordVectors.add(model.getWordVectorMatrix(word));

            } else {

                // 如果单词不在词汇表中,使用默认向量(这里使用零向量)

                int vectorSize = model.getLayerSize(); // 获取词向量的大小

                INDArray defaultVector = Nd4j.zeros(1, vectorSize); // 创建零向量

                wordVectors.add(defaultVector);

            }

        }

        return wordVectors;

    }

4.对句子进行分词处理

代码语言:java
复制
    /**

     * 对句子进行分词处理

     * @param sentence 待分词的句子

     * @return 分词后的词语列表

     */

    private static List<String> segmentWords(String sentence) {

        JiebaSegmenter segmenter = new JiebaSegmenter();

        return segmenter.sentenceProcess(sentence).stream()

                .filter(e -> !" ".equals(e) && !e.isEmpty())

                .collect(Collectors.toList());

    }

5.计算两个向量的余弦相似度

代码语言:java
复制
    /**

     * 计算两个向量的余弦相似度

     * @param vec1 第一个向量

     * @param vec2 第二个向量

     * @return 余弦相似度值

     */

    private static double cosineSimilarity(INDArray vec1, INDArray vec2) {

        // 计算两个向量的点积

        double dotProduct = vec1.mulRowVector(vec2).sumNumber().doubleValue();

        // 计算两个向量的模长

        double norm1 = vec1.norm2Number().doubleValue();

        double norm2 = vec2.norm2Number().doubleValue();

        // 计算余弦相似度

        return dotProduct / (norm1 \* norm2);

    }

6.计算两个句子的相似度

代码语言:java
复制
    /**

     * 计算两个句子的相似度

     * @param sentence1 第一个句子

     * @param sentence2 第二个句子

     * @param model 词向量模型

     * @return 句子相似度值

     */

    private static double sentenceSimilarity(String sentence1, String sentence2, Word2Vec model) {

        List<INDArray> vectors1 = getWordVectors(sentence1, model);

        List<INDArray> vectors2 = getWordVectors(sentence2, model);

        INDArray avgVector1 = getAverageVector(vectors1, model.getLayerSize());

        INDArray avgVector2 = getAverageVector(vectors2, model.getLayerSize());

        return cosineSimilarity(avgVector1, avgVector2);

    }

7.计算一组向量的平均值向量

代码语言:java
复制
    /**

     * 计算一组向量的平均值向量

     * @param vectors 向量列表

     * @param modelSize 向量维度大小

     * @return 平均向量

     */

    private static INDArray getAverageVector(List<INDArray> vectors, int modelSize) {

        INDArray sumVector = Nd4j.zeros(1, modelSize); // 创建一个与第一个向量形状相同的零向量

        for (INDArray vector : vectors) {

            sumVector.addiRowVector(vector); // 使用addi进行原地操作

        }

        INDArray indArray = sumVector.div(vectors.size());

        sumVector.close();

        return indArray; // 将总和除以向量数量以获得平均值

    }

}

8.依赖

代码语言:xml
复制
<?xml version="1.0" encoding="UTF-8"?>

<project xmlns="http://maven.apache.org/POM/4.0.0"

         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"

         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

    <modelVersion>4.0.0</modelVersion>



    <groupId>org.example</groupId>

    <artifactId>java-nlp</artifactId>

    <version>1.0-SNAPSHOT</version>



    <properties>

        <maven.compiler.source>17</maven.compiler.source>

        <maven.compiler.target>17</maven.compiler.target>

        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>

    </properties>

    <dependencies>

        <dependency>

            <groupId>org.deeplearning4j</groupId>

            <artifactId>deeplearning4j-nlp</artifactId>

            <version>1.0.0-M2.1</version>

        </dependency>

        <dependency>

            <groupId>nz.ac.waikato.cms.weka</groupId>

            <artifactId>weka-stable</artifactId>

            <version>3.8.6</version>

        </dependency>

        <dependency>

            <groupId>com.huaban</groupId>

            <artifactId>jieba-analysis</artifactId>

            <version>1.0.2</version>

        </dependency>

        <dependency>

            <groupId>ch.qos.logback</groupId>

            <artifactId>logback-classic</artifactId>

            <version>1.5.6</version>

        </dependency>

        <dependency>

            <groupId>org.nd4j</groupId>

            <artifactId>nd4j-native-platform</artifactId>

            <version>1.0.0-M2.1</version>

        </dependency>

    </dependencies>



</project>

9. 预训练文件

预训练文件
预训练文件

运行结果

运行结果
运行结果

觉得有用请点赞,有问题请在评论区留言

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1.准备工作,源码以及预训练文件
  • 2.核心代码Main函数
  • 3.根据文本内容获取对应的词向量列表
  • 4.对句子进行分词处理
  • 5.计算两个向量的余弦相似度
  • 6.计算两个句子的相似度
  • 7.计算一组向量的平均值向量
  • 8.依赖
  • 9. 预训练文件
  • 运行结果
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档