组与组员的查找

Posted by Shallow Dreameron June 14, 2025

下面是一个完整的 Django 单元测试代码示例,用于测试你这个“找出标签集合完全一致的 TagGroup”逻辑,不使用原生 SQL,仅使用 Django ORM(依赖 PostgreSQL 的 ArrayAgg)。


✅ 模型假设(用于测试)

# models.py
from django.db import models

class TagGroup(models.Model):
    name = models.CharField(max_length=100)

class TagDetail(models.Model):
    tag_group = models.ForeignKey(TagGroup, on_delete=models.CASCADE, related_name='details')
    name = models.CharField(max_length=100)

✅ 查询函数(需先引入)

# utils.py 或 services.py
from django.db.models import ArrayAgg
from yourapp.models import TagGroup, TagDetail

def get_same_tag_groups(group_id):
    target_names = list(
        TagDetail.objects
        .filter(tag_group_id=group_id)
        .order_by('name')
        .values_list('name', flat=True)
    )

    if not target_names:
        return []

    groups = (
        TagGroup.objects
        .annotate(tag_names=ArrayAgg('details__name', ordering='details__name'))
        .filter(tag_names=target_names)
        .values_list('id', flat=True)
    )

    return list(groups)

✅ 单元测试代码

# tests.py
from django.test import TestCase
from yourapp.models import TagGroup, TagDetail
from yourapp.utils import get_same_tag_groups

class TagGroupMatchingTest(TestCase):
    def setUp(self):
        # Group A: apple, banana
        self.group_a = TagGroup.objects.create(name='Group A')
        TagDetail.objects.bulk_create([
            TagDetail(tag_group=self.group_a, name='apple'),
            TagDetail(tag_group=self.group_a, name='banana'),
        ])

        # Group B: banana, apple (same as A but unordered)
        self.group_b = TagGroup.objects.create(name='Group B')
        TagDetail.objects.bulk_create([
            TagDetail(tag_group=self.group_b, name='banana'),
            TagDetail(tag_group=self.group_b, name='apple'),
        ])

        # Group C: apple only
        self.group_c = TagGroup.objects.create(name='Group C')
        TagDetail.objects.create(tag_group=self.group_c, name='apple')

        # Group D: apple, banana, cherry (extra tag)
        self.group_d = TagGroup.objects.create(name='Group D')
        TagDetail.objects.bulk_create([
            TagDetail(tag_group=self.group_d, name='apple'),
            TagDetail(tag_group=self.group_d, name='banana'),
            TagDetail(tag_group=self.group_d, name='cherry'),
        ])

    def test_matching_groups(self):
        matching_ids = get_same_tag_groups(self.group_a.id)
        expected_ids = {self.group_a.id, self.group_b.id}

        self.assertSetEqual(set(matching_ids), expected_ids)

    def test_no_match_for_partial(self):
        matching_ids = get_same_tag_groups(self.group_c.id)
        self.assertSetEqual(set(matching_ids), {self.group_c.id})

    def test_no_match_for_extra(self):
        matching_ids = get_same_tag_groups(self.group_d.id)
        self.assertSetEqual(set(matching_ids), {self.group_d.id})

✅ 如何运行测试

确保你使用的是 PostgreSQL,并已配置好 Django 数据库:

python manage.py test yourapp

✅ 可选:只查其他 group(不含自身)

如果你不想让结果包含自己(即只查“其它等价组”),在 filter 后加上:

.exclude(id=group_id)

需要我给你这个测试场景的迁移脚本、pytest 版本或更大规模压力测试脚本吗?

你的当前实现已经是 纯 ORM + 高效内存集合对比,适合中等规模(上万)数据使用。但如果你希望进一步优化性能减少数据库扫描量或内存消耗,可以考虑如下几个方向。


✅ 当前优化程度回顾

你当前方案已经有以下优化点:

