Springboot集成MongoDB以及常用方法封装

作者:码路星河 发布时间: 2025-12-08 阅读量:14 评论数:0

一、MongoDB简介

MongoDB是一款开源文档型数据库,专为Web应用提供高性能数据存储解决方案。

1、主要特性:

面向文档存储:使用BSON(Binary-JSON)格式,支持复杂数据类型。

模式自由:无固定Schema,灵活存储结构。

高性能:高效数据持久化和查询能力。

水平扩展:通过分片技术实现数据分布。

高可用性:副本集提供自动故障转移和数据冗余。

丰富的查询语言:支持文本搜索和地理位置查询。

跨平台支持:多语言SDK适配不同编程语言。

2、应用场景:

游戏‌:存储玩家信息、装备和积分,用内嵌文档高效管理,方便快速查询和更新。
社交网络‌:存用户资料、朋友圈内容,还能用地理位置索引做“附近的人”这类功能。
物联网‌:接入海量设备,存储设备信息和日志,方便做多维度的数据分析。
物流‌:存订单信息,状态变更用内嵌数组记录,一次查询就能看到所有更新。
日志分析‌:存系统日志、用户行为日志,结构不固定也能轻松应对。
内容管理‌:存文章、评论等半结构化数据,模式自由,扩展方便。
缓存‌:高性能,适合做缓存层,系统重启后还能从持久化缓存恢复。

3、对照图

在MongoDB中有几个比较核心的概念:文档、集合、数据库。以下是mongodb和关系型数据库的一个比照图:

SQL术语/概念

MongoDB术语/概念

database(数据库)

database(数据库)

table(表)

collection(集合)

row(数据记录行)

document(文档)

column(数据字段)

field(域)

index(索引)

index(索引)

二、SpringBoot集成

1、添加依赖

在SpringBoot项目的pom.xml文件中添加以下依赖:

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-data-mongodb</artifactId>
</dependency>

这个依赖将会自动配置MongoDB的相关组件,包括MongoDB驱动和Spring Data MongoDB。

2、配置MongoDB连接

在application.yml中添加MongoDB的连接配置信息:

spring:
  data:
    mongodb:
      host: 127.0.0.1
      port:27017
      database:mongo
      username:admin
      password:123456

三、常用方法封装

1、接口定义:

import org.bson.Document;
import org.springframework.data.geo.Point;
import org.springframework.data.geo.Polygon;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.index.IndexInfo;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;

import java.util.List;
import java.util.function.Function;

/**
 * @author 码路星河
 * @version 1.0
 * @Description MongoDB数据库操作接口定义
 */
public interface MongoDbService {

    /**
     * 检查集合是否存在
     * @param collectionName 集合名称
     * @return 存在返回true,否则返回false
     */
    boolean collectionExists(String collectionName);

    /**
     * 删除集合并重建
     * @param collectionName 集合名称
     */
    void recreateCollection(String collectionName);

    /**
     * 获取数据库统计信息
     * @return 数据库统计信息
     */
    Document getDatabaseStats();

    /**
     * 获取集合统计信息
     * @param collectionName 集合名称
     * @return 集合统计信息
     */
    Document getCollectionStats(String collectionName);

    /**
     * 插入单个文档
     * @param <T> 文档类型
     * @param object 要插入的对象
     * @param collectionName 集合名称
     */
    <T> void insert(T object, String collectionName);

    /**
     * 保存文档(存在则更新,不存在则插入)
     * @param <T> 文档类型
     * @param object 要保存的对象
     * @param collectionName 集合名称
     */
    <T> void save(T object, String collectionName);

    /**
     * 根据ID查询文档
     * @param <T> 文档类型
     * @param id 文档ID
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 查询到的文档对象,未找到返回null
     */
    <T> T findById(String id, Class<T> entityClass, String collectionName);

    /**
     * 查询集合中所有文档
     * @param <T> 文档类型
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 文档列表,集合为空时返回空列表
     */
    <T> List<T> findAll(Class<T> entityClass, String collectionName);

    /**
     * 根据查询条件获取文档列表
     * @param <T> 文档类型
     * @param query 查询条件
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 匹配查询条件的文档列表,无匹配时返回空列表
     */
    <T> List<T> find(Query query, Class<T> entityClass, String collectionName);

    /**
     * 批量根据ID查询文档
     * @param <T> 文档类型
     * @param ids ID列表
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 文档列表
     */
    <T> List<T> findByIds(List<String> ids, Class<T> entityClass, String collectionName);

