diff --git a/drivers/usb/gadget/function/f_mtp.c b/drivers/usb/gadget/function/f_mtp.c index 0aa87bca51e7..02705e1aa3af 100644 --- a/drivers/usb/gadget/function/f_mtp.c +++ b/drivers/usb/gadget/function/f_mtp.c @@ -74,6 +74,9 @@ #define MTP_RESPONSE_DEVICE_BUSY 0x2019 #define DRIVER_NAME "mtp" +static unsigned int mtp_rx_req_len = MTP_BULK_BUFFER_SIZE; +module_param(mtp_rx_req_len, uint, 0644); + static const char mtp_shortname[] = DRIVER_NAME "_usb"; struct mtp_dev { @@ -510,10 +513,27 @@ static int mtp_create_bulk_endpoints(struct mtp_dev *dev, req->complete = mtp_complete_in; mtp_req_put(dev, &dev->tx_idle, req); } + + /* + * The RX buffer should be aligned to EP max packet for + * some controllers. At bind time, we don't know the + * operational speed. Hence assuming super speed max + * packet size. + */ + if (mtp_rx_req_len % 1024) + mtp_rx_req_len = MTP_BULK_BUFFER_SIZE; + +retry_rx_alloc: for (i = 0; i < RX_REQ_MAX; i++) { - req = mtp_request_new(dev->ep_out, MTP_BULK_BUFFER_SIZE); - if (!req) - goto fail; + req = mtp_request_new(dev->ep_out, mtp_rx_req_len); + if (!req) { + if (mtp_rx_req_len <= MTP_BULK_BUFFER_SIZE) + goto fail; + for (--i; i >= 0; i--) + mtp_request_free(dev->rx_req[i], dev->ep_out); + mtp_rx_req_len = MTP_BULK_BUFFER_SIZE; + goto retry_rx_alloc; + } req->complete = mtp_complete_out; dev->rx_req[i] = req; } @@ -561,7 +581,7 @@ static ssize_t mtp_read(struct file *fp, char __user *buf, spin_lock_irq(&dev->lock); if (dev->ep_out->desc) { len = usb_ep_align_maybe(cdev->gadget, dev->ep_out, count); - if (len > MTP_BULK_BUFFER_SIZE) { + if (len > mtp_rx_req_len) { spin_unlock_irq(&dev->lock); return -EINVAL; } @@ -853,8 +873,8 @@ static void receive_file_work(struct work_struct *data) read_req = dev->rx_req[cur_buf]; cur_buf = (cur_buf + 1) % RX_REQ_MAX; - read_req->length = (count > MTP_BULK_BUFFER_SIZE - ? MTP_BULK_BUFFER_SIZE : count); + read_req->length = (count > mtp_rx_req_len + ? mtp_rx_req_len : count); dev->rx_done = 0; ret = usb_ep_queue(dev->ep_out, read_req, GFP_KERNEL); if (ret < 0) {