diff --git a/kexts/drivers/bus/pci/driver.c b/kexts/drivers/bus/pci/driver.c index 8c0768c..afd7057 100644 --- a/kexts/drivers/bus/pci/driver.c +++ b/kexts/drivers/bus/pci/driver.c @@ -1,4 +1,5 @@ #include +#include #include static struct vm_cache pci_driver_cache = { @@ -70,24 +71,37 @@ kern_status_t pci_driver_unregister(struct pci_driver *driver) return driver_unregister(&driver->pci_base); } -static bool scan_device_id_list(const struct pci_device_id *device_ids, uint16_t vendor_id, uint16_t device_id) +static int scan_device_id_list(const struct pci_device_id *device_ids, const struct pci_device_id *query) { + int best_score = 0; + for (unsigned int i = 0; ; i++) { - if (device_ids[i].pci_device_id == PCI_NONE && device_ids[i].pci_vendor_id == PCI_NONE) { + if (device_ids[i].pci_device_id == PCI_NONE && device_ids[i].pci_vendor_id == PCI_NONE && device_ids[i].pci_class_id == PCI_CLASS_ANY && device_ids[i].pci_subclass_id == PCI_CLASS_ANY) { break; } - if (device_ids[i].pci_device_id == device_id && device_ids[i].pci_vendor_id == vendor_id) { - return true; + int score = 0; + + if (device_ids[i].pci_class_id == query->pci_class_id && device_ids[i].pci_subclass_id == query->pci_subclass_id && device_ids[i].pci_vendor_id == PCI_NONE && device_ids[i].pci_device_id == PCI_NONE) { + score = 1; + } + + if (device_ids[i].pci_device_id == query->pci_device_id && device_ids[i].pci_vendor_id == query->pci_vendor_id) { + score = 2; + } + + if (score > best_score) { + best_score = score; } } - return false; + return best_score; } -struct pci_driver *find_driver_for_pci_device(uint16_t vendor_id, uint16_t device_id) +struct pci_driver *find_driver_for_pci_device(const struct pci_device_id *query) { - struct pci_driver *out = NULL; + struct pci_driver *best_match = NULL; + int best_score = 0; unsigned long flags; spin_lock_irqsave(&pci_drivers_lock, &flags); @@ -98,12 +112,13 @@ struct pci_driver *find_driver_for_pci_device(uint16_t vendor_id, uint16_t devic continue; } - if (scan_device_id_list(device_ids, vendor_id, device_id)) { - out = driver; - break; + int score = scan_device_id_list(device_ids, query); + if (score > best_score) { + best_score = score; + best_match = driver; } } spin_unlock_irqrestore(&pci_drivers_lock, flags); - return out; + return best_match; } diff --git a/kexts/drivers/bus/pci/include/socks/pci.h b/kexts/drivers/bus/pci/include/socks/pci.h index 1df5784..305ac39 100644 --- a/kexts/drivers/bus/pci/include/socks/pci.h +++ b/kexts/drivers/bus/pci/include/socks/pci.h @@ -40,15 +40,20 @@ #define PCI_VALUE_PORT 0xCFC #define PCI_NONE 0xFFFF +#define PCI_CLASS_ANY 0xFF #define PCI_SUBSYSTEM_KEXT_ID "net.doorstuck.socks.pci" -#define PCI_DEVICE_ID(vid, did) { .pci_vendor_id = (vid), .pci_device_id = (did) } -#define PCI_DEVICE_ID_INVALID { .pci_vendor_id = PCI_NONE, .pci_device_id = PCI_NONE } +#define PCI_DEVICE_ID(vid, did) { .pci_vendor_id = (vid), .pci_device_id = (did), .pci_class_id = PCI_CLASS_ANY, .pci_subclass_id = PCI_CLASS_ANY } +#define PCI_CLASS_ID(cid, scid) { .pci_vendor_id = PCI_NONE, .pci_device_id = PCI_NONE, .pci_class_id = (cid), .pci_subclass_id = (scid) } +#define PCI_DEVICE_ID_FULL(vid, did, cid, scid) { .pci_vendor_id = (vid), .pci_device_id = (did), .pci_class_id = (cid), .pci_subclass_id = (scid) } +#define PCI_DEVICE_ID_INVALID PCI_DEVICE_ID_FULL(PCI_NONE, PCI_NONE, PCI_CLASS_ANY, PCI_CLASS_ANY) struct pci_device_id { uint16_t pci_vendor_id; uint16_t pci_device_id; + uint8_t pci_class_id; + uint8_t pci_subclass_id; }; struct pci_device { diff --git a/kexts/drivers/bus/pci/main.c b/kexts/drivers/bus/pci/main.c index 1cd8976..a0be667 100644 --- a/kexts/drivers/bus/pci/main.c +++ b/kexts/drivers/bus/pci/main.c @@ -18,7 +18,11 @@ static void init_pci_device(uint32_t device, uint16_t vendid, uint16_t devid, vo pci_get_slot(device), pci_get_func(device)); - printk("pci: found device %s (vend:%04x, dev:%04x)", dev->dev_name, vendid, devid); + uint8_t c = pci_read_field(device, PCI_REG_CLASS, 1); + uint8_t sc = pci_read_field(device, PCI_REG_SUBCLASS, 1); + + printk("pci: found device %s (vend:%04x, dev:%04x, class:%02x, subclass:%02x)", + dev->dev_name, vendid, devid, c, sc); struct pci_device *pci_dev = kmalloc(sizeof *pci_dev, VM_NORMAL); if (!pci_dev) { @@ -28,6 +32,8 @@ static void init_pci_device(uint32_t device, uint16_t vendid, uint16_t devid, vo pci_dev->pci_id.pci_vendor_id = vendid; pci_dev->pci_id.pci_device_id = devid; + pci_dev->pci_id.pci_class_id = c; + pci_dev->pci_id.pci_subclass_id = sc; pci_dev->pci_bus = pci_get_bus(device); pci_dev->pci_slot = pci_get_slot(device); pci_dev->pci_func = pci_get_func(device); @@ -38,7 +44,7 @@ static void init_pci_device(uint32_t device, uint16_t vendid, uint16_t devid, vo if we find a suitable driver for this device, that device will re-register it as theirs. */ device_register(dev, pci_driver, bus_device_base(pci_bus)); - struct pci_driver *driver = find_driver_for_pci_device(vendid, devid); + struct pci_driver *driver = find_driver_for_pci_device(&pci_dev->pci_id); if (driver && driver->probe) { driver->probe(driver, dev); } diff --git a/kexts/drivers/bus/pci/pci.h b/kexts/drivers/bus/pci/pci.h index 7406511..1e831fd 100644 --- a/kexts/drivers/bus/pci/pci.h +++ b/kexts/drivers/bus/pci/pci.h @@ -4,12 +4,14 @@ #include #include +struct pci_device_id; + typedef void (*pci_func_t)(uint32_t device, uint16_t vendor_id, uint16_t device_id, void *arg); extern struct driver *pci_driver; extern struct bus_device *pci_bus; -extern struct pci_driver *find_driver_for_pci_device(unsigned int vendor_id, unsigned int device_id); +extern struct pci_driver *find_driver_for_pci_device(const struct pci_device_id *query); extern kern_status_t init_pci_driver_cache(void); extern void pci_enumerate_devices(pci_func_t f, int type, void *arg);