Root/drivers/vhost/net.c

1/* Copyright (C) 2009 Red Hat, Inc.
2 * Author: Michael S. Tsirkin <mst@redhat.com>
3 *
4 * This work is licensed under the terms of the GNU GPL, version 2.
5 *
6 * virtio-net server in host kernel.
7 */
8
9#include <linux/compat.h>
10#include <linux/eventfd.h>
11#include <linux/vhost.h>
12#include <linux/virtio_net.h>
13#include <linux/mmu_context.h>
14#include <linux/miscdevice.h>
15#include <linux/module.h>
16#include <linux/mutex.h>
17#include <linux/workqueue.h>
18#include <linux/rcupdate.h>
19#include <linux/file.h>
20#include <linux/slab.h>
21
22#include <linux/net.h>
23#include <linux/if_packet.h>
24#include <linux/if_arp.h>
25#include <linux/if_tun.h>
26#include <linux/if_macvlan.h>
27
28#include <net/sock.h>
29
30#include "vhost.h"
31
32/* Max number of bytes transferred before requeueing the job.
33 * Using this limit prevents one virtqueue from starving others. */
34#define VHOST_NET_WEIGHT 0x80000
35
36enum {
37    VHOST_NET_VQ_RX = 0,
38    VHOST_NET_VQ_TX = 1,
39    VHOST_NET_VQ_MAX = 2,
40};
41
42enum vhost_net_poll_state {
43    VHOST_NET_POLL_DISABLED = 0,
44    VHOST_NET_POLL_STARTED = 1,
45    VHOST_NET_POLL_STOPPED = 2,
46};
47
48struct vhost_net {
49    struct vhost_dev dev;
50    struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
51    struct vhost_poll poll[VHOST_NET_VQ_MAX];
52    /* Tells us whether we are polling a socket for TX.
53     * We only do this when socket buffer fills up.
54     * Protected by tx vq lock. */
55    enum vhost_net_poll_state tx_poll_state;
56};
57
58/* Pop first len bytes from iovec. Return number of segments used. */
59static int move_iovec_hdr(struct iovec *from, struct iovec *to,
60              size_t len, int iov_count)
61{
62    int seg = 0;
63    size_t size;
64    while (len && seg < iov_count) {
65        size = min(from->iov_len, len);
66        to->iov_base = from->iov_base;
67        to->iov_len = size;
68        from->iov_len -= size;
69        from->iov_base += size;
70        len -= size;
71        ++from;
72        ++to;
73        ++seg;
74    }
75    return seg;
76}
77
78/* Caller must have TX VQ lock */
79static void tx_poll_stop(struct vhost_net *net)
80{
81    if (likely(net->tx_poll_state != VHOST_NET_POLL_STARTED))
82        return;
83    vhost_poll_stop(net->poll + VHOST_NET_VQ_TX);
84    net->tx_poll_state = VHOST_NET_POLL_STOPPED;
85}
86
87/* Caller must have TX VQ lock */
88static void tx_poll_start(struct vhost_net *net, struct socket *sock)
89{
90    if (unlikely(net->tx_poll_state != VHOST_NET_POLL_STOPPED))
91        return;
92    vhost_poll_start(net->poll + VHOST_NET_VQ_TX, sock->file);
93    net->tx_poll_state = VHOST_NET_POLL_STARTED;
94}
95
96/* Expects to be always run from workqueue - which acts as
97 * read-size critical section for our kind of RCU. */
98static void handle_tx(struct vhost_net *net)
99{
100    struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
101    unsigned head, out, in, s;
102    struct msghdr msg = {
103        .msg_name = NULL,
104        .msg_namelen = 0,
105        .msg_control = NULL,
106        .msg_controllen = 0,
107        .msg_iov = vq->iov,
108        .msg_flags = MSG_DONTWAIT,
109    };
110    size_t len, total_len = 0;
111    int err, wmem;
112    size_t hdr_size;
113    struct socket *sock = rcu_dereference(vq->private_data);
114    if (!sock)
115        return;
116
117    wmem = atomic_read(&sock->sk->sk_wmem_alloc);
118    if (wmem >= sock->sk->sk_sndbuf) {
119        mutex_lock(&vq->mutex);
120        tx_poll_start(net, sock);
121        mutex_unlock(&vq->mutex);
122        return;
123    }
124
125    use_mm(net->dev.mm);
126    mutex_lock(&vq->mutex);
127    vhost_disable_notify(vq);
128
129    if (wmem < sock->sk->sk_sndbuf / 2)
130        tx_poll_stop(net);
131    hdr_size = vq->hdr_size;
132
133    for (;;) {
134        head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
135                     ARRAY_SIZE(vq->iov),
136                     &out, &in,
137                     NULL, NULL);
138        /* Nothing new? Wait for eventfd to tell us they refilled. */
139        if (head == vq->num) {
140            wmem = atomic_read(&sock->sk->sk_wmem_alloc);
141            if (wmem >= sock->sk->sk_sndbuf * 3 / 4) {
142                tx_poll_start(net, sock);
143                set_bit(SOCK_ASYNC_NOSPACE, &sock->flags);
144                break;
145            }
146            if (unlikely(vhost_enable_notify(vq))) {
147                vhost_disable_notify(vq);
148                continue;
149            }
150            break;
151        }
152        if (in) {
153            vq_err(vq, "Unexpected descriptor format for TX: "
154                   "out %d, int %d\n", out, in);
155            break;
156        }
157        /* Skip header. TODO: support TSO. */
158        s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, out);
159        msg.msg_iovlen = out;
160        len = iov_length(vq->iov, out);
161        /* Sanity check */
162        if (!len) {
163            vq_err(vq, "Unexpected header len for TX: "
164                   "%zd expected %zd\n",
165                   iov_length(vq->hdr, s), hdr_size);
166            break;
167        }
168        /* TODO: Check specific error and bomb out unless ENOBUFS? */
169        err = sock->ops->sendmsg(NULL, sock, &msg, len);
170        if (unlikely(err < 0)) {
171            vhost_discard_vq_desc(vq);
172            tx_poll_start(net, sock);
173            break;
174        }
175        if (err != len)
176            pr_err("Truncated TX packet: "
177                   " len %d != %zd\n", err, len);
178        vhost_add_used_and_signal(&net->dev, vq, head, 0);
179        total_len += len;
180        if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
181            vhost_poll_queue(&vq->poll);
182            break;
183        }
184    }
185
186    mutex_unlock(&vq->mutex);
187    unuse_mm(net->dev.mm);
188}
189
190/* Expects to be always run from workqueue - which acts as
191 * read-size critical section for our kind of RCU. */
192static void handle_rx(struct vhost_net *net)
193{
194    struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
195    unsigned head, out, in, log, s;
196    struct vhost_log *vq_log;
197    struct msghdr msg = {
198        .msg_name = NULL,
199        .msg_namelen = 0,
200        .msg_control = NULL, /* FIXME: get and handle RX aux data. */
201        .msg_controllen = 0,
202        .msg_iov = vq->iov,
203        .msg_flags = MSG_DONTWAIT,
204    };
205
206    struct virtio_net_hdr hdr = {
207        .flags = 0,
208        .gso_type = VIRTIO_NET_HDR_GSO_NONE
209    };
210
211    size_t len, total_len = 0;
212    int err;
213    size_t hdr_size;
214    struct socket *sock = rcu_dereference(vq->private_data);
215    if (!sock || skb_queue_empty(&sock->sk->sk_receive_queue))
216        return;
217
218    use_mm(net->dev.mm);
219    mutex_lock(&vq->mutex);
220    vhost_disable_notify(vq);
221    hdr_size = vq->hdr_size;
222
223    vq_log = unlikely(vhost_has_feature(&net->dev, VHOST_F_LOG_ALL)) ?
224        vq->log : NULL;
225
226    for (;;) {
227        head = vhost_get_vq_desc(&net->dev, vq, vq->iov,
228                     ARRAY_SIZE(vq->iov),
229                     &out, &in,
230                     vq_log, &log);
231        /* OK, now we need to know about added descriptors. */
232        if (head == vq->num) {
233            if (unlikely(vhost_enable_notify(vq))) {
234                /* They have slipped one in as we were
235                 * doing that: check again. */
236                vhost_disable_notify(vq);
237                continue;
238            }
239            /* Nothing new? Wait for eventfd to tell us
240             * they refilled. */
241            break;
242        }
243        /* We don't need to be notified again. */
244        if (out) {
245            vq_err(vq, "Unexpected descriptor format for RX: "
246                   "out %d, int %d\n",
247                   out, in);
248            break;
249        }
250        /* Skip header. TODO: support TSO/mergeable rx buffers. */
251        s = move_iovec_hdr(vq->iov, vq->hdr, hdr_size, in);
252        msg.msg_iovlen = in;
253        len = iov_length(vq->iov, in);
254        /* Sanity check */
255        if (!len) {
256            vq_err(vq, "Unexpected header len for RX: "
257                   "%zd expected %zd\n",
258                   iov_length(vq->hdr, s), hdr_size);
259            break;
260        }
261        err = sock->ops->recvmsg(NULL, sock, &msg,
262                     len, MSG_DONTWAIT | MSG_TRUNC);
263        /* TODO: Check specific error and bomb out unless EAGAIN? */
264        if (err < 0) {
265            vhost_discard_vq_desc(vq);
266            break;
267        }
268        /* TODO: Should check and handle checksum. */
269        if (err > len) {
270            pr_err("Discarded truncated rx packet: "
271                   " len %d > %zd\n", err, len);
272            vhost_discard_vq_desc(vq);
273            continue;
274        }
275        len = err;
276        err = memcpy_toiovec(vq->hdr, (unsigned char *)&hdr, hdr_size);
277        if (err) {
278            vq_err(vq, "Unable to write vnet_hdr at addr %p: %d\n",
279                   vq->iov->iov_base, err);
280            break;
281        }
282        len += hdr_size;
283        vhost_add_used_and_signal(&net->dev, vq, head, len);
284        if (unlikely(vq_log))
285            vhost_log_write(vq, vq_log, log, len);
286        total_len += len;
287        if (unlikely(total_len >= VHOST_NET_WEIGHT)) {
288            vhost_poll_queue(&vq->poll);
289            break;
290        }
291    }
292
293    mutex_unlock(&vq->mutex);
294    unuse_mm(net->dev.mm);
295}
296
297static void handle_tx_kick(struct work_struct *work)
298{
299    struct vhost_virtqueue *vq;
300    struct vhost_net *net;
301    vq = container_of(work, struct vhost_virtqueue, poll.work);
302    net = container_of(vq->dev, struct vhost_net, dev);
303    handle_tx(net);
304}
305
306static void handle_rx_kick(struct work_struct *work)
307{
308    struct vhost_virtqueue *vq;
309    struct vhost_net *net;
310    vq = container_of(work, struct vhost_virtqueue, poll.work);
311    net = container_of(vq->dev, struct vhost_net, dev);
312    handle_rx(net);
313}
314
315static void handle_tx_net(struct work_struct *work)
316{
317    struct vhost_net *net;
318    net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_TX].work);
319    handle_tx(net);
320}
321
322static void handle_rx_net(struct work_struct *work)
323{
324    struct vhost_net *net;
325    net = container_of(work, struct vhost_net, poll[VHOST_NET_VQ_RX].work);
326    handle_rx(net);
327}
328
329static int vhost_net_open(struct inode *inode, struct file *f)
330{
331    struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
332    int r;
333    if (!n)
334        return -ENOMEM;
335    n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
336    n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
337    r = vhost_dev_init(&n->dev, n->vqs, VHOST_NET_VQ_MAX);
338    if (r < 0) {
339        kfree(n);
340        return r;
341    }
342
343    vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, POLLOUT);
344    vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, POLLIN);
345    n->tx_poll_state = VHOST_NET_POLL_DISABLED;
346
347    f->private_data = n;
348
349    return 0;
350}
351
352static void vhost_net_disable_vq(struct vhost_net *n,
353                 struct vhost_virtqueue *vq)
354{
355    if (!vq->private_data)
356        return;
357    if (vq == n->vqs + VHOST_NET_VQ_TX) {
358        tx_poll_stop(n);
359        n->tx_poll_state = VHOST_NET_POLL_DISABLED;
360    } else
361        vhost_poll_stop(n->poll + VHOST_NET_VQ_RX);
362}
363
364static void vhost_net_enable_vq(struct vhost_net *n,
365                struct vhost_virtqueue *vq)
366{
367    struct socket *sock = vq->private_data;
368    if (!sock)
369        return;
370    if (vq == n->vqs + VHOST_NET_VQ_TX) {
371        n->tx_poll_state = VHOST_NET_POLL_STOPPED;
372        tx_poll_start(n, sock);
373    } else
374        vhost_poll_start(n->poll + VHOST_NET_VQ_RX, sock->file);
375}
376
377static struct socket *vhost_net_stop_vq(struct vhost_net *n,
378                    struct vhost_virtqueue *vq)
379{
380    struct socket *sock;
381
382    mutex_lock(&vq->mutex);
383    sock = vq->private_data;
384    vhost_net_disable_vq(n, vq);
385    rcu_assign_pointer(vq->private_data, NULL);
386    mutex_unlock(&vq->mutex);
387    return sock;
388}
389
390static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
391               struct socket **rx_sock)
392{
393    *tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
394    *rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
395}
396
397static void vhost_net_flush_vq(struct vhost_net *n, int index)
398{
399    vhost_poll_flush(n->poll + index);
400    vhost_poll_flush(&n->dev.vqs[index].poll);
401}
402
403static void vhost_net_flush(struct vhost_net *n)
404{
405    vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
406    vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
407}
408
409static int vhost_net_release(struct inode *inode, struct file *f)
410{
411    struct vhost_net *n = f->private_data;
412    struct socket *tx_sock;
413    struct socket *rx_sock;
414
415    vhost_net_stop(n, &tx_sock, &rx_sock);
416    vhost_net_flush(n);
417    vhost_dev_cleanup(&n->dev);
418    if (tx_sock)
419        fput(tx_sock->file);
420    if (rx_sock)
421        fput(rx_sock->file);
422    /* We do an extra flush before freeing memory,
423     * since jobs can re-queue themselves. */
424    vhost_net_flush(n);
425    kfree(n);
426    return 0;
427}
428
429static struct socket *get_raw_socket(int fd)
430{
431    struct {
432        struct sockaddr_ll sa;
433        char buf[MAX_ADDR_LEN];
434    } uaddr;
435    int uaddr_len = sizeof uaddr, r;
436    struct socket *sock = sockfd_lookup(fd, &r);
437    if (!sock)
438        return ERR_PTR(-ENOTSOCK);
439
440    /* Parameter checking */
441    if (sock->sk->sk_type != SOCK_RAW) {
442        r = -ESOCKTNOSUPPORT;
443        goto err;
444    }
445
446    r = sock->ops->getname(sock, (struct sockaddr *)&uaddr.sa,
447                   &uaddr_len, 0);
448    if (r)
449        goto err;
450
451    if (uaddr.sa.sll_family != AF_PACKET) {
452        r = -EPFNOSUPPORT;
453        goto err;
454    }
455    return sock;
456err:
457    fput(sock->file);
458    return ERR_PTR(r);
459}
460
461static struct socket *get_tap_socket(int fd)
462{
463    struct file *file = fget(fd);
464    struct socket *sock;
465    if (!file)
466        return ERR_PTR(-EBADF);
467    sock = tun_get_socket(file);
468    if (!IS_ERR(sock))
469        return sock;
470    sock = macvtap_get_socket(file);
471    if (IS_ERR(sock))
472        fput(file);
473    return sock;
474}
475
476static struct socket *get_socket(int fd)
477{
478    struct socket *sock;
479    /* special case to disable backend */
480    if (fd == -1)
481        return NULL;
482    sock = get_raw_socket(fd);
483    if (!IS_ERR(sock))
484        return sock;
485    sock = get_tap_socket(fd);
486    if (!IS_ERR(sock))
487        return sock;
488    return ERR_PTR(-ENOTSOCK);
489}
490
491static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
492{
493    struct socket *sock, *oldsock;
494    struct vhost_virtqueue *vq;
495    int r;
496
497    mutex_lock(&n->dev.mutex);
498    r = vhost_dev_check_owner(&n->dev);
499    if (r)
500        goto err;
501
502    if (index >= VHOST_NET_VQ_MAX) {
503        r = -ENOBUFS;
504        goto err;
505    }
506    vq = n->vqs + index;
507    mutex_lock(&vq->mutex);
508
509    /* Verify that ring has been setup correctly. */
510    if (!vhost_vq_access_ok(vq)) {
511        r = -EFAULT;
512        goto err_vq;
513    }
514    sock = get_socket(fd);
515    if (IS_ERR(sock)) {
516        r = PTR_ERR(sock);
517        goto err_vq;
518    }
519
520    /* start polling new socket */
521    oldsock = vq->private_data;
522    if (sock == oldsock)
523        goto done;
524
525    vhost_net_disable_vq(n, vq);
526    rcu_assign_pointer(vq->private_data, sock);
527    vhost_net_enable_vq(n, vq);
528done:
529    if (oldsock) {
530        vhost_net_flush_vq(n, index);
531        fput(oldsock->file);
532    }
533
534err_vq:
535    mutex_unlock(&vq->mutex);
536err:
537    mutex_unlock(&n->dev.mutex);
538    return r;
539}
540
541static long vhost_net_reset_owner(struct vhost_net *n)
542{
543    struct socket *tx_sock = NULL;
544    struct socket *rx_sock = NULL;
545    long err;
546    mutex_lock(&n->dev.mutex);
547    err = vhost_dev_check_owner(&n->dev);
548    if (err)
549        goto done;
550    vhost_net_stop(n, &tx_sock, &rx_sock);
551    vhost_net_flush(n);
552    err = vhost_dev_reset_owner(&n->dev);
553done:
554    mutex_unlock(&n->dev.mutex);
555    if (tx_sock)
556        fput(tx_sock->file);
557    if (rx_sock)
558        fput(rx_sock->file);
559    return err;
560}
561
562static int vhost_net_set_features(struct vhost_net *n, u64 features)
563{
564    size_t hdr_size = features & (1 << VHOST_NET_F_VIRTIO_NET_HDR) ?
565        sizeof(struct virtio_net_hdr) : 0;
566    int i;
567    mutex_lock(&n->dev.mutex);
568    if ((features & (1 << VHOST_F_LOG_ALL)) &&
569        !vhost_log_access_ok(&n->dev)) {
570        mutex_unlock(&n->dev.mutex);
571        return -EFAULT;
572    }
573    n->dev.acked_features = features;
574    smp_wmb();
575    for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
576        mutex_lock(&n->vqs[i].mutex);
577        n->vqs[i].hdr_size = hdr_size;
578        mutex_unlock(&n->vqs[i].mutex);
579    }
580    vhost_net_flush(n);
581    mutex_unlock(&n->dev.mutex);
582    return 0;
583}
584
585static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
586                unsigned long arg)
587{
588    struct vhost_net *n = f->private_data;
589    void __user *argp = (void __user *)arg;
590    u64 __user *featurep = argp;
591    struct vhost_vring_file backend;
592    u64 features;
593    int r;
594    switch (ioctl) {
595    case VHOST_NET_SET_BACKEND:
596        r = copy_from_user(&backend, argp, sizeof backend);
597        if (r < 0)
598            return r;
599        return vhost_net_set_backend(n, backend.index, backend.fd);
600    case VHOST_GET_FEATURES:
601        features = VHOST_FEATURES;
602        return copy_to_user(featurep, &features, sizeof features);
603    case VHOST_SET_FEATURES:
604        r = copy_from_user(&features, featurep, sizeof features);
605        if (r < 0)
606            return r;
607        if (features & ~VHOST_FEATURES)
608            return -EOPNOTSUPP;
609        return vhost_net_set_features(n, features);
610    case VHOST_RESET_OWNER:
611        return vhost_net_reset_owner(n);
612    default:
613        mutex_lock(&n->dev.mutex);
614        r = vhost_dev_ioctl(&n->dev, ioctl, arg);
615        vhost_net_flush(n);
616        mutex_unlock(&n->dev.mutex);
617        return r;
618    }
619}
620
621#ifdef CONFIG_COMPAT
622static long vhost_net_compat_ioctl(struct file *f, unsigned int ioctl,
623                   unsigned long arg)
624{
625    return vhost_net_ioctl(f, ioctl, (unsigned long)compat_ptr(arg));
626}
627#endif
628
629const static struct file_operations vhost_net_fops = {
630    .owner = THIS_MODULE,
631    .release = vhost_net_release,
632    .unlocked_ioctl = vhost_net_ioctl,
633#ifdef CONFIG_COMPAT
634    .compat_ioctl = vhost_net_compat_ioctl,
635#endif
636    .open = vhost_net_open,
637};
638
639static struct miscdevice vhost_net_misc = {
640    VHOST_NET_MINOR,
641    "vhost-net",
642    &vhost_net_fops,
643};
644
645int vhost_net_init(void)
646{
647    int r = vhost_init();
648    if (r)
649        goto err_init;
650    r = misc_register(&vhost_net_misc);
651    if (r)
652        goto err_reg;
653    return 0;
654err_reg:
655    vhost_cleanup();
656err_init:
657    return r;
658
659}
660module_init(vhost_net_init);
661
662void vhost_net_exit(void)
663{
664    misc_deregister(&vhost_net_misc);
665    vhost_cleanup();
666}
667module_exit(vhost_net_exit);
668
669MODULE_VERSION("0.0.1");
670MODULE_LICENSE("GPL v2");
671MODULE_AUTHOR("Michael S. Tsirkin");
672MODULE_DESCRIPTION("Host kernel accelerator for virtio net");
673

Archive Download this file



interactive