Python 对象拷贝机制

在 Python 编程中,掌握对象拷贝的细节至关重要。虽然对象拷贝看似是一个简单的操作,但在背后涉及到多个层次的实现机制。尤其是拷贝协议,它们为我们提供了强大的工具来控制对象的复制行为。通过合理地使用这些机制,我们不仅能够避免常见的编程陷阱,还能提升程序的效率和稳定性。

本文将带你一步步深入了解 Python 中的对象拷贝,包括 深浅拷贝拷贝协议 的实现,帮助你在实际开发中准确地选择合适的拷贝策略,避免潜在的问题。

1. 深浅拷贝

浅层拷贝(Shallow Copy):拷贝后的对象与原对象共享子对象,存在数据共享的情况。
深层拷贝(Deep Copy):拷贝后的对象与原对象完全独立,子对象也会被一并复制,不存在数据共享。

浅拷贝拷贝速度快,缺点是修改另外一个对象会导致当前对象发生变化。深拷贝拷贝速度慢,优点是对象相互独立。在 Python 中,使用 copy.copy 进行浅拷贝,使用 copy.deepcopy 进行深拷贝。

import copy


class Person:
    def __init__(self):
        self.data = [10, 20, 30]

    def __str__(self):
        return str(self.data)

# 1. 浅拷贝
def demo01():
    person1 = Person()
    print(person1)

    person2 = copy.copy(person1)
    print(person2)

    # 修改其中一个对象,另外一个对象发生变化
    person2.data[0] = 1000
    print(person1)


# 2. 深拷贝
def demo02():
    person1 = Person()
    print(person1)

    person2 = copy.deepcopy(person1)
    print(person2)

    # 两个对象相互独立,一个对象的修改不会影响到另外一个对象
    person2.data[0] = 1000
    print(person1)


if __name__ == '__main__':
    # demo01()
    demo02()

注意:对于不可变类型不涉及到深浅拷贝问题。

2. 拷贝协议

Python 拷贝协议是指 Python 对象在进行拷贝时,如何定义对象的拷贝行为的约定。具体来说,当你想要对自定义的对象进行拷贝时,Python 提供了一些特殊方法,使得对象可以控制它们如何被拷贝。

2.1 拷贝实现

Python 提供了两个特殊的魔术方法用于自定义在深拷贝过程中应该如何处理子对象以及哪些对象需要跳过拷贝,这两个方法分别是:

  • 浅拷贝:__copy__
  • 深拷贝:__deepcopy__

import copy


class Demo:

    def __init__(self):
        self.att1 = [10, 20, 30]
        self.att2 = ['a', 'b', 'c']

    def __copy__(self):
        """浅拷贝"""
        new_demo = Demo()
        new_demo.att1 = self.att1
        new_demo.att2 = self.att2
        return new_demo

    def __deepcopy__(self, memodict={}):
        """深拷贝"""
        new_demo = Demo()
        new_demo.att1 = copy.deepcopy(self.att1)
        new_demo.att2 = copy.deepcopy(self.att2)
        return new_demo

    def __str__(self):
        return f'demo: {id(self)}, demo.att1: {id(self.att1)}, demo.att2: {id(self.att2)}'


def demo():
    demo = Demo()
    print('原对象:', demo)

    # 浅拷贝
    new_demo = copy.copy(demo)  # 自动调用 Team 类的 __copy__ 方法进行浅拷贝
    print('浅拷贝:', new_demo)

    # 深拷贝
    new_demo = copy.deepcopy(demo)  # 动调用 Team 类的 __deepcopy__ 方法进行深拷贝
    print('深拷贝:', new_demo)


if __name__ == '__main__':
    demo()

程序执行结果:

原对象: demo: 1809016891920, demo.att1: 1809019859712, demo.att2: 1809019859840
浅拷贝: demo: 1809016933680, demo.att1: 1809019859712, demo.att2: 1809019859840
深拷贝: demo: 1809019909168, demo.att1: 1809019858240, demo.att2: 1809019859200

通过拷贝协议,我们可以自定义对象的拷贝行为。即在进行对象拷贝时,可以控制哪些属性进行拷贝,哪些属性避免拷贝,或者对部分属性进行深拷贝、部分属性进行浅拷贝。甚至可以选择禁止某些对象的拷贝行为,或者在拷贝过程中对特殊资源(如文件、数据库连接等)进行特定处理。

禁止对象拷贝示例:

import copy

class Node:
    def __init__(self, name):
        self.name = name

    def __deepcopy__(self, memodict={}):
        raise NotImplementedError('对象不支持拷贝')

    def __copy__(self):
        raise NotImplementedError('对象不支持拷贝')

def test():
    node = Node('node1')
    # 以下拷贝失败
    new_node = copy.copy(node)
    new_node = copy.deepcopy(node)

if __name__ == '__main__':
    test()

处理特殊资源拷贝示例:

import copy

class Demo:
    def __init__(self):
        self.file = open('demo.txt', 'w')
        self.attr = [10, 20, 30]

    def __deepcopy__(self, memodict={}):
        new_demo = Demo()
        new_demo.attr = copy.deepcopy(self.attr)
        new_demo.file = open('new_demo.txt', 'w')
        return new_demo

    def __copy__(self):
        new_demo = Demo()
        new_demo.attr = self.attr
        new_demo.file = open('new_demo.txt', 'w')
        return new_demo

def test():
    demo = Demo()
    new_demo = copy.copy(demo)
    new_demo = copy.deepcopy(demo)

if __name__ == '__main__':
    test()

2.2 循环引用

我们发现,深拷贝函数 __deepcopy__(self, memodict) 存在一个名为 memodict 的参数,它是 Python 深拷贝机制中的一个非常重要的”内部缓存”机制,作用是防止循环引用导致无限递归拷贝问题。

当对象之间存在相互引用的情况,在进行对象深拷贝时,会无限的递归拷贝,直至程序报错:

RecursionError: maximum recursion depth exceeded while calling a Python object

如何解决这个问题?

我们可以引入一个额外的容器 memodict 用于缓存或者记录已经进行拷贝过的对象,避免重复拷贝:

  1. 当进行当前对象拷贝时,先看看在 memodict 中是否有当前对象的拷贝,如果有,则直接返回拷贝,不再进行对象拷贝
  2. 如果没有,则进行当前对象拷贝。拷贝之后,将其添加到 memodict 中记录一下。

接下来,我们看个循环引用导致深拷贝发生错误的例子,以及我们是如何在代码中基于 memodict 来解决该拷贝问题。

import copy


class Node:
    def __init__(self, name):
        self.name = name
        self.next = None


    def __deepcopy__(self, memodict={}):
        """错误写法"""
        new_node = Node(self.name)
        new_node.next = copy.deepcopy(self.next)
        return new_node

    # def __deepcopy__(self, memodict={}):
    #     """正确写法"""
    #     # 如果当前对象已经被拷贝过,则直接返回
    #     if id(self) in memodict:
    #         return memodict[id(self)]
    #
    #     new_node = Node(self.name)
    #     # 将当前对象的拷贝存储到 memo 中
    #     memodict[id(self)] = new_node
    #
    #     new_node.next = copy.deepcopy(self.next, memodict)
    #     return new_node


def test():
    node1 = Node('node1')
    node2 = Node('node2')

    node1.next = node2
    node2.next = node1

    new_node = copy.deepcopy(node1)
    print(id(node1), id(node1.next))
    print(id(new_node), id(new_node.next))


if __name__ == '__main__':
    test()

未经允许不得转载:一亩三分地 » Python 对象拷贝机制
评论 (0)

7 + 6 =