update乐观锁测试

Posted by Shallow Dreameron April 29, 2025
from django.db import models

# Create your models here.
class OptimisticLockException(Exception):
    def __init__(self, conflict_ids):
        self.conflict_ids = conflict_ids
        super().__init__(f"乐观锁冲突,以下 ID 已被修改: {conflict_ids}")


class OptimisticQuerySet(models.QuerySet):
    def with_lock(self, lock_fields=None):
        self._lock_fields = lock_fields or ['updated_at']
        self._snapshot = list(self.all())  # 保存旧快照
        return self

    def update_safely(self, **kwargs):
        if not hasattr(self, '_snapshot'):
            raise ValueError("必须先调用 with_lock() 获取旧快照")

        lock_fields = getattr(self, '_lock_fields', ['updated_at'])
        snapshot = {obj.id: obj for obj in self._snapshot}
        ids = list(snapshot.keys())

        current_objs = self.model.objects.filter(id__in=ids).only('id', *lock_fields)
        conflict_ids = []

        for obj in current_objs:
            old = snapshot[obj.id]
            for field in lock_fields:
                if getattr(obj, field) != getattr(old, field):
                    conflict_ids.append(obj.id)
                    break

        if conflict_ids:
            raise OptimisticLockException(conflict_ids)

        return self.model.objects.filter(id__in=ids).update(**kwargs)

class OptimisticManager(models.Manager):
    def get_queryset(self):
        return OptimisticQuerySet(self.model, using=self._db)


class Product(models.Model):
    name = models.CharField(max_length=100)
    stock = models.IntegerField()
    updated_at = models.DateTimeField()

    objects = OptimisticManager()

    def __str__(self):
        return f"{self.name} - Stock: {self.stock}"

# tests.py
from django.test import TransactionTestCase
from multiprocessing import Process, Manager

from django.utils import timezone

from django.db import connections
import time

def process_worker(i, results):
    from .models import Product
    # 每个进程要显式连接数据库(Django 多进程下需要关闭原连接)
    connections.close_all()

    try:
        # 模拟读取 + 延迟操作
        product = Product.objects.get(name='Banana')
        old_time = product.updated_at

        time.sleep(0.1 * i)

        # 乐观锁条件更新
        updated = Product.objects.filter(
            id=product.id,
            updated_at=old_time
        ).update(stock=200 + i, updated_at=timezone.now())

        results[i] = "Success" if updated == 1 else "Conflict"
    except Exception as e:
        results[i] = f"Error: {str(e)}"

class OptimisticLockTest(TransactionTestCase):
    def setUp(self):
        from .models import Product
        Product.objects.create(name='Banana', stock=100, updated_at=timezone.now())

    def test_concurrent_update_multiprocessing(self):
        manager = Manager()
        results = manager.dict()
        processes = []

        for i in range(5):
            p = Process(target=process_worker, args=(i, results))
            processes.append(p)
            p.start()

        for p in processes:
            p.join()

        print("Results:", dict(results))
        success_count = list(results.values()).count("Success")
        self.assertEqual(success_count, 1)

import threading

from django.test import TestCase, TransactionTestCase
from django.utils import timezone

# Create your tests here.
from .models import Product, OptimisticLockException

import time

class OptimisticLockTest(TransactionTestCase):
    def setUp(self):
        Product.objects.create(name='Banana', stock=100, updated_at=timezone.now())

    def thread_worker(self, thread_id, results):
        try:
            qs = Product.objects.filter(name='Banana').with_lock().update_safely(stock=100 + thread_id, updated_at=timezone.now())
            # print(thread_id, qs.values('updated_at'), end='\n')
            # time.sleep(0.01 * thread_id)  # 模拟处理延迟
            # qs
            results[thread_id] = 'Success'
        except OptimisticLockException:
            results[thread_id] = 'Conflict'

    def test_concurrent_updates(self):
        thread_count = 5
        threads = []
        results = {}

        for i in range(thread_count):
            t = threading.Thread(target=self.thread_worker, args=(i, results))
            threads.append(t)
            t.start()

        for t in threads:
            t.join()

        # 至少 1 成功,其他失败(冲突)
        success_count = list(results.values()).count('Success')
        conflict_count = list(results.values()).count('Conflict')

        print("Results:", results)
        self.assertEqual(success_count, 1)
        self.assertEqual(conflict_count, thread_count - 1)