优化点 是否具备 说明
只加载必要字段 .only('name') 限定字段
批量预加载 prefetch_related(Prefetch(...)) 避免 N+1
结构对比优化 set() 减少排序等额外运算
提前排除不等长集合 len(names) != target_count

🚀 可选优化方向(进阶)

1. ✅ 限制候选组范围(用 annotation 聚合过滤)

不用加载所有 group,只加载“可能匹配”的 group,提前用数据库过滤掉数量不对的 group。

示例:

from django.db.models import Count

# 先过滤出标签数量相等的组(减少 prefetch 的量)
groups = TagGroup.objects.annotate(tag_count=Count('details')).filter(tag_count=target_count)
groups = groups.prefetch_related(Prefetch('details', queryset=TagDetail.objects.only('name')))

这样你只 prefetch 数量相同的组,显著减少内存遍历和请求量


2. ✅ 排序 + tuple 比较(避免 set 构造开销)

如果你保证标签名称在每组内唯一,且数量不大,你也可以将标签名排序为元组进行对比(加快 hash 比较):

target_signature = tuple(sorted(target_names))

for group in groups:
    signature = tuple(sorted(d.name for d in group.details.all()))
    if signature == target_signature:
        matched_group_ids.append(group.id)

这种方式对比两个元组比对两个 set 更快,尤其是在标签数量较小时(5-30 个以内),速度可能快一倍左右。


3. ✅ 提前缓存目标签名

避免多次构建 target 数据结构,尤其你后续可能要多次用:

# 已有:target_names = set(...)
target_signature = tuple(sorted(target_names))

4. ✅ 将结果改为 generator(流式遍历)

如果你不需要一次性拿到所有 ID,而是逐个返回,可以将其改为生成器:

def get_groups_with_same_tag_names(group_id):
    ...
    for group in groups:
        ...
        if names == target_names:
            yield group.id

然后调用处使用:

for group_id in get_groups_with_same_tag_names(...):
    print(group_id)

这样可以节省大量内存(适用于百万组以上的极大数据量)。


5. ❌ 不推荐的优化:hash 进数据库

你可能会考虑给每个 TagGroup 的标签名集合加上 hash(如 md5(sorted(tag_names)))来加速过滤,但这属于:

  • 写时增加额外维护负担;
  • 标签变更时同步维护复杂;
  • 不如纯 Python 处理灵活。

仅在超大规模数据(千万以上)需要考虑。


✅ 推荐最终优化版(组合 Count + Prefetch)

from django.db.models import Count, Prefetch
from .models import TagGroup, TagDetail

def get_groups_with_same_tag_names(group_id):
    target_names = set(
        TagDetail.objects.filter(tag_group_id=group_id).values_list('name', flat=True)
    )
    target_count = len(target_names)
    target_signature = tuple(sorted(target_names))

    groups = TagGroup.objects.annotate(tag_count=Count('details')).filter(tag_count=target_count)
    groups = groups.prefetch_related(
        Prefetch('details', queryset=TagDetail.objects.only('name'))
    )

    matched_group_ids = []

    for group in groups:
        signature = tuple(sorted(d.name for d in group.details.all()))
        if signature == target_signature:
            matched_group_ids.append(group.id)

    return matched_group_ids

如你希望把这封装为可链式使用的 ORM 方法(如 .with_same_tag_names_as(group_id)),我也可以帮你做自定义 QuerySet

是否需要我帮你封装为更通用的工具?

是的!你现在这段代码已经是Django ORM 中较高效的实现,但如果要进一步加速,我们可以从以下几个方面做结构性优化,最大程度减少查询量、提高 PostgreSQL 执行效率 —— 特别是在数据集数量和标签组非常庞大的情况下


✅ 你当前做得很好的地方

  • 使用了 Subquery + OuterRef 拿到每个数据集的最新版本;
  • ArrayAgg(... ordering=..., distinct=True) 构造有序标签集合;
  • 避免了 Python 端循环,全部逻辑尽量在数据库完成。

🚀 优化方向(详细说明)