    /**
     * 分页查询文档
     * @param <T> 文档类型
     * @param query 查询条件
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @param pageNum 页码(从1开始)
     * @param pageSize 每页大小
     * @return 当前页的文档列表,无数据时返回空列表
     */
    <T> List<T> findPage(Query query, Class<T> entityClass, String collectionName, int pageNum, int pageSize);

    /**
     * 更新匹配查询条件的文档
     * @param <T> 文档类型
     * @param query 查询条件
     * @param update 更新内容
     * @param entityClass 实体类
     * @param collectionName 集合名称
     */
    <T> void update(Query query, Update update, Class<T> entityClass, String collectionName);

    /**
     * 根据ID删除文档
     * @param <T> 文档类型
     * @param id 文档ID
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 删除的文档数量
     */
    <T> long deleteById(Object id, Class<T> entityClass, String collectionName);

    /**
     * 批量删除文档
     * @param <T> 文档类型
     * @param ids ID列表
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 删除的文档数量
     */
    <T> long batchDeleteByIds(List<Object> ids, Class<T> entityClass, String collectionName);

    /**
     * 删除匹配查询条件的文档
     * @param <T> 文档类型
     * @param query 查询条件
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 删除的文档数量
     */
    <T> long delete(Query query, Class<T> entityClass, String collectionName);

    /**
     * 统计匹配查询条件的文档数量
     * @param <T> 文档类型
     * @param query 查询条件(null表示统计所有文档)
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 匹配的文档数量
     */
    <T> long count(Query query, Class<T> entityClass, String collectionName);

    /**
     * 批量插入文档
     * @param <T> 文档类型
     * @param objects 要插入的对象列表
     * @param collectionName 集合名称
     */
    <T> void batchInsert(List<T> objects, String collectionName);

    /**
     * 创建2dsphere空间索引
     * @param collectionName 集合名称
     * @param fieldName 地理空间字段名
     */
    void create2DSphereIndex(String collectionName, String fieldName);

    /**
     * 创建复合索引
     * @param collectionName 集合名称
     * @param fieldNames 字段名数组
     */
    void createCompoundIndex(String collectionName, String... fieldNames);

    /**
     * 创建唯一索引
     * @param collectionName 集合名称
     * @param fieldName 字段名
     */
    void createUniqueIndex(String collectionName, String fieldName);

    /**
     * 删除索引
     * @param collectionName 集合名称
     * @param indexName 索引名称
     */
    void dropIndex(String collectionName, String indexName);

    /**
     * 获取集合的所有索引信息
     * @param collectionName 集合名称
     * @return 索引信息列表
     */
    List<IndexInfo> getIndexInfo(String collectionName);

    /**
     * 批量更新文档(按顺序执行)
     * @param <T> 文档类型
     * @param queries 查询条件列表
     * @param updates 更新内容列表
     * @param entityClass 实体类
     * @param collectionName 集合名称
     */
    <T> void batchUpdate(List<Query> queries, List<Update> updates, Class<T> entityClass, String collectionName);

    /**
     * 执行聚合查询
     * @param <T> 输出结果类型
     * @param aggregation 聚合管道
     * @param collectionName 集合名称
     * @param outputType 输出类型
     * @return 聚合结果
     */
    <T> AggregationResults<T> aggregate(Aggregation aggregation, String collectionName, Class<T> outputType);

    /**
     * 执行带分页的聚合查询
     * @param <T> 输出结果类型
     * @param aggregation 聚合管道
     * @param collectionName 集合名称
     * @param outputType 输出类型
     * @param pageNum 页码(从1开始)
     * @param pageSize 每页大小
     * @return 分页聚合结果
     */
    <T> AggregationResults<T> aggregatePage(Aggregation aggregation, String collectionName,
                                            Class<T> outputType, int pageNum, int pageSize);

    /**
     * 获取聚合查询的结果数量
     * @param aggregation 聚合管道
     * @param collectionName 集合名称
     * @return 结果总数
     */
    long countAggregate(Aggregation aggregation, String collectionName);


    /**
     * 查询指定距离内的文档(按距离排序)
     * @param <T> 文档类型
     * @param point 中心点坐标
     * @param distance 最大距离(米)
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 匹配的文档列表(按距离升序排序),无匹配时返回空列表
     */
    <T> List<T> near(Point point, double distance, String fieldName, Class<T> entityClass, String collectionName);

