Line data Source code
1 : // SPDX-License-Identifier: GPL-2.0
2 : #include <linux/init.h>
3 : #include <linux/static_call.h>
4 : #include <linux/bug.h>
5 : #include <linux/smp.h>
6 : #include <linux/sort.h>
7 : #include <linux/slab.h>
8 : #include <linux/module.h>
9 : #include <linux/cpu.h>
10 : #include <linux/processor.h>
11 : #include <asm/sections.h>
12 :
13 : extern struct static_call_site __start_static_call_sites[],
14 : __stop_static_call_sites[];
15 : extern struct static_call_tramp_key __start_static_call_tramp_key[],
16 : __stop_static_call_tramp_key[];
17 :
18 : static bool static_call_initialized;
19 :
20 : /* mutex to protect key modules/sites */
21 : static DEFINE_MUTEX(static_call_mutex);
22 :
23 20 : static void static_call_lock(void)
24 : {
25 20 : mutex_lock(&static_call_mutex);
26 : }
27 :
28 20 : static void static_call_unlock(void)
29 : {
30 20 : mutex_unlock(&static_call_mutex);
31 : }
32 :
33 551 : static inline void *static_call_addr(struct static_call_site *site)
34 : {
35 551 : return (void *)((long)site->addr + (long)&site->addr);
36 : }
37 :
38 :
39 5559 : static inline struct static_call_key *static_call_key(const struct static_call_site *site)
40 : {
41 5559 : return (struct static_call_key *)
42 5559 : (((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
43 : }
44 :
45 : /* These assume the key is word-aligned. */
46 20 : static inline bool static_call_is_init(struct static_call_site *site)
47 : {
48 20 : return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
49 : }
50 :
51 551 : static inline bool static_call_is_tail(struct static_call_site *site)
52 : {
53 551 : return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
54 : }
55 :
56 5 : static inline void static_call_set_init(struct static_call_site *site)
57 : {
58 5 : site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
59 : (long)&site->key;
60 5 : }
61 :
62 4986 : static int static_call_site_cmp(const void *_a, const void *_b)
63 : {
64 4986 : const struct static_call_site *a = _a;
65 4986 : const struct static_call_site *b = _b;
66 4986 : const struct static_call_key *key_a = static_call_key(a);
67 4986 : const struct static_call_key *key_b = static_call_key(b);
68 :
69 4986 : if (key_a < key_b)
70 : return -1;
71 :
72 2213 : if (key_a > key_b)
73 2182 : return 1;
74 :
75 : return 0;
76 : }
77 :
78 4542 : static void static_call_site_swap(void *_a, void *_b, int size)
79 : {
80 4542 : long delta = (unsigned long)_a - (unsigned long)_b;
81 4542 : struct static_call_site *a = _a;
82 4542 : struct static_call_site *b = _b;
83 4542 : struct static_call_site tmp = *a;
84 :
85 4542 : a->addr = b->addr - delta;
86 4542 : a->key = b->key - delta;
87 :
88 4542 : b->addr = tmp.addr + delta;
89 4542 : b->key = tmp.key + delta;
90 4542 : }
91 :
92 1 : static inline void static_call_sort_entries(struct static_call_site *start,
93 : struct static_call_site *stop)
94 : {
95 1 : sort(start, stop - start, sizeof(struct static_call_site),
96 : static_call_site_cmp, static_call_site_swap);
97 1 : }
98 :
99 36 : static inline bool static_call_key_has_mods(struct static_call_key *key)
100 : {
101 36 : return !(key->type & 1);
102 : }
103 :
104 18 : static inline struct static_call_mod *static_call_key_next(struct static_call_key *key)
105 : {
106 18 : if (!static_call_key_has_mods(key))
107 : return NULL;
108 :
109 0 : return key->mods;
110 : }
111 :
112 18 : static inline struct static_call_site *static_call_key_sites(struct static_call_key *key)
113 : {
114 18 : if (static_call_key_has_mods(key))
115 : return NULL;
116 :
117 18 : return (struct static_call_site *)(key->type & ~1);
118 : }
119 :
120 19 : void __static_call_update(struct static_call_key *key, void *tramp, void *func)
121 : {
122 19 : struct static_call_site *site, *stop;
123 19 : struct static_call_mod *site_mod, first;
124 :
125 19 : cpus_read_lock();
126 19 : static_call_lock();
127 :
128 19 : if (key->func == func)
129 1 : goto done;
130 :
131 18 : key->func = func;
132 :
133 18 : arch_static_call_transform(NULL, tramp, func, false);
134 :
135 : /*
136 : * If uninitialized, we'll not update the callsites, but they still
137 : * point to the trampoline and we just patched that.
138 : */
139 18 : if (WARN_ON_ONCE(!static_call_initialized))
140 0 : goto done;
141 :
142 18 : first = (struct static_call_mod){
143 18 : .next = static_call_key_next(key),
144 : .mod = NULL,
145 18 : .sites = static_call_key_sites(key),
146 : };
147 :
148 36 : for (site_mod = &first; site_mod; site_mod = site_mod->next) {
149 18 : struct module *mod = site_mod->mod;
150 :
151 18 : if (!site_mod->sites) {
152 : /*
153 : * This can happen if the static call key is defined in
154 : * a module which doesn't use it.
155 : *
156 : * It also happens in the has_mods case, where the
157 : * 'first' entry has no sites associated with it.
158 : */
159 1 : continue;
160 : }
161 :
162 37 : stop = __stop_static_call_sites;
163 :
164 : #ifdef CONFIG_MODULES
165 : if (mod) {
166 : stop = mod->static_call_sites +
167 : mod->num_static_call_sites;
168 : }
169 : #endif
170 :
171 37 : for (site = site_mod->sites;
172 37 : site < stop && static_call_key(site) == key; site++) {
173 20 : void *site_addr = static_call_addr(site);
174 :
175 20 : if (static_call_is_init(site)) {
176 : /*
177 : * Don't write to call sites which were in
178 : * initmem and have since been freed.
179 : */
180 0 : if (!mod && system_state >= SYSTEM_RUNNING)
181 0 : continue;
182 0 : if (mod && !within_module_init((unsigned long)site_addr, mod))
183 0 : continue;
184 : }
185 :
186 20 : if (!kernel_text_address((unsigned long)site_addr)) {
187 0 : WARN_ONCE(1, "can't patch static call site at %pS",
188 : site_addr);
189 0 : continue;
190 : }
191 :
192 20 : arch_static_call_transform(site_addr, NULL, func,
193 20 : static_call_is_tail(site));
194 : }
195 : }
196 :
197 18 : done:
198 19 : static_call_unlock();
199 19 : cpus_read_unlock();
200 19 : }
201 : EXPORT_SYMBOL_GPL(__static_call_update);
202 :
203 1 : static int __static_call_init(struct module *mod,
204 : struct static_call_site *start,
205 : struct static_call_site *stop)
206 : {
207 1 : struct static_call_site *site;
208 1 : struct static_call_key *key, *prev_key = NULL;
209 1 : struct static_call_mod *site_mod;
210 :
211 1 : if (start == stop)
212 : return 0;
213 :
214 1 : static_call_sort_entries(start, stop);
215 :
216 533 : for (site = start; site < stop; site++) {
217 531 : void *site_addr = static_call_addr(site);
218 :
219 531 : if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
220 531 : (!mod && init_section_contains(site_addr, 1)))
221 5 : static_call_set_init(site);
222 :
223 531 : key = static_call_key(site);
224 531 : if (key != prev_key) {
225 508 : prev_key = key;
226 :
227 : /*
228 : * For vmlinux (!mod) avoid the allocation by storing
229 : * the sites pointer in the key itself. Also see
230 : * __static_call_update()'s @first.
231 : *
232 : * This allows architectures (eg. x86) to call
233 : * static_call_init() before memory allocation works.
234 : */
235 508 : if (!mod) {
236 508 : key->sites = site;
237 508 : key->type |= 1;
238 508 : goto do_transform;
239 : }
240 :
241 0 : site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
242 0 : if (!site_mod)
243 : return -ENOMEM;
244 :
245 : /*
246 : * When the key has a direct sites pointer, extract
247 : * that into an explicit struct static_call_mod, so we
248 : * can have a list of modules.
249 : */
250 0 : if (static_call_key_sites(key)) {
251 0 : site_mod->mod = NULL;
252 0 : site_mod->next = NULL;
253 0 : site_mod->sites = static_call_key_sites(key);
254 :
255 0 : key->mods = site_mod;
256 :
257 0 : site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
258 0 : if (!site_mod)
259 : return -ENOMEM;
260 : }
261 :
262 0 : site_mod->mod = mod;
263 0 : site_mod->sites = site;
264 0 : site_mod->next = static_call_key_next(key);
265 0 : key->mods = site_mod;
266 : }
267 :
268 23 : do_transform:
269 531 : arch_static_call_transform(site_addr, NULL, key->func,
270 531 : static_call_is_tail(site));
271 : }
272 :
273 : return 0;
274 : }
275 :
276 0 : static int addr_conflict(struct static_call_site *site, void *start, void *end)
277 : {
278 0 : unsigned long addr = (unsigned long)static_call_addr(site);
279 :
280 0 : if (addr <= (unsigned long)end &&
281 0 : addr + CALL_INSN_SIZE > (unsigned long)start)
282 : return 1;
283 :
284 : return 0;
285 : }
286 :
287 0 : static int __static_call_text_reserved(struct static_call_site *iter_start,
288 : struct static_call_site *iter_stop,
289 : void *start, void *end)
290 : {
291 0 : struct static_call_site *iter = iter_start;
292 :
293 0 : while (iter < iter_stop) {
294 0 : if (addr_conflict(iter, start, end))
295 : return 1;
296 0 : iter++;
297 : }
298 :
299 : return 0;
300 : }
301 :
302 : #ifdef CONFIG_MODULES
303 :
304 : static int __static_call_mod_text_reserved(void *start, void *end)
305 : {
306 : struct module *mod;
307 : int ret;
308 :
309 : preempt_disable();
310 : mod = __module_text_address((unsigned long)start);
311 : WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
312 : if (!try_module_get(mod))
313 : mod = NULL;
314 : preempt_enable();
315 :
316 : if (!mod)
317 : return 0;
318 :
319 : ret = __static_call_text_reserved(mod->static_call_sites,
320 : mod->static_call_sites + mod->num_static_call_sites,
321 : start, end);
322 :
323 : module_put(mod);
324 :
325 : return ret;
326 : }
327 :
328 : static unsigned long tramp_key_lookup(unsigned long addr)
329 : {
330 : struct static_call_tramp_key *start = __start_static_call_tramp_key;
331 : struct static_call_tramp_key *stop = __stop_static_call_tramp_key;
332 : struct static_call_tramp_key *tramp_key;
333 :
334 : for (tramp_key = start; tramp_key != stop; tramp_key++) {
335 : unsigned long tramp;
336 :
337 : tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp;
338 : if (tramp == addr)
339 : return (long)tramp_key->key + (long)&tramp_key->key;
340 : }
341 :
342 : return 0;
343 : }
344 :
345 : static int static_call_add_module(struct module *mod)
346 : {
347 : struct static_call_site *start = mod->static_call_sites;
348 : struct static_call_site *stop = start + mod->num_static_call_sites;
349 : struct static_call_site *site;
350 :
351 : for (site = start; site != stop; site++) {
352 : unsigned long s_key = (long)site->key + (long)&site->key;
353 : unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS;
354 : unsigned long key;
355 :
356 : /*
357 : * Is the key is exported, 'addr' points to the key, which
358 : * means modules are allowed to call static_call_update() on
359 : * it.
360 : *
361 : * Otherwise, the key isn't exported, and 'addr' points to the
362 : * trampoline so we need to lookup the key.
363 : *
364 : * We go through this dance to prevent crazy modules from
365 : * abusing sensitive static calls.
366 : */
367 : if (!kernel_text_address(addr))
368 : continue;
369 :
370 : key = tramp_key_lookup(addr);
371 : if (!key) {
372 : pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n",
373 : static_call_addr(site));
374 : return -EINVAL;
375 : }
376 :
377 : key |= s_key & STATIC_CALL_SITE_FLAGS;
378 : site->key = key - (long)&site->key;
379 : }
380 :
381 : return __static_call_init(mod, start, stop);
382 : }
383 :
384 : static void static_call_del_module(struct module *mod)
385 : {
386 : struct static_call_site *start = mod->static_call_sites;
387 : struct static_call_site *stop = mod->static_call_sites +
388 : mod->num_static_call_sites;
389 : struct static_call_key *key, *prev_key = NULL;
390 : struct static_call_mod *site_mod, **prev;
391 : struct static_call_site *site;
392 :
393 : for (site = start; site < stop; site++) {
394 : key = static_call_key(site);
395 : if (key == prev_key)
396 : continue;
397 :
398 : prev_key = key;
399 :
400 : for (prev = &key->mods, site_mod = key->mods;
401 : site_mod && site_mod->mod != mod;
402 : prev = &site_mod->next, site_mod = site_mod->next)
403 : ;
404 :
405 : if (!site_mod)
406 : continue;
407 :
408 : *prev = site_mod->next;
409 : kfree(site_mod);
410 : }
411 : }
412 :
413 : static int static_call_module_notify(struct notifier_block *nb,
414 : unsigned long val, void *data)
415 : {
416 : struct module *mod = data;
417 : int ret = 0;
418 :
419 : cpus_read_lock();
420 : static_call_lock();
421 :
422 : switch (val) {
423 : case MODULE_STATE_COMING:
424 : ret = static_call_add_module(mod);
425 : if (ret) {
426 : WARN(1, "Failed to allocate memory for static calls");
427 : static_call_del_module(mod);
428 : }
429 : break;
430 : case MODULE_STATE_GOING:
431 : static_call_del_module(mod);
432 : break;
433 : }
434 :
435 : static_call_unlock();
436 : cpus_read_unlock();
437 :
438 : return notifier_from_errno(ret);
439 : }
440 :
441 : static struct notifier_block static_call_module_nb = {
442 : .notifier_call = static_call_module_notify,
443 : };
444 :
445 : #else
446 :
447 : static inline int __static_call_mod_text_reserved(void *start, void *end)
448 : {
449 : return 0;
450 : }
451 :
452 : #endif /* CONFIG_MODULES */
453 :
454 0 : int static_call_text_reserved(void *start, void *end)
455 : {
456 0 : int ret = __static_call_text_reserved(__start_static_call_sites,
457 : __stop_static_call_sites, start, end);
458 :
459 0 : if (ret)
460 : return ret;
461 :
462 0 : return __static_call_mod_text_reserved(start, end);
463 : }
464 :
465 2 : int __init static_call_init(void)
466 : {
467 2 : int ret;
468 :
469 2 : if (static_call_initialized)
470 : return 0;
471 :
472 1 : cpus_read_lock();
473 1 : static_call_lock();
474 1 : ret = __static_call_init(NULL, __start_static_call_sites,
475 : __stop_static_call_sites);
476 1 : static_call_unlock();
477 1 : cpus_read_unlock();
478 :
479 1 : if (ret) {
480 0 : pr_err("Failed to allocate memory for static_call!\n");
481 0 : BUG();
482 : }
483 :
484 1 : static_call_initialized = true;
485 :
486 : #ifdef CONFIG_MODULES
487 : register_module_notifier(&static_call_module_nb);
488 : #endif
489 1 : return 0;
490 : }
491 : early_initcall(static_call_init);
492 :
493 0 : long __static_call_return0(void)
494 : {
495 0 : return 0;
496 : }
497 :
498 : #ifdef CONFIG_STATIC_CALL_SELFTEST
499 :
500 : static int func_a(int x)
501 : {
502 : return x+1;
503 : }
504 :
505 : static int func_b(int x)
506 : {
507 : return x+2;
508 : }
509 :
510 : DEFINE_STATIC_CALL(sc_selftest, func_a);
511 :
512 : static struct static_call_data {
513 : int (*func)(int);
514 : int val;
515 : int expect;
516 : } static_call_data [] __initdata = {
517 : { NULL, 2, 3 },
518 : { func_b, 2, 4 },
519 : { func_a, 2, 3 }
520 : };
521 :
522 : static int __init test_static_call_init(void)
523 : {
524 : int i;
525 :
526 : for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
527 : struct static_call_data *scd = &static_call_data[i];
528 :
529 : if (scd->func)
530 : static_call_update(sc_selftest, scd->func);
531 :
532 : WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
533 : }
534 :
535 : return 0;
536 : }
537 : early_initcall(test_static_call_init);
538 :
539 : #endif /* CONFIG_STATIC_CALL_SELFTEST */
|