| #!/usr/bin/env python3 | 
 | # SPDX-License-Identifier: GPL-2.0 | 
 |  | 
 | import argparse | 
 | import errno | 
 | import logging | 
 | import socket | 
 | import struct | 
 | import time | 
 |  | 
 | import usb.core | 
 | import usb.util | 
 |  | 
 |  | 
 | def path_from_usb_dev(dev): | 
 |     """Takes a pyUSB device as argument and returns a string. | 
 |     The string is a Path representation of the position of the USB device on the USB bus tree. | 
 |  | 
 |     This path is used to find a USB device on the bus or all devices connected to a HUB. | 
 |     The path is made up of the number of the USB controller followed be the ports of the HUB tree.""" | 
 |     if dev.port_numbers: | 
 |         dev_path = ".".join(str(i) for i in dev.port_numbers) | 
 |         return f"{dev.bus}-{dev_path}" | 
 |     return "" | 
 |  | 
 |  | 
 | HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128 | 
 |  | 
 |  | 
 | class Forwarder: | 
 |     @staticmethod | 
 |     def _log_hexdump(data): | 
 |         if not logging.root.isEnabledFor(logging.TRACE): | 
 |             return | 
 |         L = 16 | 
 |         for c in range(0, len(data), L): | 
 |             chars = data[c : c + L] | 
 |             dump = " ".join(f"{x:02x}" for x in chars) | 
 |             printable = "".join(HEXDUMP_FILTER[x] for x in chars) | 
 |             line = f"{c:08x}  {dump:{L*3}s} |{printable:{L}s}|" | 
 |             logging.root.log(logging.TRACE, "%s", line) | 
 |  | 
 |     def __init__(self, server, vid, pid, path): | 
 |         self.stats = { | 
 |             "c2s packets": 0, | 
 |             "c2s bytes": 0, | 
 |             "s2c packets": 0, | 
 |             "s2c bytes": 0, | 
 |         } | 
 |         self.stats_logged = time.monotonic() | 
 |  | 
 |         def find_filter(dev): | 
 |             dev_path = path_from_usb_dev(dev) | 
 |             if path is not None: | 
 |                 return dev_path == path | 
 |             return True | 
 |  | 
 |         dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter) | 
 |         if dev is None: | 
 |             raise ValueError("Device not found") | 
 |  | 
 |         logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}") | 
 |  | 
 |         # dev.set_configuration() is not necessary since g_multi has only one | 
 |         usb9pfs = None | 
 |         # g_multi adds 9pfs as last interface | 
 |         cfg = dev.get_active_configuration() | 
 |         for intf in cfg: | 
 |             # we have to detach the usb-storage driver from multi gadget since | 
 |             # stall option could be set, which will lead to spontaneous port | 
 |             # resets and our transfers will run dead | 
 |             if intf.bInterfaceClass == 0x08: | 
 |                 if dev.is_kernel_driver_active(intf.bInterfaceNumber): | 
 |                     dev.detach_kernel_driver(intf.bInterfaceNumber) | 
 |  | 
 |             if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09: | 
 |                 usb9pfs = intf | 
 |         if usb9pfs is None: | 
 |             raise ValueError("Interface not found") | 
 |  | 
 |         logging.info(f"claiming interface:\n{usb9pfs}") | 
 |         usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber) | 
 |         ep_out = usb.util.find_descriptor( | 
 |             usb9pfs, | 
 |             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT, | 
 |         ) | 
 |         assert ep_out is not None | 
 |         ep_in = usb.util.find_descriptor( | 
 |             usb9pfs, | 
 |             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN, | 
 |         ) | 
 |         assert ep_in is not None | 
 |         logging.info("interface claimed") | 
 |  | 
 |         self.ep_out = ep_out | 
 |         self.ep_in = ep_in | 
 |         self.dev = dev | 
 |  | 
 |         # create and connect socket | 
 |         self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | 
 |         self.s.connect(server) | 
 |  | 
 |         logging.info("connected to server") | 
 |  | 
 |     def c2s(self): | 
 |         """forward a request from the USB client to the TCP server""" | 
 |         data = None | 
 |         while data is None: | 
 |             try: | 
 |                 logging.log(logging.TRACE, "c2s: reading") | 
 |                 data = self.ep_in.read(self.ep_in.wMaxPacketSize) | 
 |             except usb.core.USBTimeoutError: | 
 |                 logging.log(logging.TRACE, "c2s: reading timed out") | 
 |                 continue | 
 |             except usb.core.USBError as e: | 
 |                 if e.errno == errno.EIO: | 
 |                     logging.debug("c2s: reading failed with %s, retrying", repr(e)) | 
 |                     time.sleep(0.5) | 
 |                     continue | 
 |                 logging.error("c2s: reading failed with %s, aborting", repr(e)) | 
 |                 raise | 
 |         size = struct.unpack("<I", data[:4])[0] | 
 |         while len(data) < size: | 
 |             data += self.ep_in.read(size - len(data)) | 
 |         logging.log(logging.TRACE, "c2s: writing") | 
 |         self._log_hexdump(data) | 
 |         self.s.send(data) | 
 |         logging.debug("c2s: forwarded %i bytes", size) | 
 |         self.stats["c2s packets"] += 1 | 
 |         self.stats["c2s bytes"] += size | 
 |  | 
 |     def s2c(self): | 
 |         """forward a response from the TCP server to the USB client""" | 
 |         logging.log(logging.TRACE, "s2c: reading") | 
 |         data = self.s.recv(4) | 
 |         size = struct.unpack("<I", data[:4])[0] | 
 |         while len(data) < size: | 
 |             data += self.s.recv(size - len(data)) | 
 |         logging.log(logging.TRACE, "s2c: writing") | 
 |         self._log_hexdump(data) | 
 |         while data: | 
 |             written = self.ep_out.write(data) | 
 |             assert written > 0 | 
 |             data = data[written:] | 
 |         if size % self.ep_out.wMaxPacketSize == 0: | 
 |             logging.log(logging.TRACE, "sending zero length packet") | 
 |             self.ep_out.write(b"") | 
 |         logging.debug("s2c: forwarded %i bytes", size) | 
 |         self.stats["s2c packets"] += 1 | 
 |         self.stats["s2c bytes"] += size | 
 |  | 
 |     def log_stats(self): | 
 |         logging.info("statistics:") | 
 |         for k, v in self.stats.items(): | 
 |             logging.info(f"  {k+':':14s} {v}") | 
 |  | 
 |     def log_stats_interval(self, interval=5): | 
 |         if (time.monotonic() - self.stats_logged) < interval: | 
 |             return | 
 |  | 
 |         self.log_stats() | 
 |         self.stats_logged = time.monotonic() | 
 |  | 
 |  | 
 | def try_get_usb_str(dev, name): | 
 |     try: | 
 |         with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f: | 
 |             return f.read().strip() | 
 |     except FileNotFoundError: | 
 |         return None | 
 |  | 
 |  | 
 | def list_usb(args): | 
 |     vid, pid = [int(x, 16) for x in args.id.split(":", 1)] | 
 |  | 
 |     print("Bus | Addr | Manufacturer     | Product          | ID        | Path") | 
 |     print("--- | ---- | ---------------- | ---------------- | --------- | ----") | 
 |     for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid): | 
 |         path = path_from_usb_dev(dev) or "" | 
 |         manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown" | 
 |         product = try_get_usb_str(dev, "product") or "unknown" | 
 |         print( | 
 |             f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}" | 
 |         ) | 
 |  | 
 |  | 
 | def connect(args): | 
 |     vid, pid = [int(x, 16) for x in args.id.split(":", 1)] | 
 |  | 
 |     f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path) | 
 |  | 
 |     try: | 
 |         while True: | 
 |             f.c2s() | 
 |             f.s2c() | 
 |             f.log_stats_interval() | 
 |     finally: | 
 |         f.log_stats() | 
 |  | 
 |  | 
 | def main(): | 
 |     parser = argparse.ArgumentParser( | 
 |         description="Forward 9PFS requests from USB to TCP", | 
 |     ) | 
 |  | 
 |     parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device") | 
 |     parser.add_argument("--path", type=str, required=False, help="path of target device") | 
 |     parser.add_argument("-v", "--verbose", action="count", default=0) | 
 |  | 
 |     subparsers = parser.add_subparsers() | 
 |     subparsers.required = True | 
 |     subparsers.dest = "command" | 
 |  | 
 |     parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets") | 
 |     parser_list.set_defaults(func=list_usb) | 
 |  | 
 |     parser_connect = subparsers.add_parser( | 
 |         "connect", help="Forward messages between the usb9pfs gadget and the 9p server" | 
 |     ) | 
 |     parser_connect.set_defaults(func=connect) | 
 |     connect_group = parser_connect.add_argument_group() | 
 |     connect_group.required = True | 
 |     parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname") | 
 |     parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port") | 
 |  | 
 |     args = parser.parse_args() | 
 |  | 
 |     logging.TRACE = logging.DEBUG - 5 | 
 |     logging.addLevelName(logging.TRACE, "TRACE") | 
 |  | 
 |     if args.verbose >= 2: | 
 |         level = logging.TRACE | 
 |     elif args.verbose: | 
 |         level = logging.DEBUG | 
 |     else: | 
 |         level = logging.INFO | 
 |     logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s") | 
 |  | 
 |     args.func(args) | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     main() |