✅ 1. 用 Exists + Filtered Subquery 替代 ArrayAgg 聚合比较(更快)

问题:

你现在用 ArrayAgg 会让 PostgreSQL 做字符串数组构造再比较,这是典型的“聚合 + 过滤比对”,有一定开销。

替代方案:

改用 Exists() + Group by + Having 思路(即比对 tag_group 的标签集合是否相同)。

但由于你不想写 SQL,这里给出兼顾 ORM 和性能的可行方案:


✅ 2. 为每个标签组构造“标签签名”(只读优化,不存字段)

我们构造每个标签组的唯一标识(签名),作为可比较对象,避免 ArrayAgg:

# 为每个 TagGroup 构造签名 —— 排序标签名再连接
def get_tag_signature(tag_group_id):
    return ','.join(
        TagDetail.objects
        .filter(tag_group_id=tag_group_id)
        .order_by('name')
        .values_list('name', flat=True)
    )

标签集合 {x, y, z}{z, x, y} 生成的签名一样 ⇒ 可直接字符串对比!


✅ 3. 查询阶段避免聚合,对比签名(字符串比较快于数组)

替换原 ArrayAgg 为提前构造签名的方式:

from django.db.models import Subquery, OuterRef
from yourapp.models import Dataset, DatasetVersion, TagDetail

def get_datasets_with_same_latest_tag_signature(dataset_id):
    # Step 1: 获取目标标签组 ID
    latest_tag_group_id = (
        DatasetVersion.objects
        .filter(dataset_id=dataset_id)
        .order_by('-version')
        .values_list('tag_group_id', flat=True)
        .first()
    )

    if not latest_tag_group_id:
        return []

    # Step 2: 获取目标签名
    target_signature = get_tag_signature(latest_tag_group_id)

    if not target_signature:
        return []

    # Step 3: 找出每个数据集的最新版本 tag_group
    latest_tag_group_subquery = Subquery(
        DatasetVersion.objects
        .filter(dataset=OuterRef('pk'))
        .order_by('-version')
        .values('tag_group_id')[:1]
    )

    # Step 4: 比对标签签名(在 Python 侧一次性拉取 tag_group_ids 并映射签名)
    # 更适合中等规模:几千组时性能优于数据库端聚合

    from django.db.models import F

    datasets = (
        Dataset.objects
        .annotate(latest_tag_group_id=latest_tag_group_subquery)
        .values('id', 'latest_tag_group_id')
    )

    # 批量取所有 group_id 的签名(避免循环调用 get_tag_signature)
    from collections import defaultdict
    tag_group_ids = {d['latest_tag_group_id'] for d in datasets if d['latest_tag_group_id']}
    group_signature_map = {
        group_id: get_tag_signature(group_id)
        for group_id in tag_group_ids
    }

    matching_ids = [
        d['id']
        for d in datasets
        if group_signature_map.get(d['latest_tag_group_id']) == target_signature
    ]

    return matching_ids

✅ 性能优势说明:

方法 优点 适用场景
原始 ArrayAgg 全 ORM,无手动签名逻辑 数据集较少(<1000)时足够快
签名映射法(推荐) 极快,无聚合,避免重算 数据集成千上万时性能更优,避免大字段聚合
预计算字段(下一步) 查询最快,仅字符串匹配 数据不频繁变更,可存 signature 字段并加索引

✅ Bonus(额外极致优化)

如你允许为 TagGroup 添加只读字段,可在 TagGroup 中添加一个 signature 字段,用于维护标签集合哈希:

class TagGroup(models.Model):
    ...
    signature = models.CharField(max_length=500, db_index=True)  # 可自动更新

并使用 post_save 信号或 save() 内部更新。

查询直接变成:

Dataset.objects.filter(latest_version__tag_group__signature=target_signature)

这个方法最快,适合读多写少场景。


✅ 小结

