如何在 Django 中任意安全获取 request

在 Django 中,request 包含了一次请求的全部信息。后端处理逻辑经常需要用到 request 中的信息。比如, DRF 框架中想要随时能够获取到 request,或者将一些参数全局传递。Django 第三方 App 中有一些工具可以满足要求,但它们并不是安全可靠的。意思是,如果 Django 启动时,使用了多线程或协程,在获取 request 时,可能会发生错误。这显然是不能接受的。下面是一个安全可靠的实现版本,让你在任意位置都能获取 request 对象。

1. 实现

utils/local.py 文件

相关推荐

站点声明:本站部分内容转载自网络,作品版权归原作者及来源网站所有,任何内容转载、商业用途等均须联系原作者并注明来源。

相关侵权、举报、投诉及建议等,请发邮件至E-mail:service@mryunwei.com

回到顶部
  1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# -*- coding: utf-8 -*-

"""Thread-local/Greenlet-local objects

Thread-local/Greenlet-local objects support the management of
thread-local/greenlet-local data. If you have data that you want
to be local to a thread/greenlet, simply create a
thread-local/greenlet-local object and use its attributes:

  >>> mydata = Local()
  >>> mydata.number = 42
  >>> mydata.number
  42
  >>> hasattr(mydata, 'number')
  True
  >>> hasattr(mydata, 'username')
  False

  Reference :
  from threading import local
"""
try:
    from greenlet import getcurrent as get_ident
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident

__all__ = ["local", "Local"]

class Localbase(object):

    __slots__ = ('__storage__', '__ident_func__')

    def __new__(cls, *args, **kwargs):
        self = object.__new__(cls, *args, **kwargs)
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)
        return self

class Local(Localbase):

    def __iter__(self):
        ident = self.__ident_func__()
        return iter(self.__storage__[ident].items())

    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)

    def __getattr__(self, name):
        ident = self.__ident_func__()
        try:
            return self.__storage__[ident][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        if name in ('__storage__', '__ident_func__'):
            raise AttributeError(
                "%r object attribute '%s' is read-only"
                % (self.__class__.__name__, name))

        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

    def __delattr__(self, name):
        if name in ('__storage__', '__ident_func__'):
            raise AttributeError(
                "%r object attribute '%s' is read-only"
                % (self.__class__.__name__, name))

        ident = self.__ident_func__()
        try:
            del self.__storage__[ident][name]
            if len(self.__storage__[ident]) == 0:
                self.__release_local__()
        except KeyError:
            raise AttributeError(name)

local = Local()

if __name__ == '__main__':
    def display(id):
        1. import time
        local.id = id
        for i in range(3):
            print get_ident(), local.id, "\n"
            1. time.sleep(1)

    def gree(id):
        import gevent
        t = []
        for i in range(10):
            t.append(gevent.spawn(display, "%s-%s" % (id, i)))
        gevent.joinall(t)

    1. test one
    1. l1 = Local()
    1. l2 = Local()
    1. l.xxx = 1
    1. print l.xxx
    1. print l1.xxx
    1. print l2.xxx

    1. test two
    1. import gevent
    1. t = []
    1. for i in range(10):
    1.     g = gevent.spawn(display, i)
    1.     t.append(g)
    1. gevent.joinall(t)

    1. test three
    import threading
    t = []
    for i in range(10):
        t.append(threading.Thread(target=gree, args=(i,)))

    [th.start() for th in t]
    [th.join() for th in t]
# -*- coding: utf-8 -*-
from django.dispatch import Signal
from django.conf import settings
from utils.local import local

class AccessorSignal(Signal):
    allowed_receiver = 'utils.request_middlewares.RequestProvider'

    def __init__(self, providing_args=None):
        Signal.__init__(self, providing_args)

    def connect(self, receiver, sender=None, weak=True, dispatch_uid=None):
        receiver_name = '.'.join(
            [receiver.__class__.__module__, receiver.__class__.__name__]
        )
        if receiver_name != self.allowed_receiver:
            raise Exception(
                u"%s is not allowed to connect" % receiver_name)
        if not self.receivers:
            Signal.connect(self, receiver, sender, weak, dispatch_uid)

request_accessor = AccessorSignal()

class RequestProvider(object):
    """
    @summary: request事件接收者
    """

    def __init__(self):
        request_accessor.connect(self)

    def process_request(self, request):
        """
            这里可以在 request 上添加自定义的一些数据、处理逻辑
        """
        local.current_request = request
        return None

    def process_view(self, request, view_func, view_args, view_kwargs):
        your_args = view_kwargs.get("your_args", "")
        if not your_args:
            your_args = (request.POST.get('your_args') or
                      request.GET.get('your_args')) or ""
        request.your_args = your_args

    def process_response(self, request, response):
        if hasattr(local, 'current_request'):
            assert request is local.current_request
            del local.current_request

        return response

    def __call__(self, **kwargs):
        if not hasattr(local, 'current_request'):
            raise Exception(
                u"get_request can't be called in a new thread.")
        return local.current_request

def get_request():
    if hasattr(local, 'current_request'):
        return local.current_request
    else:
        raise Exception(u"get_request: current thread hasn't request.")

def get_x_request_id():
    x_request_id = ''
    http_request = get_request()
    if hasattr(http_request, 'META'):
        meta = http_request.META
        x_request_id = (meta.get('HTTP_X_REQUEST_ID', '')
                        if isinstance(meta, dict) else '')
    return x_request_id
MIDDLEWARE_CLASSES = (
    ...
    'utils.request_middlewares.RequestProvider',
    ... )
from utils.request_middlewares import local

def my_function():
    local.current_request
    pass