#include #include #include struct heap_block { UINTN size; // includes header; bit 0 = 1 used, 0 free struct heap_block* next; // free list link (only valid when free) }; #define HEAP_ALIGN 16 #define HEADER_SIZE ((UINTN)sizeof(struct heap_block)) #define MIN_BLOCK_SIZE (HEADER_SIZE + HEAP_ALIGN) #define HEAP_INIT_PAGES 4 #define BLOCK_SIZE(block) ((block)->size & ~(UINTN)1) #define IS_USED(block) ((block)->size & 1) #define IS_FREE(block) (!IS_USED(block)) static struct heap_block* g_heap_free_list = NULL; static void* g_heap_start = NULL; static void* g_heap_end = NULL; static UINTN align_up(UINTN val, UINTN align) { return (val + align - 1) & ~(align - 1ULL); } static struct heap_block* next_block(struct heap_block* block) { return (struct heap_block*)((UINT8*)block + BLOCK_SIZE(block)); } static void heap_expand(UINTN min_size) { UINTN pages = (min_size + PAGE_SIZE - 1) / PAGE_SIZE; void* mem = pmm_alloc_pages(pages); if (!mem) { serial_write("HEAP: expand failed!\n"); return; } struct heap_block* new_block = (struct heap_block*)mem; new_block->size = pages * PAGE_SIZE; new_block->next = NULL; // Add to free list (sorted by address for coalescing) struct heap_block** prev = &g_heap_free_list; while (*prev && (UINT8*)*prev < (UINT8*)new_block) { prev = &(*prev)->next; } new_block->next = *prev; *prev = new_block; // Try to merge with the previous free block if adjacent if (prev != &g_heap_free_list) { struct heap_block* prev_block = g_heap_free_list; while (prev_block->next != new_block) { prev_block = prev_block->next; } if ((UINT8*)prev_block + BLOCK_SIZE(prev_block) == (UINT8*)new_block) { prev_block->size += new_block->size; prev_block->next = new_block->next; new_block = prev_block; } } if ((UINT8*)new_block + BLOCK_SIZE(new_block) > (UINT8*)g_heap_end) { g_heap_end = (UINT8*)new_block + BLOCK_SIZE(new_block); } serial_write("HEAP: expanded by "); serial_write_hex(pages * PAGE_SIZE); serial_write(" bytes\n"); } void init_heap() { void* mem = pmm_alloc_pages(HEAP_INIT_PAGES); if (!mem) { serial_write("HEAP: init failed!\n"); return; } g_heap_start = mem; g_heap_end = (void*)((UINT8*)mem + HEAP_INIT_PAGES * PAGE_SIZE); struct heap_block* initial = (struct heap_block*)mem; initial->size = HEAP_INIT_PAGES * PAGE_SIZE; initial->next = NULL; g_heap_free_list = initial; serial_write("HEAP: init OK, "); serial_write_hex(HEAP_INIT_PAGES * PAGE_SIZE); serial_write(" bytes @ "); serial_write_hex((UINTN)mem); serial_write("\n"); } void* kmalloc(UINTN size) { if (size == 0) return NULL; UINTN alloc_size = align_up(size + HEADER_SIZE, HEAP_ALIGN); if (alloc_size < MIN_BLOCK_SIZE) alloc_size = MIN_BLOCK_SIZE; struct heap_block** prev = &g_heap_free_list; while (*prev) { UINTN block_sz = BLOCK_SIZE(*prev); if (block_sz >= alloc_size) { // Found a suitable block struct heap_block* block = *prev; // Split if remaining space is useful if (block_sz >= alloc_size + MIN_BLOCK_SIZE) { struct heap_block* split = (struct heap_block*)((UINT8*)block + alloc_size); split->size = block_sz - alloc_size; // Insert split into free list split->next = block->next; block->size = alloc_size | 1; *prev = split; } else { // Use the whole block *prev = block->next; block->size = block_sz | 1; // mark used } if (size > 1024) { serial_write("HEAP: kmalloc "); serial_write_hex(size); serial_write(" -> "); serial_write_hex((UINTN)(block + 1)); serial_write("\n"); } return (void*)(block + 1); } prev = &(*prev)->next; } // Out of memory in current heap — expand UINTN expand_size = alloc_size > PAGE_SIZE ? alloc_size : PAGE_SIZE; heap_expand(expand_size); // Retry after expansion return kmalloc(size); } void kfree(void* ptr) { if (!ptr) return; struct heap_block* block = (struct heap_block*)ptr - 1; if (IS_FREE(block)) { serial_write("HEAP: double free detected!\n"); return; } // Mark as free block->size &= ~(UINTN)1; // Merge with next block if it's free struct heap_block* next = next_block(block); if ((UINT8*)next < (UINT8*)g_heap_end) { if (IS_FREE(next)) { // Remove next from free list and merge block->size += next->size; struct heap_block** prev = &g_heap_free_list; while (*prev && *prev != next) { prev = &(*prev)->next; } if (*prev) *prev = next->next; } } // Insert block into free list struct heap_block** prev = &g_heap_free_list; while (*prev && (UINT8*)*prev < (UINT8*)block) { prev = &(*prev)->next; } block->next = *prev; *prev = block; serial_write("HEAP: kfree @ "); serial_write_hex((UINTN)ptr); serial_write("\n"); } void* kcalloc(UINTN num, UINTN size) { UINTN total = num * size; void* ptr = kmalloc(total); if (ptr) { UINT8* p = (UINT8*)ptr; for (UINTN i = 0; i < total; i++) { p[i] = 0; } } return ptr; } void* krealloc(void* ptr, UINTN new_size) { if (!ptr) return kmalloc(new_size); if (new_size == 0) { kfree(ptr); return NULL; } struct heap_block* block = (struct heap_block*)ptr - 1; UINTN old_size = BLOCK_SIZE(block) - HEADER_SIZE; if (old_size >= new_size) { // Can we split the shrinkage? UINTN shrink = old_size - new_size; if (shrink >= MIN_BLOCK_SIZE) { block->size = (new_size + HEADER_SIZE) | 1; struct heap_block* split = (struct heap_block*)((UINT8*)ptr + new_size); split->size = shrink; kfree(split + 1); } return ptr; } void* new_ptr = kmalloc(new_size); if (new_ptr) { UINT8* src = (UINT8*)ptr; UINT8* dst = (UINT8*)new_ptr; for (UINTN i = 0; i < old_size; i++) { dst[i] = src[i]; } kfree(ptr); } return new_ptr; }