定义Model 首先要定义的是所有ORM映射的基类Model: class Model(dict, metaclass=ModelMetaclass): def __init__(self, **kw): super(Model, self).__init__(**kw) def __getattr__(self, key): try: return self[key] except KeyError: raise AttributeError(r"'Model' object has no attribute '%s'" % key) def __setattr__(self, key, value): self[key] = value def getValue(self, key): return getattr(self, key, None) def getValueOrDefault(self, key): value = getattr(self, key, None) if value is None: field = self.__mappings__[key] if field.default is not None: value = field.default() if callable(field.default) else field.default logging.debug('using default value for %s: %s' % (key, str(value))) setattr(self, key, value) return valueModel从dict继承,所以具备所有dict的功能,同时又实现了特殊方法__getattr__()和__setattr__(),因此又可以像引用普通字段那样写: >>> user['id']123>>> user.id123以及Field和各种Field子类: class Field(object): def __init__(self, name, column_type, primary_key, default): self.name = name self.column_type = column_type self.primary_key = primary_key self.default = default def __str__(self): return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)映射varchar的StringField: class StringField(Field): def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'): super().__init__(name, ddl, primary_key, default)注意到Model只是一个基类,如何将具体的子类如User的映射信息读取出来呢?答案就是通过metaclass:ModelMetaclass: class ModelMetaclass(type): def __new__(cls, name, bases, attrs): # 排除Model类本身: if name=='Model': return type.__new__(cls, name, bases, attrs) # 获取table名称: tableName = attrs.get('__table__', None) or name logging.info('found model: %s (table: %s)' % (name, tableName)) # 获取所有的Field和主键名: mappings = dict() fields = [] primaryKey = None for k, v in attrs.items(): if isinstance(v, Field): logging.info(' found mapping: %s ==> %s' % (k, v)) mappings[k] = v if v.primary_key: # 找到主键: if primaryKey: raise RuntimeError('Duplicate primary key for field: %s' % k) primaryKey = k else: fields.append(k) if not primaryKey: raise RuntimeError('Primary key not found.') for k in mappings.keys(): attrs.pop(k) escaped_fields = list(map(lambda f: '`%s`' % f, fields)) attrs['__mappings__'] = mappings # 保存属性和列的映射关系 attrs['__table__'] = tableName attrs['__primary_key__'] = primaryKey # 主键属性名 attrs['__fields__'] = fields # 除主键外的属性名 # 构造默认的SELECT, INSERT, UPDATE和DELETE语句: attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName) attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1)) attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey) attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey) return type.__new__(cls, name, bases, attrs)这样,任何继承自Model的类(比如User),会自动通过ModelMetaclass扫描映射关系,并存储到自身的类属性如__table__、__mappings__中。 然后,我们往Model类添加class方法,就可以让所有子类调用class方法: class Model(dict): ... @classmethod @asyncio.coroutine def find(cls, pk): ' find object by primary key. ' rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1) if len(rs) == 0: return None return cls(**rs[0])User类现在就可以通过类方法实现主键查找: user = yield from User.find('123')往Model类添加实例方法,就可以让所有子类调用实例方法: class Model(dict): ... @asyncio.coroutine def save(self): args = list(map(self.getValueOrDefault, self.__fields__)) args.append(self.getValueOrDefault(self.__primary_key__)) rows = yield from execute(self.__insert__, args) if rows != 1: logging.warn('failed to insert record: affected rows: %s' % rows)这样,就可以把一个User实例存入数据库: user = User(id=123, name='Michael')yield from user.save()最后一步是完善ORM,对于查找,我们可以实现以下方法: 以及update()和remove()方法。 所有这些方法都必须用@asyncio.coroutine装饰,变成一个协程。 调用时需要特别注意: user.save()没有任何效果,因为调用save()仅仅是创建了一个协程,并没有执行它。一定要用: yield from user.save()才真正执行了INSERT操作。 最后看看我们实现的ORM模块一共多少行代码?累计不到300多行。用Python写一个ORM是不是很容易呢?
|