diff --git a/drivers/misc/ntsync.c b/drivers/misc/ntsync.c index c8beb5504bcc..a2826dbff2f4 100644 --- a/drivers/misc/ntsync.c +++ b/drivers/misc/ntsync.c @@ -25,6 +25,7 @@ enum ntsync_type { NTSYNC_TYPE_SEM, + NTSYNC_TYPE_MUTEX, }; /* @@ -55,6 +56,10 @@ struct ntsync_obj { __u32 count; __u32 max; } sem; + struct { + __u32 count; + pid_t owner; + } mutex; } u; /* @@ -92,6 +97,7 @@ struct ntsync_q_entry { struct ntsync_q { struct task_struct *task; + __u32 owner; /* * Protected via atomic_try_cmpxchg(). Only the thread that wins the @@ -214,13 +220,17 @@ static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \ (obj)->dev_locked)) -static bool is_signaled(struct ntsync_obj *obj) +static bool is_signaled(struct ntsync_obj *obj, __u32 owner) { ntsync_assert_held(obj); switch (obj->type) { case NTSYNC_TYPE_SEM: return !!obj->u.sem.count; + case NTSYNC_TYPE_MUTEX: + if (obj->u.mutex.owner && obj->u.mutex.owner != owner) + return false; + return obj->u.mutex.count < UINT_MAX; } WARN(1, "bad object type %#x\n", obj->type); @@ -250,7 +260,7 @@ static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q, } for (i = 0; i < count; i++) { - if (!is_signaled(q->entries[i].obj)) { + if (!is_signaled(q->entries[i].obj, q->owner)) { can_wake = false; break; } @@ -264,6 +274,10 @@ static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q, case NTSYNC_TYPE_SEM: obj->u.sem.count--; break; + case NTSYNC_TYPE_MUTEX: + obj->u.mutex.count++; + obj->u.mutex.owner = q->owner; + break; } } wake_up_process(q->task); @@ -307,6 +321,30 @@ static void try_wake_any_sem(struct ntsync_obj *sem) } } +static void try_wake_any_mutex(struct ntsync_obj *mutex) +{ + struct ntsync_q_entry *entry; + + ntsync_assert_held(mutex); + lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX); + + list_for_each_entry(entry, &mutex->any_waiters, node) { + struct ntsync_q *q = entry->q; + int signaled = -1; + + if (mutex->u.mutex.count == UINT_MAX) + break; + if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner) + continue; + + if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) { + mutex->u.mutex.count++; + mutex->u.mutex.owner = q->owner; + wake_up_process(q->task); + } + } +} + /* * Actually change the semaphore state, returning -EOVERFLOW if it is made * invalid. @@ -451,6 +489,30 @@ static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp) return fd; } +static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp) +{ + struct ntsync_mutex_args args; + struct ntsync_obj *mutex; + int fd; + + if (copy_from_user(&args, argp, sizeof(args))) + return -EFAULT; + + if (!args.owner != !args.count) + return -EINVAL; + + mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX); + if (!mutex) + return -ENOMEM; + mutex->u.mutex.count = args.count; + mutex->u.mutex.owner = args.owner; + fd = ntsync_obj_get_fd(mutex); + if (fd < 0) + kfree(mutex); + + return fd; +} + static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd) { struct file *file = fget(fd); @@ -520,7 +582,7 @@ static int setup_wait(struct ntsync_device *dev, struct ntsync_q *q; __u32 i, j; - if (args->pad[0] || args->pad[1] || args->pad[2] || (args->flags & ~NTSYNC_WAIT_REALTIME)) + if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME)) return -EINVAL; if (args->count > NTSYNC_MAX_WAIT_COUNT) @@ -534,6 +596,7 @@ static int setup_wait(struct ntsync_device *dev, if (!q) return -ENOMEM; q->task = current; + q->owner = args->owner; atomic_set(&q->signaled, -1); q->all = all; q->count = count; @@ -576,6 +639,9 @@ static void try_wake_any_obj(struct ntsync_obj *obj) case NTSYNC_TYPE_SEM: try_wake_any_sem(obj); break; + case NTSYNC_TYPE_MUTEX: + try_wake_any_mutex(obj); + break; } } @@ -765,6 +831,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd, void __user *argp = (void __user *)parm; switch (cmd) { + case NTSYNC_IOC_CREATE_MUTEX: + return ntsync_create_mutex(dev, argp); case NTSYNC_IOC_CREATE_SEM: return ntsync_create_sem(dev, argp); case NTSYNC_IOC_WAIT_ALL: diff --git a/include/uapi/linux/ntsync.h b/include/uapi/linux/ntsync.h index d64420ffbcd7..bb7fb94f5856 100644 --- a/include/uapi/linux/ntsync.h +++ b/include/uapi/linux/ntsync.h @@ -15,6 +15,11 @@ struct ntsync_sem_args { __u32 max; }; +struct ntsync_mutex_args { + __u32 owner; + __u32 count; +}; + #define NTSYNC_WAIT_REALTIME 0x1 struct ntsync_wait_args { @@ -23,7 +28,8 @@ struct ntsync_wait_args { __u32 count; __u32 index; __u32 flags; - __u32 pad[3]; + __u32 owner; + __u32 pad[2]; }; #define NTSYNC_MAX_WAIT_COUNT 64 @@ -31,6 +37,7 @@ struct ntsync_wait_args { #define NTSYNC_IOC_CREATE_SEM _IOW ('N', 0x80, struct ntsync_sem_args) #define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args) #define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args) +#define NTSYNC_IOC_CREATE_MUTEX _IOW ('N', 0x84, struct ntsync_mutex_args) #define NTSYNC_IOC_SEM_RELEASE _IOWR('N', 0x81, __u32)