from django.db import models
# Create your models here.
class OptimisticLockException(Exception):
def __init__(self, conflict_ids):
self.conflict_ids = conflict_ids
super().__init__(f"乐观锁冲突,以下 ID 已被修改: {conflict_ids}")
class OptimisticQuerySet(models.QuerySet):
def with_lock(self, lock_fields=None):
self._lock_fields = lock_fields or ['updated_at']
self._snapshot = list(self.all()) # 保存旧快照
return self
def update_safely(self, **kwargs):
if not hasattr(self, '_snapshot'):
raise ValueError("必须先调用 with_lock() 获取旧快照")
lock_fields = getattr(self, '_lock_fields', ['updated_at'])
snapshot = {obj.id: obj for obj in self._snapshot}
ids = list(snapshot.keys())
current_objs = self.model.objects.filter(id__in=ids).only('id', *lock_fields)
conflict_ids = []
for obj in current_objs:
old = snapshot[obj.id]
for field in lock_fields:
if getattr(obj, field) != getattr(old, field):
conflict_ids.append(obj.id)
break
if conflict_ids:
raise OptimisticLockException(conflict_ids)
return self.model.objects.filter(id__in=ids).update(**kwargs)
class OptimisticManager(models.Manager):
def get_queryset(self):
return OptimisticQuerySet(self.model, using=self._db)
class Product(models.Model):
name = models.CharField(max_length=100)
stock = models.IntegerField()
updated_at = models.DateTimeField()
objects = OptimisticManager()
def __str__(self):
return f"{self.name} - Stock: {self.stock}"
# tests.py
from django.test import TransactionTestCase
from multiprocessing import Process, Manager
from django.utils import timezone
from django.db import connections
import time
def process_worker(i, results):
from .models import Product
# 每个进程要显式连接数据库(Django 多进程下需要关闭原连接)
connections.close_all()
try:
# 模拟读取 + 延迟操作
product = Product.objects.get(name='Banana')
old_time = product.updated_at
time.sleep(0.1 * i)
# 乐观锁条件更新
updated = Product.objects.filter(
id=product.id,
updated_at=old_time
).update(stock=200 + i, updated_at=timezone.now())
results[i] = "Success" if updated == 1 else "Conflict"
except Exception as e:
results[i] = f"Error: {str(e)}"
class OptimisticLockTest(TransactionTestCase):
def setUp(self):
from .models import Product
Product.objects.create(name='Banana', stock=100, updated_at=timezone.now())
def test_concurrent_update_multiprocessing(self):
manager = Manager()
results = manager.dict()
processes = []
for i in range(5):
p = Process(target=process_worker, args=(i, results))
processes.append(p)
p.start()
for p in processes:
p.join()
print("Results:", dict(results))
success_count = list(results.values()).count("Success")
self.assertEqual(success_count, 1)
import threading
from django.test import TestCase, TransactionTestCase
from django.utils import timezone
# Create your tests here.
from .models import Product, OptimisticLockException
import time
class OptimisticLockTest(TransactionTestCase):
def setUp(self):
Product.objects.create(name='Banana', stock=100, updated_at=timezone.now())
def thread_worker(self, thread_id, results):
try:
qs = Product.objects.filter(name='Banana').with_lock().update_safely(stock=100 + thread_id, updated_at=timezone.now())
# print(thread_id, qs.values('updated_at'), end='\n')
# time.sleep(0.01 * thread_id) # 模拟处理延迟
# qs
results[thread_id] = 'Success'
except OptimisticLockException:
results[thread_id] = 'Conflict'
def test_concurrent_updates(self):
thread_count = 5
threads = []
results = {}
for i in range(thread_count):
t = threading.Thread(target=self.thread_worker, args=(i, results))
threads.append(t)
t.start()
for t in threads:
t.join()
# 至少 1 成功,其他失败(冲突)
success_count = list(results.values()).count('Success')
conflict_count = list(results.values()).count('Conflict')
print("Results:", results)
self.assertEqual(success_count, 1)
self.assertEqual(conflict_count, thread_count - 1)