方案 速度 Django 原生 适用场景
ArrayAgg + ORM(你当前使用) 🟡 中 ✅ 是 中等数据集
🆕 TagGroup 签名对比(推荐) 🟢 快 ✅ 是 大数据集
添加字段 signature 🟢 最快 ✅ 是 写少读多场景

如果你希望我帮你封装第 2 个方案为可复用函数/管理器,或给出一个 benchmark 测试工具(对比耗时),我可以继续提供。是否需要?

# utils/tag_matching.py
import time

from django.contrib.postgres.aggregates import ArrayAgg
from django.db.models import Subquery, OuterRef, Q
from .models import Dataset, DatasetVersion, TagDetail

def build_tag_group_signature(tag_group_id):
    """构造标签组的签名(按 name 排序连接)"""
    return ','.join(
        TagDetail.objects
        .filter(tag_group_id=tag_group_id)
        .order_by('name')
        .values_list('name', flat=True)
    )

def get_datasets_with_same_latest_tag_signature(dataset_id):
    """
    返回所有其最新版本的标签集合与给定 dataset 最新版本标签集合完全一致的 Dataset ID 列表。
    不使用 ArrayAgg,使用标签签名字符串比较。
    """
    latest_tag_group_id = (
        DatasetVersion.objects
        .filter(dataset_id=dataset_id)
        .order_by('-version')
        .values_list('tag_group_id', flat=True)
        .first()
    )

    if not latest_tag_group_id:
        return []

    target_signature = build_tag_group_signature(latest_tag_group_id)
    if not target_signature:
        return []

    latest_tag_group_subquery = Subquery(
        DatasetVersion.objects
        .filter(dataset=OuterRef('pk'))
        .order_by('-version')
        .values('tag_group_id')[:1]
    )

    annotated_datasets = Dataset.objects.annotate(latest_tag_group_id=latest_tag_group_subquery)

    group_ids = set(
        annotated_datasets
        .values_list('latest_tag_group_id', flat=True)
    )
    group_signatures = {
        gid: build_tag_group_signature(gid)
        for gid in group_ids if gid
    }

    matched_ids = [
        ds.id for ds in annotated_datasets
        if group_signatures.get(ds.latest_tag_group_id) == target_signature
    ]

    return matched_ids

def benchmark_tag_matching(dataset_id):
    """对比 ArrayAgg 方案与标签签名方案的性能"""
    def match_with_arrayagg():
        latest_tag_group_id = (
            DatasetVersion.objects
            .filter(dataset_id=dataset_id)
            .order_by('-version')
            .values_list('tag_group_id', flat=True)
            .first()
        )

        if not latest_tag_group_id:
            return []

        target_tags = list(
            TagDetail.objects
            .filter(tag_group_id=latest_tag_group_id)
            .order_by('name')
            .values_list('name', flat=True)
        )

        if not target_tags:
            return []

        latest_subquery = Subquery(
            DatasetVersion.objects
            .filter(dataset=OuterRef('pk'))
            .order_by('-version')
            .values('tag_group_id')[:1]
        )

        result = Dataset.objects.annotate(
            latest_tag_group_id=latest_subquery
        ).annotate(
            latest_tag_names=ArrayAgg(
                'versions__tag_group__details__name',
                filter=Q(versions__tag_group_id=OuterRef('latest_tag_group_id')),
                ordering='versions__tag_group__details__name',
                distinct=True
            )
        ).filter(latest_tag_names=target_tags).values_list('id', flat=True)

        return list(result)

    print("Benchmarking: Dataset ID =", dataset_id)

    start = time.time()
    result1 = get_datasets_with_same_latest_tag_signature(dataset_id)
    end = time.time()
    print(f"Signature Match Time: {end - start:.4f}s, Matches: {len(result1)}")

    start = time.time()
    result2 = match_with_arrayagg()
    end = time.time()
    print(f"ArrayAgg Match Time: {end - start:.4f}s, Matches: {len(result2)}")

    if set(result1) == set(result2):
        print("✅ Results match")
    else:
        print("❌ Results differ")