diff --git a/drivers/usb/gadget/function/f_accessory.c b/drivers/usb/gadget/function/f_accessory.c index 832d80a5255b..b3b07b9cceb1 100644 --- a/drivers/usb/gadget/function/f_accessory.c +++ b/drivers/usb/gadget/function/f_accessory.c @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -74,6 +75,7 @@ struct acc_dev { struct usb_function function; struct usb_composite_dev *cdev; spinlock_t lock; + struct acc_dev_ref *ref; struct usb_ep *ep_in; struct usb_ep *ep_out; @@ -274,7 +276,14 @@ static struct usb_gadget_strings *acc_strings[] = { NULL, }; -static struct acc_dev *_acc_dev; +struct acc_dev_ref { + struct kref kref; + struct acc_dev *acc_dev; +}; + +static struct acc_dev_ref _acc_dev_ref = { + .kref = KREF_INIT(0), +}; struct acc_instance { struct usb_function_instance func_inst; @@ -283,11 +292,26 @@ struct acc_instance { static struct acc_dev *get_acc_dev(void) { - return _acc_dev; + struct acc_dev_ref *ref = &_acc_dev_ref; + + return kref_get_unless_zero(&ref->kref) ? ref->acc_dev : NULL; +} + +static void __put_acc_dev(struct kref *kref) +{ + struct acc_dev_ref *ref = container_of(kref, struct acc_dev_ref, kref); + struct acc_dev *dev = ref->acc_dev; + + ref->acc_dev = NULL; + kfree(dev); } static void put_acc_dev(struct acc_dev *dev) { + struct acc_dev_ref *ref = dev->ref; + + WARN_ON(ref->acc_dev != dev); + kref_put(&ref->kref, __put_acc_dev); } static inline struct acc_dev *func_to_dev(struct usb_function *f) @@ -1311,6 +1335,7 @@ static void acc_function_disable(struct usb_function *f) static int acc_setup(void) { + struct acc_dev_ref *ref = &_acc_dev_ref; struct acc_dev *dev; int ret; @@ -1331,7 +1356,9 @@ static int acc_setup(void) INIT_WORK(&dev->getprotocol_work, acc_getprotocol_work); INIT_WORK(&dev->sendstring_work, acc_sendstring_work); - _acc_dev = dev; + dev->ref = ref; + kref_init(&ref->kref); + ref->acc_dev = dev; ret = misc_register(&acc_device); if (ret) @@ -1359,12 +1386,11 @@ EXPORT_SYMBOL_GPL(acc_disconnect); static void acc_cleanup(void) { - struct acc_dev *dev = _acc_dev; + struct acc_dev *dev = get_acc_dev(); misc_deregister(&acc_device); put_acc_dev(dev); - kfree(dev); - _acc_dev = NULL; + put_acc_dev(dev); /* Pairs with kref_init() in acc_setup() */ } static struct acc_instance *to_acc_instance(struct config_item *item) {