Line data Source code
1 : /* SPDX-License-Identifier: GPL-2.0 */
2 : /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 :
4 : #ifndef _LINUX_SKMSG_H
5 : #define _LINUX_SKMSG_H
6 :
7 : #include <linux/bpf.h>
8 : #include <linux/filter.h>
9 : #include <linux/scatterlist.h>
10 : #include <linux/skbuff.h>
11 :
12 : #include <net/sock.h>
13 : #include <net/tcp.h>
14 : #include <net/strparser.h>
15 :
16 : #define MAX_MSG_FRAGS MAX_SKB_FRAGS
17 : #define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
18 :
19 : enum __sk_action {
20 : __SK_DROP = 0,
21 : __SK_PASS,
22 : __SK_REDIRECT,
23 : __SK_NONE,
24 : };
25 :
26 : struct sk_msg_sg {
27 : u32 start;
28 : u32 curr;
29 : u32 end;
30 : u32 size;
31 : u32 copybreak;
32 : unsigned long copy;
33 : /* The extra two elements:
34 : * 1) used for chaining the front and sections when the list becomes
35 : * partitioned (e.g. end < start). The crypto APIs require the
36 : * chaining;
37 : * 2) to chain tailer SG entries after the message.
38 : */
39 : struct scatterlist data[MAX_MSG_FRAGS + 2];
40 : };
41 : static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
42 :
43 : /* UAPI in filter.c depends on struct sk_msg_sg being first element. */
44 : struct sk_msg {
45 : struct sk_msg_sg sg;
46 : void *data;
47 : void *data_end;
48 : u32 apply_bytes;
49 : u32 cork_bytes;
50 : u32 flags;
51 : struct sk_buff *skb;
52 : struct sock *sk_redir;
53 : struct sock *sk;
54 : struct list_head list;
55 : };
56 :
57 : struct sk_psock_progs {
58 : struct bpf_prog *msg_parser;
59 : struct bpf_prog *skb_parser;
60 : struct bpf_prog *skb_verdict;
61 : };
62 :
63 : enum sk_psock_state_bits {
64 : SK_PSOCK_TX_ENABLED,
65 : };
66 :
67 : struct sk_psock_link {
68 : struct list_head list;
69 : struct bpf_map *map;
70 : void *link_raw;
71 : };
72 :
73 : struct sk_psock_parser {
74 : struct strparser strp;
75 : bool enabled;
76 : void (*saved_data_ready)(struct sock *sk);
77 : };
78 :
79 : struct sk_psock_work_state {
80 : struct sk_buff *skb;
81 : u32 len;
82 : u32 off;
83 : };
84 :
85 : struct sk_psock {
86 : struct sock *sk;
87 : struct sock *sk_redir;
88 : u32 apply_bytes;
89 : u32 cork_bytes;
90 : u32 eval;
91 : struct sk_msg *cork;
92 : struct sk_psock_progs progs;
93 : struct sk_psock_parser parser;
94 : struct sk_buff_head ingress_skb;
95 : struct list_head ingress_msg;
96 : unsigned long state;
97 : struct list_head link;
98 : spinlock_t link_lock;
99 : refcount_t refcnt;
100 : void (*saved_unhash)(struct sock *sk);
101 : void (*saved_close)(struct sock *sk, long timeout);
102 : void (*saved_write_space)(struct sock *sk);
103 : struct proto *sk_proto;
104 : struct sk_psock_work_state work_state;
105 : struct work_struct work;
106 : union {
107 : struct rcu_head rcu;
108 : struct work_struct gc;
109 : };
110 : };
111 :
112 : int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
113 : int elem_first_coalesce);
114 : int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
115 : u32 off, u32 len);
116 : void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len);
117 : int sk_msg_free(struct sock *sk, struct sk_msg *msg);
118 : int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg);
119 : void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes);
120 : void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
121 : u32 bytes);
122 :
123 : void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes);
124 : void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes);
125 :
126 : int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
127 : struct sk_msg *msg, u32 bytes);
128 : int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
129 : struct sk_msg *msg, u32 bytes);
130 :
131 : static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
132 : {
133 : WARN_ON(i == msg->sg.end && bytes);
134 : }
135 :
136 : static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
137 : {
138 : if (psock->apply_bytes) {
139 : if (psock->apply_bytes < bytes)
140 : psock->apply_bytes = 0;
141 : else
142 : psock->apply_bytes -= bytes;
143 : }
144 : }
145 :
146 0 : static inline u32 sk_msg_iter_dist(u32 start, u32 end)
147 : {
148 0 : return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
149 : }
150 :
151 : #define sk_msg_iter_var_prev(var) \
152 : do { \
153 : if (var == 0) \
154 : var = NR_MSG_FRAG_IDS - 1; \
155 : else \
156 : var--; \
157 : } while (0)
158 :
159 : #define sk_msg_iter_var_next(var) \
160 : do { \
161 : var++; \
162 : if (var == NR_MSG_FRAG_IDS) \
163 : var = 0; \
164 : } while (0)
165 :
166 : #define sk_msg_iter_prev(msg, which) \
167 : sk_msg_iter_var_prev(msg->sg.which)
168 :
169 : #define sk_msg_iter_next(msg, which) \
170 : sk_msg_iter_var_next(msg->sg.which)
171 :
172 : static inline void sk_msg_clear_meta(struct sk_msg *msg)
173 : {
174 : memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy));
175 : }
176 :
177 : static inline void sk_msg_init(struct sk_msg *msg)
178 : {
179 : BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
180 : memset(msg, 0, sizeof(*msg));
181 : sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
182 : }
183 :
184 : static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
185 : int which, u32 size)
186 : {
187 : dst->sg.data[which] = src->sg.data[which];
188 : dst->sg.data[which].length = size;
189 : dst->sg.size += size;
190 : src->sg.size -= size;
191 : src->sg.data[which].length -= size;
192 : src->sg.data[which].offset += size;
193 : }
194 :
195 : static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
196 : {
197 : memcpy(dst, src, sizeof(*src));
198 : sk_msg_init(src);
199 : }
200 :
201 : static inline bool sk_msg_full(const struct sk_msg *msg)
202 : {
203 : return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
204 : }
205 :
206 0 : static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
207 : {
208 0 : return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
209 : }
210 :
211 0 : static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which)
212 : {
213 0 : return &msg->sg.data[which];
214 : }
215 :
216 0 : static inline struct scatterlist sk_msg_elem_cpy(struct sk_msg *msg, int which)
217 : {
218 0 : return msg->sg.data[which];
219 : }
220 :
221 : static inline struct page *sk_msg_page(struct sk_msg *msg, int which)
222 : {
223 : return sg_page(sk_msg_elem(msg, which));
224 : }
225 :
226 : static inline bool sk_msg_to_ingress(const struct sk_msg *msg)
227 : {
228 : return msg->flags & BPF_F_INGRESS;
229 : }
230 :
231 0 : static inline void sk_msg_compute_data_pointers(struct sk_msg *msg)
232 : {
233 0 : struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start);
234 :
235 0 : if (test_bit(msg->sg.start, &msg->sg.copy)) {
236 0 : msg->data = NULL;
237 0 : msg->data_end = NULL;
238 : } else {
239 0 : msg->data = sg_virt(sge);
240 0 : msg->data_end = msg->data + sge->length;
241 : }
242 0 : }
243 :
244 : static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page,
245 : u32 len, u32 offset)
246 : {
247 : struct scatterlist *sge;
248 :
249 : get_page(page);
250 : sge = sk_msg_elem(msg, msg->sg.end);
251 : sg_set_page(sge, page, len, offset);
252 : sg_unmark_end(sge);
253 :
254 : __set_bit(msg->sg.end, &msg->sg.copy);
255 : msg->sg.size += len;
256 : sk_msg_iter_next(msg, end);
257 : }
258 :
259 : static inline void sk_msg_sg_copy(struct sk_msg *msg, u32 i, bool copy_state)
260 : {
261 : do {
262 : if (copy_state)
263 : __set_bit(i, &msg->sg.copy);
264 : else
265 : __clear_bit(i, &msg->sg.copy);
266 : sk_msg_iter_var_next(i);
267 : if (i == msg->sg.end)
268 : break;
269 : } while (1);
270 : }
271 :
272 : static inline void sk_msg_sg_copy_set(struct sk_msg *msg, u32 start)
273 : {
274 : sk_msg_sg_copy(msg, start, true);
275 : }
276 :
277 : static inline void sk_msg_sg_copy_clear(struct sk_msg *msg, u32 start)
278 : {
279 : sk_msg_sg_copy(msg, start, false);
280 : }
281 :
282 : static inline struct sk_psock *sk_psock(const struct sock *sk)
283 : {
284 : return rcu_dereference_sk_user_data(sk);
285 : }
286 :
287 : static inline void sk_psock_queue_msg(struct sk_psock *psock,
288 : struct sk_msg *msg)
289 : {
290 : list_add_tail(&msg->list, &psock->ingress_msg);
291 : }
292 :
293 : static inline bool sk_psock_queue_empty(const struct sk_psock *psock)
294 : {
295 : return psock ? list_empty(&psock->ingress_msg) : true;
296 : }
297 :
298 : static inline void sk_psock_report_error(struct sk_psock *psock, int err)
299 : {
300 : struct sock *sk = psock->sk;
301 :
302 : sk->sk_err = err;
303 : sk->sk_error_report(sk);
304 : }
305 :
306 : struct sk_psock *sk_psock_init(struct sock *sk, int node);
307 :
308 : int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock);
309 : void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock);
310 : void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock);
311 : void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock);
312 : void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock);
313 :
314 : int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
315 : struct sk_msg *msg);
316 :
317 : static inline struct sk_psock_link *sk_psock_init_link(void)
318 : {
319 : return kzalloc(sizeof(struct sk_psock_link),
320 : GFP_ATOMIC | __GFP_NOWARN);
321 : }
322 :
323 : static inline void sk_psock_free_link(struct sk_psock_link *link)
324 : {
325 : kfree(link);
326 : }
327 :
328 : struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
329 :
330 : void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
331 :
332 : static inline void sk_psock_cork_free(struct sk_psock *psock)
333 : {
334 : if (psock->cork) {
335 : sk_msg_free(psock->sk, psock->cork);
336 : kfree(psock->cork);
337 : psock->cork = NULL;
338 : }
339 : }
340 :
341 : static inline void sk_psock_update_proto(struct sock *sk,
342 : struct sk_psock *psock,
343 : struct proto *ops)
344 : {
345 : /* Pairs with lockless read in sk_clone_lock() */
346 : WRITE_ONCE(sk->sk_prot, ops);
347 : }
348 :
349 : static inline void sk_psock_restore_proto(struct sock *sk,
350 : struct sk_psock *psock)
351 : {
352 : sk->sk_prot->unhash = psock->saved_unhash;
353 : if (inet_csk_has_ulp(sk)) {
354 : tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
355 : } else {
356 : sk->sk_write_space = psock->saved_write_space;
357 : /* Pairs with lockless read in sk_clone_lock() */
358 : WRITE_ONCE(sk->sk_prot, psock->sk_proto);
359 : }
360 : }
361 :
362 : static inline void sk_psock_set_state(struct sk_psock *psock,
363 : enum sk_psock_state_bits bit)
364 : {
365 : set_bit(bit, &psock->state);
366 : }
367 :
368 : static inline void sk_psock_clear_state(struct sk_psock *psock,
369 : enum sk_psock_state_bits bit)
370 : {
371 : clear_bit(bit, &psock->state);
372 : }
373 :
374 : static inline bool sk_psock_test_state(const struct sk_psock *psock,
375 : enum sk_psock_state_bits bit)
376 : {
377 : return test_bit(bit, &psock->state);
378 : }
379 :
380 : static inline struct sk_psock *sk_psock_get(struct sock *sk)
381 : {
382 : struct sk_psock *psock;
383 :
384 : rcu_read_lock();
385 : psock = sk_psock(sk);
386 : if (psock && !refcount_inc_not_zero(&psock->refcnt))
387 : psock = NULL;
388 : rcu_read_unlock();
389 : return psock;
390 : }
391 :
392 : void sk_psock_stop(struct sock *sk, struct sk_psock *psock);
393 : void sk_psock_drop(struct sock *sk, struct sk_psock *psock);
394 :
395 : static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock)
396 : {
397 : if (refcount_dec_and_test(&psock->refcnt))
398 : sk_psock_drop(sk, psock);
399 : }
400 :
401 : static inline void sk_psock_data_ready(struct sock *sk, struct sk_psock *psock)
402 : {
403 : if (psock->parser.enabled)
404 : psock->parser.saved_data_ready(sk);
405 : else
406 : sk->sk_data_ready(sk);
407 : }
408 :
409 : static inline void psock_set_prog(struct bpf_prog **pprog,
410 : struct bpf_prog *prog)
411 : {
412 : prog = xchg(pprog, prog);
413 : if (prog)
414 : bpf_prog_put(prog);
415 : }
416 :
417 : static inline int psock_replace_prog(struct bpf_prog **pprog,
418 : struct bpf_prog *prog,
419 : struct bpf_prog *old)
420 : {
421 : if (cmpxchg(pprog, old, prog) != old)
422 : return -ENOENT;
423 :
424 : if (old)
425 : bpf_prog_put(old);
426 :
427 : return 0;
428 : }
429 :
430 : static inline void psock_progs_drop(struct sk_psock_progs *progs)
431 : {
432 : psock_set_prog(&progs->msg_parser, NULL);
433 : psock_set_prog(&progs->skb_parser, NULL);
434 : psock_set_prog(&progs->skb_verdict, NULL);
435 : }
436 :
437 : int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb);
438 :
439 : static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
440 : {
441 : if (!psock)
442 : return false;
443 : return psock->parser.enabled;
444 : }
445 : #endif /* _LINUX_SKMSG_H */
|