下面是一个完整的 Django 单元测试代码示例,用于测试你这个“找出标签集合完全一致的 TagGroup
”逻辑,不使用原生 SQL,仅使用 Django ORM(依赖 PostgreSQL 的 ArrayAgg
)。
✅ 模型假设(用于测试)
# models.py
from django.db import models
class TagGroup(models.Model):
name = models.CharField(max_length=100)
class TagDetail(models.Model):
tag_group = models.ForeignKey(TagGroup, on_delete=models.CASCADE, related_name='details')
name = models.CharField(max_length=100)
✅ 查询函数(需先引入)
# utils.py 或 services.py
from django.db.models import ArrayAgg
from yourapp.models import TagGroup, TagDetail
def get_same_tag_groups(group_id):
target_names = list(
TagDetail.objects
.filter(tag_group_id=group_id)
.order_by('name')
.values_list('name', flat=True)
)
if not target_names:
return []
groups = (
TagGroup.objects
.annotate(tag_names=ArrayAgg('details__name', ordering='details__name'))
.filter(tag_names=target_names)
.values_list('id', flat=True)
)
return list(groups)
✅ 单元测试代码
# tests.py
from django.test import TestCase
from yourapp.models import TagGroup, TagDetail
from yourapp.utils import get_same_tag_groups
class TagGroupMatchingTest(TestCase):
def setUp(self):
# Group A: apple, banana
self.group_a = TagGroup.objects.create(name='Group A')
TagDetail.objects.bulk_create([
TagDetail(tag_group=self.group_a, name='apple'),
TagDetail(tag_group=self.group_a, name='banana'),
])
# Group B: banana, apple (same as A but unordered)
self.group_b = TagGroup.objects.create(name='Group B')
TagDetail.objects.bulk_create([
TagDetail(tag_group=self.group_b, name='banana'),
TagDetail(tag_group=self.group_b, name='apple'),
])
# Group C: apple only
self.group_c = TagGroup.objects.create(name='Group C')
TagDetail.objects.create(tag_group=self.group_c, name='apple')
# Group D: apple, banana, cherry (extra tag)
self.group_d = TagGroup.objects.create(name='Group D')
TagDetail.objects.bulk_create([
TagDetail(tag_group=self.group_d, name='apple'),
TagDetail(tag_group=self.group_d, name='banana'),
TagDetail(tag_group=self.group_d, name='cherry'),
])
def test_matching_groups(self):
matching_ids = get_same_tag_groups(self.group_a.id)
expected_ids = {self.group_a.id, self.group_b.id}
self.assertSetEqual(set(matching_ids), expected_ids)
def test_no_match_for_partial(self):
matching_ids = get_same_tag_groups(self.group_c.id)
self.assertSetEqual(set(matching_ids), {self.group_c.id})
def test_no_match_for_extra(self):
matching_ids = get_same_tag_groups(self.group_d.id)
self.assertSetEqual(set(matching_ids), {self.group_d.id})
✅ 如何运行测试
确保你使用的是 PostgreSQL,并已配置好 Django 数据库:
python manage.py test yourapp
✅ 可选:只查其他 group(不含自身)
如果你不想让结果包含自己(即只查“其它等价组”),在 filter
后加上:
.exclude(id=group_id)
需要我给你这个测试场景的迁移脚本、pytest
版本或更大规模压力测试脚本吗?
你的当前实现已经是 纯 ORM + 高效内存集合对比,适合中等规模(上万)数据使用。但如果你希望进一步优化性能、减少数据库扫描量或内存消耗,可以考虑如下几个方向。
✅ 当前优化程度回顾
你当前方案已经有以下优化点:
优化点 | 是否具备 | 说明 |
---|---|---|
只加载必要字段 | ✅ | .only('name') 限定字段 |
批量预加载 | ✅ | prefetch_related(Prefetch(...)) 避免 N+1 |
结构对比优化 | ✅ | 用 set() 减少排序等额外运算 |
提前排除不等长集合 | ✅ | len(names) != target_count |
🚀 可选优化方向(进阶)
1. ✅ 限制候选组范围(用 annotation 聚合过滤)
不用加载所有 group,只加载“可能匹配”的 group,提前用数据库过滤掉数量不对的 group。
示例:
from django.db.models import Count
# 先过滤出标签数量相等的组(减少 prefetch 的量)
groups = TagGroup.objects.annotate(tag_count=Count('details')).filter(tag_count=target_count)
groups = groups.prefetch_related(Prefetch('details', queryset=TagDetail.objects.only('name')))
这样你只 prefetch 数量相同的组,显著减少内存遍历和请求量。
2. ✅ 排序 + tuple 比较(避免 set 构造开销)
如果你保证标签名称在每组内唯一,且数量不大,你也可以将标签名排序为元组进行对比(加快 hash 比较):
target_signature = tuple(sorted(target_names))
for group in groups:
signature = tuple(sorted(d.name for d in group.details.all()))
if signature == target_signature:
matched_group_ids.append(group.id)
这种方式对比两个元组比对两个 set 更快,尤其是在标签数量较小时(5-30 个以内),速度可能快一倍左右。
3. ✅ 提前缓存目标签名
避免多次构建 target 数据结构,尤其你后续可能要多次用:
# 已有:target_names = set(...)
target_signature = tuple(sorted(target_names))
4. ✅ 将结果改为 generator(流式遍历)
如果你不需要一次性拿到所有 ID,而是逐个返回,可以将其改为生成器:
def get_groups_with_same_tag_names(group_id):
...
for group in groups:
...
if names == target_names:
yield group.id
然后调用处使用:
for group_id in get_groups_with_same_tag_names(...):
print(group_id)
这样可以节省大量内存(适用于百万组以上的极大数据量)。
5. ❌ 不推荐的优化:hash 进数据库
你可能会考虑给每个 TagGroup 的标签名集合加上 hash(如 md5(sorted(tag_names))
)来加速过滤,但这属于:
- 写时增加额外维护负担;
- 标签变更时同步维护复杂;
- 不如纯 Python 处理灵活。
仅在超大规模数据(千万以上)需要考虑。
✅ 推荐最终优化版(组合 Count + Prefetch)
from django.db.models import Count, Prefetch
from .models import TagGroup, TagDetail
def get_groups_with_same_tag_names(group_id):
target_names = set(
TagDetail.objects.filter(tag_group_id=group_id).values_list('name', flat=True)
)
target_count = len(target_names)
target_signature = tuple(sorted(target_names))
groups = TagGroup.objects.annotate(tag_count=Count('details')).filter(tag_count=target_count)
groups = groups.prefetch_related(
Prefetch('details', queryset=TagDetail.objects.only('name'))
)
matched_group_ids = []
for group in groups:
signature = tuple(sorted(d.name for d in group.details.all()))
if signature == target_signature:
matched_group_ids.append(group.id)
return matched_group_ids
如你希望把这封装为可链式使用的 ORM 方法(如 .with_same_tag_names_as(group_id)
),我也可以帮你做自定义 QuerySet
。
是否需要我帮你封装为更通用的工具?
是的!你现在这段代码已经是Django ORM 中较高效的实现,但如果要进一步加速,我们可以从以下几个方面做结构性优化,最大程度减少查询量、提高 PostgreSQL 执行效率 —— 特别是在数据集数量和标签组非常庞大的情况下。
✅ 你当前做得很好的地方
- 使用了
Subquery
+OuterRef
拿到每个数据集的最新版本; - 用
ArrayAgg(... ordering=..., distinct=True)
构造有序标签集合; - 避免了 Python 端循环,全部逻辑尽量在数据库完成。
🚀 优化方向(详细说明)
✅ 1. 用 Exists + Filtered Subquery
替代 ArrayAgg
聚合比较(更快)
问题:
你现在用 ArrayAgg
会让 PostgreSQL 做字符串数组构造再比较,这是典型的“聚合 + 过滤比对”,有一定开销。
替代方案:
改用 Exists()
+ Group by + Having
思路(即比对 tag_group 的标签集合是否相同)。
但由于你不想写 SQL,这里给出兼顾 ORM 和性能的可行方案:
✅ 2. 为每个标签组构造“标签签名”(只读优化,不存字段)
我们构造每个标签组的唯一标识(签名),作为可比较对象,避免 ArrayAgg:
# 为每个 TagGroup 构造签名 —— 排序标签名再连接
def get_tag_signature(tag_group_id):
return ','.join(
TagDetail.objects
.filter(tag_group_id=tag_group_id)
.order_by('name')
.values_list('name', flat=True)
)
标签集合
{x, y, z}
与{z, x, y}
生成的签名一样 ⇒ 可直接字符串对比!
✅ 3. 查询阶段避免聚合,对比签名(字符串比较快于数组)
替换原 ArrayAgg
为提前构造签名的方式:
from django.db.models import Subquery, OuterRef
from yourapp.models import Dataset, DatasetVersion, TagDetail
def get_datasets_with_same_latest_tag_signature(dataset_id):
# Step 1: 获取目标标签组 ID
latest_tag_group_id = (
DatasetVersion.objects
.filter(dataset_id=dataset_id)
.order_by('-version')
.values_list('tag_group_id', flat=True)
.first()
)
if not latest_tag_group_id:
return []
# Step 2: 获取目标签名
target_signature = get_tag_signature(latest_tag_group_id)
if not target_signature:
return []
# Step 3: 找出每个数据集的最新版本 tag_group
latest_tag_group_subquery = Subquery(
DatasetVersion.objects
.filter(dataset=OuterRef('pk'))
.order_by('-version')
.values('tag_group_id')[:1]
)
# Step 4: 比对标签签名(在 Python 侧一次性拉取 tag_group_ids 并映射签名)
# 更适合中等规模:几千组时性能优于数据库端聚合
from django.db.models import F
datasets = (
Dataset.objects
.annotate(latest_tag_group_id=latest_tag_group_subquery)
.values('id', 'latest_tag_group_id')
)
# 批量取所有 group_id 的签名(避免循环调用 get_tag_signature)
from collections import defaultdict
tag_group_ids = {d['latest_tag_group_id'] for d in datasets if d['latest_tag_group_id']}
group_signature_map = {
group_id: get_tag_signature(group_id)
for group_id in tag_group_ids
}
matching_ids = [
d['id']
for d in datasets
if group_signature_map.get(d['latest_tag_group_id']) == target_signature
]
return matching_ids
✅ 性能优势说明:
方法 | 优点 | 适用场景 |
---|---|---|
原始 ArrayAgg | 全 ORM,无手动签名逻辑 | 数据集较少(<1000)时足够快 |
签名映射法(推荐) | 极快,无聚合,避免重算 | 数据集成千上万时性能更优,避免大字段聚合 |
预计算字段(下一步) | 查询最快,仅字符串匹配 | 数据不频繁变更,可存 signature 字段并加索引 |
✅ Bonus(额外极致优化)
如你允许为 TagGroup
添加只读字段,可在 TagGroup
中添加一个 signature
字段,用于维护标签集合哈希:
class TagGroup(models.Model):
...
signature = models.CharField(max_length=500, db_index=True) # 可自动更新
并使用 post_save
信号或 save() 内部更新。
查询直接变成:
Dataset.objects.filter(latest_version__tag_group__signature=target_signature)
这个方法最快,适合读多写少场景。
✅ 小结
方案 | 速度 | Django 原生 | 适用场景 |
---|---|---|---|
ArrayAgg + ORM(你当前使用) |
🟡 中 | ✅ 是 | 中等数据集 |
🆕 TagGroup 签名对比(推荐) | 🟢 快 | ✅ 是 | 大数据集 |
添加字段 signature |
🟢 最快 | ✅ 是 | 写少读多场景 |
如果你希望我帮你封装第 2 个方案为可复用函数/管理器,或给出一个 benchmark 测试工具(对比耗时),我可以继续提供。是否需要?
# utils/tag_matching.py
import time
from django.contrib.postgres.aggregates import ArrayAgg
from django.db.models import Subquery, OuterRef, Q
from .models import Dataset, DatasetVersion, TagDetail
def build_tag_group_signature(tag_group_id):
"""构造标签组的签名(按 name 排序连接)"""
return ','.join(
TagDetail.objects
.filter(tag_group_id=tag_group_id)
.order_by('name')
.values_list('name', flat=True)
)
def get_datasets_with_same_latest_tag_signature(dataset_id):
"""
返回所有其最新版本的标签集合与给定 dataset 最新版本标签集合完全一致的 Dataset ID 列表。
不使用 ArrayAgg,使用标签签名字符串比较。
"""
latest_tag_group_id = (
DatasetVersion.objects
.filter(dataset_id=dataset_id)
.order_by('-version')
.values_list('tag_group_id', flat=True)
.first()
)
if not latest_tag_group_id:
return []
target_signature = build_tag_group_signature(latest_tag_group_id)
if not target_signature:
return []
latest_tag_group_subquery = Subquery(
DatasetVersion.objects
.filter(dataset=OuterRef('pk'))
.order_by('-version')
.values('tag_group_id')[:1]
)
annotated_datasets = Dataset.objects.annotate(latest_tag_group_id=latest_tag_group_subquery)
group_ids = set(
annotated_datasets
.values_list('latest_tag_group_id', flat=True)
)
group_signatures = {
gid: build_tag_group_signature(gid)
for gid in group_ids if gid
}
matched_ids = [
ds.id for ds in annotated_datasets
if group_signatures.get(ds.latest_tag_group_id) == target_signature
]
return matched_ids
def benchmark_tag_matching(dataset_id):
"""对比 ArrayAgg 方案与标签签名方案的性能"""
def match_with_arrayagg():
latest_tag_group_id = (
DatasetVersion.objects
.filter(dataset_id=dataset_id)
.order_by('-version')
.values_list('tag_group_id', flat=True)
.first()
)
if not latest_tag_group_id:
return []
target_tags = list(
TagDetail.objects
.filter(tag_group_id=latest_tag_group_id)
.order_by('name')
.values_list('name', flat=True)
)
if not target_tags:
return []
latest_subquery = Subquery(
DatasetVersion.objects
.filter(dataset=OuterRef('pk'))
.order_by('-version')
.values('tag_group_id')[:1]
)
result = Dataset.objects.annotate(
latest_tag_group_id=latest_subquery
).annotate(
latest_tag_names=ArrayAgg(
'versions__tag_group__details__name',
filter=Q(versions__tag_group_id=OuterRef('latest_tag_group_id')),
ordering='versions__tag_group__details__name',
distinct=True
)
).filter(latest_tag_names=target_tags).values_list('id', flat=True)
return list(result)
print("Benchmarking: Dataset ID =", dataset_id)
start = time.time()
result1 = get_datasets_with_same_latest_tag_signature(dataset_id)
end = time.time()
print(f"Signature Match Time: {end - start:.4f}s, Matches: {len(result1)}")
start = time.time()
result2 = match_with_arrayagg()
end = time.time()
print(f"ArrayAgg Match Time: {end - start:.4f}s, Matches: {len(result2)}")
if set(result1) == set(result2):
print("✅ Results match")
else:
print("❌ Results differ")