    /**
     * 查询多边形区域内的文档
     * @param <T> 文档类型
     * @param polygon 多边形区域(至少3个点)
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 位于多边形区域内的文档列表,无匹配时返回空列表
     */
    <T> List<T> within(Polygon polygon, String fieldName, Class<T> entityClass, String collectionName);

    /**
     * 查询多边形区域内并按距离排序的文档(分页)
     * @param <T> 文档类型
     * @param centerPoint 中心点坐标(用于距离计算)
     * @param polygon 多边形区域
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @param pageNum 页码(从1开始)
     * @param pageSize 每页大小
     * @return 当前页的文档列表(按距离升序排序),无匹配时返回空列表
     */
    <T> List<T> withinWithDistanceSort(Point centerPoint, Polygon polygon, String fieldName,
                                       Class<T> entityClass, String collectionName, int pageNum, int pageSize);

    /**
     * 查询多边形区域内并按距离排序的文档(不分页)
     * @param <T> 文档类型
     * @param centerPoint 中心点坐标(用于距离计算)
     * @param polygon 多边形区域
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 所有匹配的文档列表(按距离升序排序),无匹配时返回空列表
     */
    <T> List<T> withinWithDistanceSort(Point centerPoint, Polygon polygon, String fieldName,
                                       Class<T> entityClass, String collectionName);

    /**
     * 查询指定矩形区域内的文档
     * @param <T> 文档类型
     * @param southwest 西南角坐标
     * @param northeast 东北角坐标
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 位于矩形区域内的文档列表
     */
    <T> List<T> withinBox(Point southwest, Point northeast, String fieldName,
                          Class<T> entityClass, String collectionName);

    /**
     * 查询指定圆心和半径内的文档数量
     * @param point 圆心坐标
     * @param radius 半径(米)
     * @param fieldName 包含坐标的字段名
     * @param entityClass 实体类
     * @param collectionName 集合名称
     * @return 区域内文档数量
     */
    <T> long countNear(Point point, double radius, String fieldName,
                       Class<T> entityClass, String collectionName);


    /**
     * 清空集合中的所有文档(保留集合结构)
     * @param collectionName 集合名称
     */
    void clearCollection(String collectionName);

    /**
     * 在事务中执行多个操作
     * @param operations 操作列表
     * @param <T> 返回值类型
     * @return 操作结果
     */
    <T> T executeInTransaction(Function<MongoOperations, T> operations);

}

2、接口实现:

import com.mongodb.BasicDBObject;
import com.mongodb.client.result.DeleteResult;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.bson.Document;
import org.springframework.data.geo.*;
import org.springframework.data.mongodb.core.BulkOperations;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
import org.springframework.data.mongodb.core.index.*;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.NearQuery;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
 * @author 码路星河
 * @version 1.0
 * @Description MongoDB数据库操作实现类
 */
@Slf4j
@Service
public class MongoDbServiceImpl implements MongoDbService {

    @Resource
    private MongoTemplate mongoTemplate;

    @Override
    public boolean collectionExists(String collectionName) {
        return mongoTemplate.collectionExists(collectionName);
    }

    @Override
    public void recreateCollection(String collectionName) {
        if (mongoTemplate.collectionExists(collectionName)) {
            mongoTemplate.dropCollection(collectionName);
        }
        mongoTemplate.createCollection(collectionName);
    }

    @Override
    public Document getDatabaseStats() {
        return mongoTemplate.executeCommand("{ dbStats: 1 }");
    }

    @Override
    public Document getCollectionStats(String collectionName) {
        return mongoTemplate.executeCommand("{ collStats: \"" + collectionName + "\" }");
    }

    @Override
    public <T> void insert(T object, String collectionName) {
        mongoTemplate.insert(object, collectionName);
    }

    @Override
    public <T> void save(T object, String collectionName) {
        mongoTemplate.save(object, collectionName);
    }

    @Override
    public <T> T findById(String id, Class<T> entityClass, String collectionName) {
        return mongoTemplate.findById(id, entityClass, collectionName);
    }

    @Override
    public <T> List<T> findAll(Class<T> entityClass, String collectionName) {
        return mongoTemplate.findAll(entityClass, collectionName);
    }

    @Override
    public <T> List<T> find(Query query, Class<T> entityClass, String collectionName) {
        return mongoTemplate.find(query, entityClass, collectionName);
    }

    @Override
    public <T> List<T> findByIds(List<String> ids, Class<T> entityClass, String collectionName) {
        Query query = new Query(Criteria.where("id").in(ids));
        return mongoTemplate.find(query, entityClass, collectionName);
    }

