LCOV - code coverage report
Current view: top level - include/linux - skmsg.h (source / functions) Hit Total Coverage
Test: landlock.info Lines: 0 16 0.0 %
Date: 2021-04-22 12:43:58 Functions: 0 1 0.0 %

          Line data    Source code
       1             : /* SPDX-License-Identifier: GPL-2.0 */
       2             : /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
       3             : 
       4             : #ifndef _LINUX_SKMSG_H
       5             : #define _LINUX_SKMSG_H
       6             : 
       7             : #include <linux/bpf.h>
       8             : #include <linux/filter.h>
       9             : #include <linux/scatterlist.h>
      10             : #include <linux/skbuff.h>
      11             : 
      12             : #include <net/sock.h>
      13             : #include <net/tcp.h>
      14             : #include <net/strparser.h>
      15             : 
      16             : #define MAX_MSG_FRAGS                   MAX_SKB_FRAGS
      17             : #define NR_MSG_FRAG_IDS                 (MAX_MSG_FRAGS + 1)
      18             : 
      19             : enum __sk_action {
      20             :         __SK_DROP = 0,
      21             :         __SK_PASS,
      22             :         __SK_REDIRECT,
      23             :         __SK_NONE,
      24             : };
      25             : 
      26             : struct sk_msg_sg {
      27             :         u32                             start;
      28             :         u32                             curr;
      29             :         u32                             end;
      30             :         u32                             size;
      31             :         u32                             copybreak;
      32             :         unsigned long                   copy;
      33             :         /* The extra two elements:
      34             :          * 1) used for chaining the front and sections when the list becomes
      35             :          *    partitioned (e.g. end < start). The crypto APIs require the
      36             :          *    chaining;
      37             :          * 2) to chain tailer SG entries after the message.
      38             :          */
      39             :         struct scatterlist              data[MAX_MSG_FRAGS + 2];
      40             : };
      41             : static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
      42             : 
      43             : /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
      44             : struct sk_msg {
      45             :         struct sk_msg_sg                sg;
      46             :         void                            *data;
      47             :         void                            *data_end;
      48             :         u32                             apply_bytes;
      49             :         u32                             cork_bytes;
      50             :         u32                             flags;
      51             :         struct sk_buff                  *skb;
      52             :         struct sock                     *sk_redir;
      53             :         struct sock                     *sk;
      54             :         struct list_head                list;
      55             : };
      56             : 
      57             : struct sk_psock_progs {
      58             :         struct bpf_prog                 *msg_parser;
      59             :         struct bpf_prog                 *skb_parser;
      60             :         struct bpf_prog                 *skb_verdict;
      61             : };
      62             : 
      63             : enum sk_psock_state_bits {
      64             :         SK_PSOCK_TX_ENABLED,
      65             : };
      66             : 
      67             : struct sk_psock_link {
      68             :         struct list_head                list;
      69             :         struct bpf_map                  *map;
      70             :         void                            *link_raw;
      71             : };
      72             : 
      73             : struct sk_psock_parser {
      74             :         struct strparser                strp;
      75             :         bool                            enabled;
      76             :         void (*saved_data_ready)(struct sock *sk);
      77             : };
      78             : 
      79             : struct sk_psock_work_state {
      80             :         struct sk_buff                  *skb;
      81             :         u32                             len;
      82             :         u32                             off;
      83             : };
      84             : 
      85             : struct sk_psock {
      86             :         struct sock                     *sk;
      87             :         struct sock                     *sk_redir;
      88             :         u32                             apply_bytes;
      89             :         u32                             cork_bytes;
      90             :         u32                             eval;
      91             :         struct sk_msg                   *cork;
      92             :         struct sk_psock_progs           progs;
      93             :         struct sk_psock_parser          parser;
      94             :         struct sk_buff_head             ingress_skb;
      95             :         struct list_head                ingress_msg;
      96             :         unsigned long                   state;
      97             :         struct list_head                link;
      98             :         spinlock_t                      link_lock;
      99             :         refcount_t                      refcnt;
     100             :         void (*saved_unhash)(struct sock *sk);
     101             :         void (*saved_close)(struct sock *sk, long timeout);
     102             :         void (*saved_write_space)(struct sock *sk);
     103             :         struct proto                    *sk_proto;
     104             :         struct sk_psock_work_state      work_state;
     105             :         struct work_struct              work;
     106             :         union {
     107             :                 struct rcu_head         rcu;
     108             :                 struct work_struct      gc;
     109             :         };
     110             : };
     111             : 
     112             : int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
     113             :                  int elem_first_coalesce);
     114             : int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
     115             :                  u32 off, u32 len);
     116             : void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
     117             : int sk_msg_free(struct sock *sk, struct sk_msg *msg);
     118             : int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
     119             : void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
     120             : void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
     121             :                                   u32 bytes);
     122             : 
     123             : void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
     124             : void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
     125             : 
     126             : int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
     127             :                               struct sk_msg *msg, u32 bytes);
     128             : int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
     129             :                              struct sk_msg *msg, u32 bytes);
     130             : 
     131             : static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
     132             : {
     133             :         WARN_ON(i == msg->sg.end && bytes);
     134             : }
     135             : 
     136             : static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
     137             : {
     138             :         if (psock->apply_bytes) {
     139             :                 if (psock->apply_bytes < bytes)
     140             :                         psock->apply_bytes = 0;
     141             :                 else
     142             :                         psock->apply_bytes -= bytes;
     143             :         }
     144             : }
     145             : 
     146           0 : static inline u32 sk_msg_iter_dist(u32 start, u32 end)
     147             : {
     148           0 :         return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
     149             : }
     150             : 
     151             : #define sk_msg_iter_var_prev(var)                       \
     152             :         do {                                            \
     153             :                 if (var == 0)                           \
     154             :                         var = NR_MSG_FRAG_IDS - 1;      \
     155             :                 else                                    \
     156             :                         var--;                          \
     157             :         } while (0)
     158             : 
     159             : #define sk_msg_iter_var_next(var)                       \
     160             :         do {                                            \
     161             :                 var++;                                  \
     162             :                 if (var == NR_MSG_FRAG_IDS)             \
     163             :                         var = 0;                        \
     164             :         } while (0)
     165             : 
     166             : #define sk_msg_iter_prev(msg, which)                    \
     167             :         sk_msg_iter_var_prev(msg->sg.which)
     168             : 
     169             : #define sk_msg_iter_next(msg, which)                    \
     170             :         sk_msg_iter_var_next(msg->sg.which)
     171             : 
     172             : static inline void sk_msg_clear_meta(struct sk_msg *msg)
     173             : {
     174             :         memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
     175             : }
     176             : 
     177             : static inline void sk_msg_init(struct sk_msg *msg)
     178             : {
     179             :         BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
     180             :         memset(msg, 0, sizeof(*msg));
     181             :         sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
     182             : }
     183             : 
     184             : static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
     185             :                                int which, u32 size)
     186             : {
     187             :         dst->sg.data[which] = src->sg.data[which];
     188             :         dst->sg.data[which].length  = size;
     189             :         dst->sg.size            += size;
     190             :         src->sg.size            -= size;
     191             :         src->sg.data[which].length -= size;
     192             :         src->sg.data[which].offset += size;
     193             : }
     194             : 
     195             : static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
     196             : {
     197             :         memcpy(dst, src, sizeof(*src));
     198             :         sk_msg_init(src);
     199             : }
     200             : 
     201             : static inline bool sk_msg_full(const struct sk_msg *msg)
     202             : {
     203             :         return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
     204             : }
     205             : 
     206           0 : static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
     207             : {
     208           0 :         return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
     209             : }
     210             : 
     211           0 : static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
     212             : {
     213           0 :         return &msg->sg.data[which];
     214             : }
     215             : 
     216           0 : static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
     217             : {
     218           0 :         return msg->sg.data[which];
     219             : }
     220             : 
     221             : static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
     222             : {
     223             :         return sg_page(sk_msg_elem(msg, which));
     224             : }
     225             : 
     226             : static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
     227             : {
     228             :         return msg->flags & BPF_F_INGRESS;
     229             : }
     230             : 
     231           0 : static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
     232             : {
     233           0 :         struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
     234             : 
     235           0 :         if (test_bit(msg->sg.start, &msg->sg.copy)) {
     236           0 :                 msg->data = NULL;
     237           0 :                 msg->data_end = NULL;
     238             :         } else {
     239           0 :                 msg->data = sg_virt(sge);
     240           0 :                 msg->data_end = msg->data + sge->length;
     241             :         }
     242           0 : }
     243             : 
     244             : static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
     245             :                                    u32 len, u32 offset)
     246             : {
     247             :         struct scatterlist *sge;
     248             : 
     249             :         get_page(page);
     250             :         sge = sk_msg_elem(msg, msg->sg.end);
     251             :         sg_set_page(sge, page, len, offset);
     252             :         sg_unmark_end(sge);
     253             : 
     254             :         __set_bit(msg->sg.end, &msg->sg.copy);
     255             :         msg->sg.size += len;
     256             :         sk_msg_iter_next(msg, end);
     257             : }
     258             : 
     259             : static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
     260             : {
     261             :         do {
     262             :                 if (copy_state)
     263             :                         __set_bit(i, &msg->sg.copy);
     264             :                 else
     265             :                         __clear_bit(i, &msg->sg.copy);
     266             :                 sk_msg_iter_var_next(i);
     267             :                 if (i == msg->sg.end)
     268             :                         break;
     269             :         } while (1);
     270             : }
     271             : 
     272             : static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
     273             : {
     274             :         sk_msg_sg_copy(msg, start, true);
     275             : }
     276             : 
     277             : static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
     278             : {
     279             :         sk_msg_sg_copy(msg, start, false);
     280             : }
     281             : 
     282             : static inline struct sk_psock *sk_psock(const struct sock *sk)
     283             : {
     284             :         return rcu_dereference_sk_user_data(sk);
     285             : }
     286             : 
     287             : static inline void sk_psock_queue_msg(struct sk_psock *psock,
     288             :                                       struct sk_msg *msg)
     289             : {
     290             :         list_add_tail(&msg->list, &psock->ingress_msg);
     291             : }
     292             : 
     293             : static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
     294             : {
     295             :         return psock ? list_empty(&psock->ingress_msg) : true;
     296             : }
     297             : 
     298             : static inline void sk_psock_report_error(struct sk_psock *psock, int err)
     299             : {
     300             :         struct sock *sk = psock->sk;
     301             : 
     302             :         sk->sk_err = err;
     303             :         sk->sk_error_report(sk);
     304             : }
     305             : 
     306             : struct sk_psock *sk_psock_init(struct sock *sk, int node);
     307             : 
     308             : int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
     309             : void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
     310             : void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
     311             : void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
     312             : void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
     313             : 
     314             : int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
     315             :                          struct sk_msg *msg);
     316             : 
     317             : static inline struct sk_psock_link *sk_psock_init_link(void)
     318             : {
     319             :         return kzalloc(sizeof(struct sk_psock_link),
     320             :                        GFP_ATOMIC | __GFP_NOWARN);
     321             : }
     322             : 
     323             : static inline void sk_psock_free_link(struct sk_psock_link *link)
     324             : {
     325             :         kfree(link);
     326             : }
     327             : 
     328             : struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
     329             : 
     330             : void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
     331             : 
     332             : static inline void sk_psock_cork_free(struct sk_psock *psock)
     333             : {
     334             :         if (psock->cork) {
     335             :                 sk_msg_free(psock->sk, psock->cork);
     336             :                 kfree(psock->cork);
     337             :                 psock->cork = NULL;
     338             :         }
     339             : }
     340             : 
     341             : static inline void sk_psock_update_proto(struct sock *sk,
     342             :                                          struct sk_psock *psock,
     343             :                                          struct proto *ops)
     344             : {
     345             :         /* Pairs with lockless read in sk_clone_lock() */
     346             :         WRITE_ONCE(sk->sk_prot, ops);
     347             : }
     348             : 
     349             : static inline void sk_psock_restore_proto(struct sock *sk,
     350             :                                           struct sk_psock *psock)
     351             : {
     352             :         sk->sk_prot->unhash = psock->saved_unhash;
     353             :         if (inet_csk_has_ulp(sk)) {
     354             :                 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
     355             :         } else {
     356             :                 sk->sk_write_space = psock->saved_write_space;
     357             :                 /* Pairs with lockless read in sk_clone_lock() */
     358             :                 WRITE_ONCE(sk->sk_prot, psock->sk_proto);
     359             :         }
     360             : }
     361             : 
     362             : static inline void sk_psock_set_state(struct sk_psock *psock,
     363             :                                       enum sk_psock_state_bits bit)
     364             : {
     365             :         set_bit(bit, &psock->state);
     366             : }
     367             : 
     368             : static inline void sk_psock_clear_state(struct sk_psock *psock,
     369             :                                         enum sk_psock_state_bits bit)
     370             : {
     371             :         clear_bit(bit, &psock->state);
     372             : }
     373             : 
     374             : static inline bool sk_psock_test_state(const struct sk_psock *psock,
     375             :                                        enum sk_psock_state_bits bit)
     376             : {
     377             :         return test_bit(bit, &psock->state);
     378             : }
     379             : 
     380             : static inline struct sk_psock *sk_psock_get(struct sock *sk)
     381             : {
     382             :         struct sk_psock *psock;
     383             : 
     384             :         rcu_read_lock();
     385             :         psock = sk_psock(sk);
     386             :         if (psock && !refcount_inc_not_zero(&psock->refcnt))
     387             :                 psock = NULL;
     388             :         rcu_read_unlock();
     389             :         return psock;
     390             : }
     391             : 
     392             : void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
     393             : void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
     394             : 
     395             : static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
     396             : {
     397             :         if (refcount_dec_and_test(&psock->refcnt))
     398             :                 sk_psock_drop(sk, psock);
     399             : }
     400             : 
     401             : static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
     402             : {
     403             :         if (psock->parser.enabled)
     404             :                 psock->parser.saved_data_ready(sk);
     405             :         else
     406             :                 sk->sk_data_ready(sk);
     407             : }
     408             : 
     409             : static inline void psock_set_prog(struct bpf_prog **pprog,
     410             :                                   struct bpf_prog *prog)
     411             : {
     412             :         prog = xchg(pprog, prog);
     413             :         if (prog)
     414             :                 bpf_prog_put(prog);
     415             : }
     416             : 
     417             : static inline int psock_replace_prog(struct bpf_prog **pprog,
     418             :                                      struct bpf_prog *prog,
     419             :                                      struct bpf_prog *old)
     420             : {
     421             :         if (cmpxchg(pprog, old, prog) != old)
     422             :                 return -ENOENT;
     423             : 
     424             :         if (old)
     425             :                 bpf_prog_put(old);
     426             : 
     427             :         return 0;
     428             : }
     429             : 
     430             : static inline void psock_progs_drop(struct sk_psock_progs *progs)
     431             : {
     432             :         psock_set_prog(&progs->msg_parser, NULL);
     433             :         psock_set_prog(&progs->skb_parser, NULL);
     434             :         psock_set_prog(&progs->skb_verdict, NULL);
     435             : }
     436             : 
     437             : int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
     438             : 
     439             : static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
     440             : {
     441             :         if (!psock)
     442             :                 return false;
     443             :         return psock->parser.enabled;
     444             : }
     445             : #endif /* _LINUX_SKMSG_H */

Generated by: LCOV version 1.14