    @Override
    public <T> List<T> findPage(Query query, Class<T> entityClass, String collectionName, int pageNum, int pageSize) {
        query.skip((pageNum - 1) * pageSize).limit(pageSize);
        return mongoTemplate.find(query, entityClass, collectionName);
    }

    @Override
    public <T> void update(Query query, Update update, Class<T> entityClass, String collectionName) {
        mongoTemplate.updateMulti(query, update, entityClass, collectionName);
    }

    @Override
    public <T> long deleteById(Object id, Class<T> entityClass, String collectionName) {
        Query query = new Query(Criteria.where("id").is(id));
        DeleteResult result = mongoTemplate.remove(query, entityClass, collectionName);
        long deleteCount = result.getDeletedCount();
        return deleteCount;
    }

    @Override
    public <T> long batchDeleteByIds(List<Object> ids, Class<T> entityClass, String collectionName) {
        Query query = new Query(Criteria.where("id").in(ids));
        DeleteResult result = mongoTemplate.remove(query, entityClass, collectionName);
        return result.getDeletedCount();
    }

    @Override
    public <T> long delete(Query query, Class<T> entityClass, String collectionName) {
        DeleteResult result = mongoTemplate.remove(query, entityClass, collectionName);
        long deleteCount = result.getDeletedCount();
        return deleteCount;
    }

    @Override
    public <T> long count(Query query, Class<T> entityClass, String collectionName) {
        return mongoTemplate.count(query, entityClass, collectionName);
    }

    @Override
    public <T> void batchInsert(List<T> objects, String collectionName) {
        mongoTemplate.insert(objects, collectionName);
    }

    @Override
    public void create2DSphereIndex(String collectionName, String fieldName) {
        IndexOperations indexOps = mongoTemplate.indexOps(collectionName);
        IndexDefinition indexDefinition = new GeospatialIndex(fieldName).typed(GeoSpatialIndexType.GEO_2DSPHERE);
        indexOps.ensureIndex(indexDefinition);
    }

    @Override
    public void createCompoundIndex(String collectionName, String... fieldNames) {
        if (fieldNames == null || fieldNames.length == 0) {
            return;
        }

        IndexOperations indexOps = mongoTemplate.indexOps(collectionName);

        // 使用Spring Data MongoDB的标准方式创建复合索引
        org.springframework.data.mongodb.core.index.Index index = new org.springframework.data.mongodb.core.index.Index();
        for (String fieldName : fieldNames) {
            index.on(fieldName, org.springframework.data.domain.Sort.Direction.ASC);
        }

        indexOps.ensureIndex(index);
    }

    @Override
    public void createUniqueIndex(String collectionName, String fieldName) {
        IndexOperations indexOps = mongoTemplate.indexOps(collectionName);
        indexOps.ensureIndex(new org.springframework.data.mongodb.core.index.Index()
                .on(fieldName, org.springframework.data.domain.Sort.Direction.ASC)
                .unique());
    }

    @Override
    public void dropIndex(String collectionName, String indexName) {
        IndexOperations indexOps = mongoTemplate.indexOps(collectionName);
        indexOps.dropIndex(indexName);
    }

    @Override
    public List<IndexInfo> getIndexInfo(String collectionName) {
        IndexOperations indexOps = mongoTemplate.indexOps(collectionName);
        return indexOps.getIndexInfo();
    }

    @Override
    public <T> void batchUpdate(List<Query> queries, List<Update> updates, Class<T> entityClass, String collectionName) {
        BulkOperations bulkOps = mongoTemplate.bulkOps(BulkOperations.BulkMode.ORDERED, entityClass, collectionName);
        for (int i = 0; i < queries.size(); i++) {
            bulkOps.updateOne(queries.get(i), updates.get(i));
        }
        bulkOps.execute();
    }

    @Override
    public <T> AggregationResults<T> aggregate(Aggregation aggregation, String collectionName, Class<T> outputType) {
        return mongoTemplate.aggregate(aggregation, collectionName, outputType);
    }

    @Override
    public <T> AggregationResults<T> aggregatePage(Aggregation aggregation, String collectionName,
                                                   Class<T> outputType, int pageNum, int pageSize) {
        // 获取原始聚合管道操作
        List<AggregationOperation> operations = new ArrayList<>();

        // 通过getPipeline()方法获取聚合操作列表
        aggregation.getPipeline().getOperations().forEach(op -> operations.add(op));

        // 添加分页操作
        operations.add(Aggregation.skip((long) (pageNum - 1) * pageSize));
        operations.add(Aggregation.limit(pageSize));

        Aggregation paginatedAggregation = Aggregation.newAggregation(operations);
        return mongoTemplate.aggregate(paginatedAggregation, collectionName, outputType);
    }

    @Override
    public long countAggregate(Aggregation aggregation, String collectionName) {
        try {
            // 获取原始聚合管道操作
            List<AggregationOperation> operations = new ArrayList<>();

            // 通过getPipeline()方法获取聚合操作列表
            aggregation.getPipeline().getOperations().forEach(op -> operations.add(op));

            // 添加计数阶段
            operations.add(Aggregation.count().as("totalCount"));

            Aggregation countAggregation = Aggregation.newAggregation(operations);
            AggregationResults<Document> results = mongoTemplate.aggregate(countAggregation, collectionName, Document.class);

            List<Document> mappedResults = results.getMappedResults();
            if (!mappedResults.isEmpty() && mappedResults.get(0).containsKey("totalCount")) {
                return mappedResults.get(0).getLong("totalCount");
            }
            return 0L;
        } catch (Exception e) {
            log.error("Error counting aggregate results", e);
            return 0L;
        }
    }

    @Override
    public <T> List<T> near(Point point, double distance, String fieldName, Class<T> entityClass, String collectionName) {
        NearQuery query = NearQuery.near(point).maxDistance(new Distance(distance, Metrics.KILOMETERS));
        Query geoQuery = new Query(Criteria.where(fieldName).nearSphere(point).maxDistance(distance/6378.1));
        return mongoTemplate.find(geoQuery, entityClass, collectionName);
    }

    @Override
    public <T> List<T> within(Polygon polygon, String fieldName, Class<T> entityClass, String collectionName) {
        Criteria criteria = Criteria.where(fieldName).within(polygon);
        return mongoTemplate.find(new Query(criteria), entityClass, collectionName);
    }

    @Override
    public <T> List<T> withinWithDistanceSort(Point centerPoint, Polygon polygon, String fieldName,
                                              Class<T> entityClass, String collectionName, int pageNum, int pageSize) {
        List<AggregationOperation> aggregations = new ArrayList<>();

        // 使用 geoNear 进行地理空间查询,并增加 distanceField 设置
        aggregations.add(Aggregation.geoNear(
                NearQuery.near(centerPoint)
                        .spherical(true)
                        .distanceMultiplier(6378137) // 地球半径(米),将弧度转换为米
                        .query(new Query(Criteria.where(fieldName).within(polygon))),
                "distance"
        ));
        // 移除冗余的 match 阶段
        aggregations.add(Aggregation.skip((pageNum - 1) * pageSize));
        aggregations.add(Aggregation.limit(pageSize));

        // 执行聚合查询并获取结果
        AggregationResults<T> results = mongoTemplate.aggregate(
                Aggregation.newAggregation(aggregations),
                collectionName,
                entityClass
        );

        return results.getMappedResults();
    }

    @Override
    public <T> List<T> withinWithDistanceSort(Point centerPoint, Polygon polygon, String fieldName,
                                              Class<T> entityClass, String collectionName) {
        return withinWithDistanceSort(centerPoint, polygon, fieldName, entityClass, collectionName, 1, Integer.MAX_VALUE);
    }

    @Override
    public <T> List<T> withinBox(Point southwest, Point northeast, String fieldName,
                                 Class<T> entityClass, String collectionName) {
        Criteria criteria = Criteria.where(fieldName).within(new Box(southwest, northeast));
        Query query = new Query(criteria);
        return mongoTemplate.find(query, entityClass, collectionName);
    }

    @Override
    public <T> long countNear(Point point, double radius, String fieldName,
                              Class<T> entityClass, String collectionName) {
        Query geoQuery = new Query(Criteria.where(fieldName)
                .nearSphere(point)
                .maxDistance(radius / 6378137)); // 转换为弧度
        return mongoTemplate.count(geoQuery, entityClass, collectionName);
    }

    @Override
    public void clearCollection(String collectionName) {
        mongoTemplate.getCollection(collectionName).deleteMany(new BasicDBObject());
    }

    @Override
    public <T> T executeInTransaction(Function<MongoOperations, T> operations) {
        return mongoTemplate.execute(session -> {
            try {
                return operations.apply(mongoTemplate);
            } catch (Exception e) {
                throw new RuntimeException("Transaction execution failed", e);
            }
        });
    }

}

评论