kernel/drivers/block/
virtio_blk.rs

1//! # VirtIO Block Device Driver
2//! 
3//! This module provides a driver for VirtIO block devices, implementing the
4//! BlockDevice trait for integration with the kernel's block device subsystem.
5//!
6//! The driver supports basic block operations (read/write) and handles the VirtIO
7//! queue management for block device requests.
8//!
9//! ## Features Support
10//! 
11//! The driver checks for and handles the following VirtIO block device features:
12//! - `VIRTIO_BLK_F_BLK_SIZE`: Custom sector size
13//! - `VIRTIO_BLK_F_RO`: Read-only device detection
14//!
15//! ## Implementation Details
16//!
17//! The driver uses a single virtqueue for processing block I/O requests. Each request
18//! consists of three parts:
19//! 1. Request header (specifying operation type and sector)
20//! 2. Data buffer (for read/write content)
21//! 3. Status byte (for operation result)
22//!
23//! Requests are processed through the VirtIO descriptor chain mechanism, with proper
24//! memory management using Box allocations to ensure data remains valid during transfers.
25
26use alloc::{boxed::Box, vec::Vec};
27use alloc::vec;
28use spin::{Mutex, RwLock};
29
30use core::{mem, ptr};
31
32use crate::defer;
33use crate::device::{Device, DeviceType};
34use crate::{
35    device::block::{request::{BlockIORequest, BlockIORequestType, BlockIOResult}, BlockDevice}, 
36    drivers::virtio::{device::{Register, VirtioDevice}, queue::{DescriptorFlag, VirtQueue}}
37};
38
39// VirtIO Block Request Type
40const VIRTIO_BLK_T_IN: u32 = 0;     // Read
41const VIRTIO_BLK_T_OUT: u32 = 1;    // Write
42// const VIRTIO_BLK_T_FLUSH: u32 = 4;  // Flush
43
44// VirtIO Block Status Codes
45const VIRTIO_BLK_S_OK: u8 = 0;
46const VIRTIO_BLK_S_IOERR: u8 = 1;
47const VIRTIO_BLK_S_UNSUPP: u8 = 2;
48
49// Device Feature bits
50// const VIRTIO_BLK_F_SIZE_MAX: u32 = 1;
51// const VIRTIO_BLK_F_SEG_MAX: u32 = 2;
52// const VIRTIO_BLK_F_GEOMETRY: u32 = 4;
53const VIRTIO_BLK_F_RO: u32 = 5;
54const VIRTIO_BLK_F_BLK_SIZE: u32 = 6;
55const VIRTIO_BLK_F_SCSI: u32 = 7;
56// const VIRTIO_BLK_F_FLUSH: u32 = 9;
57const VIRTIO_BLK_F_CONFIG_WCE: u32 = 11;
58const VIRTIO_BLK_F_MQ: u32 = 12;
59const VIRTIO_F_ANY_LAYOUT: u32 = 27;
60const VIRTIO_RING_F_INDIRECT_DESC: u32 = 28;
61const VIRTIO_RING_F_EVENT_IDX: u32 = 29;
62
63// #define VIRTIO_BLK_F_RO              5	/* Disk is read-only */
64// #define VIRTIO_BLK_F_SCSI            7	/* Supports scsi command passthru */
65// #define VIRTIO_BLK_F_CONFIG_WCE     11	/* Writeback mode available in config */
66// #define VIRTIO_BLK_F_MQ             12	/* support more than one vq */
67// #define VIRTIO_F_ANY_LAYOUT         27
68// #define VIRTIO_RING_F_INDIRECT_DESC 28
69// #define VIRTIO_RING_F_EVENT_IDX     29
70
71#[repr(C)]
72pub struct VirtioBlkConfig {
73    pub capacity: u64,
74    pub size_max: u32,
75    pub seg_max: u32,
76    pub geometry: VirtioBlkGeometry,
77    pub blk_size: u32,
78    pub topology: VirtioBlkTopology,
79    pub writeback: u8,
80}
81
82#[repr(C)]
83pub struct VirtioBlkGeometry {
84    pub cylinders: u16,
85    pub heads: u8,
86    pub sectors: u8,
87}
88
89#[repr(C)]
90pub struct VirtioBlkTopology {
91    pub physical_block_exp: u8,
92    pub alignment_offset: u8,
93    pub min_io_size: u16,
94    pub opt_io_size: u32,
95}
96
97#[repr(C)]
98pub struct VirtioBlkReqHeader {
99    pub type_: u32,
100    pub reserved: u32,
101    pub sector: u64,
102}
103
104pub struct VirtioBlockDevice {
105    base_addr: usize,
106    virtqueues: Mutex<[VirtQueue<'static>; 1]>, // Only one queue for request/response
107    capacity: RwLock<u64>,
108    sector_size: RwLock<u32>,
109    features: RwLock<u32>,
110    read_only: RwLock<bool>,
111    request_queue: Mutex<Vec<Box<BlockIORequest>>>,
112}
113
114impl VirtioBlockDevice {
115    pub fn new(base_addr: usize) -> Self {
116        let mut device = Self {
117            base_addr,
118            virtqueues: Mutex::new([VirtQueue::new(8)]),
119            capacity: RwLock::new(0),
120            sector_size: RwLock::new(512), // Default sector size
121            features: RwLock::new(0),
122            read_only: RwLock::new(false),
123            request_queue: Mutex::new(Vec::new()),
124        };
125        
126        // Initialize the device
127        if device.init().is_err() {
128            panic!("Failed to initialize Virtio Block Device");
129        }
130
131        // Read device configuration
132        *device.capacity.write() = device.read_config::<u64>(0); // Capacity at offset 0
133
134        // Read device features
135        let features = device.read32_register(Register::DeviceFeatures);
136        *device.features.write() = features;
137        
138        // Check if block size feature is supported
139        if features & (1 << VIRTIO_BLK_F_BLK_SIZE) != 0 {
140            *device.sector_size.write() = device.read_config::<u32>(20); // blk_size at offset 20
141        }
142        
143        // Check if device is read-only
144        *device.read_only.write() = features & (1 << VIRTIO_BLK_F_RO) != 0;
145
146        device
147    }
148    
149    fn process_request(&self, req: &mut BlockIORequest) -> Result<(), &'static str> {
150        // Allocate memory for request header, data, and status
151        let header = Box::new(VirtioBlkReqHeader {
152            type_: match req.request_type {
153                BlockIORequestType::Read => VIRTIO_BLK_T_IN,
154                BlockIORequestType::Write => VIRTIO_BLK_T_OUT,
155            },
156            reserved: 0,
157            sector: req.sector as u64,
158        });
159        let data = vec![0u8; req.buffer.len()].into_boxed_slice();
160        let status = Box::new(0u8);
161                
162        // Cast pages to appropriate types
163        let header_ptr = Box::into_raw(header);
164        let data_ptr = Box::into_raw(data) as *mut [u8];
165        let status_ptr = Box::into_raw(status);
166
167        defer! {
168            // Deallocate memory after use
169            unsafe {
170                drop(Box::from_raw(header_ptr));
171                drop(Box::from_raw(data_ptr));
172                drop(Box::from_raw(status_ptr));
173            }
174        }
175
176        // Set up request header
177        unsafe {
178            // Copy data for write requests
179            if let BlockIORequestType::Write = req.request_type {
180                ptr::copy_nonoverlapping(
181                    req.buffer.as_ptr(),
182                    data_ptr as *mut u8,
183                    req.buffer.len()
184                );
185            }
186        }
187        
188        // Lock the virtqueues for processing
189        let mut virtqueues = self.virtqueues.lock();
190        
191        // Allocate descriptors for the request
192        let header_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
193        let data_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
194        let status_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
195        
196        // Set up header descriptor
197        virtqueues[0].desc[header_desc].addr = (header_ptr as usize) as u64;
198        virtqueues[0].desc[header_desc].len = mem::size_of::<VirtioBlkReqHeader>() as u32;
199        virtqueues[0].desc[header_desc].flags = DescriptorFlag::Next as u16;
200        virtqueues[0].desc[header_desc].next = data_desc as u16;
201        
202        // Set up data descriptor
203        virtqueues[0].desc[data_desc].addr = (data_ptr as *mut u8 as usize) as u64;
204        virtqueues[0].desc[data_desc].len = req.buffer.len() as u32;
205        
206        // Set flags based on request type
207        match req.request_type {
208            BlockIORequestType::Read => {
209                DescriptorFlag::Next.set(&mut virtqueues[0].desc[data_desc].flags);
210                DescriptorFlag::Write.set(&mut virtqueues[0].desc[data_desc].flags);
211            },
212            BlockIORequestType::Write => {
213                DescriptorFlag::Next.set(&mut virtqueues[0].desc[data_desc].flags);
214            }
215        }
216        
217        virtqueues[0].desc[data_desc].next = status_desc as u16;
218        
219        // Set up status descriptor
220        virtqueues[0].desc[status_desc].addr = (status_ptr as usize) as u64;
221        virtqueues[0].desc[status_desc].len = 1;
222        virtqueues[0].desc[status_desc].flags |= DescriptorFlag::Write as u16;
223        
224        // Submit the request to the queue
225        virtqueues[0].push(header_desc)?;
226
227        // Notify the device
228        self.notify(0);
229        
230        // Wait for the response (polling)
231        while virtqueues[0].is_busy() {}
232        while *virtqueues[0].used.idx as usize == virtqueues[0].last_used_idx {}
233
234        // Process completed request
235        let desc_idx = virtqueues[0].pop().ok_or("No response from device")?;
236        if desc_idx != header_desc {
237            return Err("Invalid descriptor index");
238        }
239        
240        // Check status
241        let status_val = unsafe { *status_ptr };
242        match status_val {
243            VIRTIO_BLK_S_OK => {
244                // For read requests, copy data to the buffer
245                if let BlockIORequestType::Read = req.request_type {
246                    unsafe {
247                        req.buffer.clear();
248                        req.buffer.extend_from_slice(core::slice::from_raw_parts(
249                            data_ptr as *const u8,
250                            virtqueues[0].desc[data_desc].len as usize
251                        ));
252                    }
253                }
254                Ok(())
255            },
256            VIRTIO_BLK_S_IOERR => Err("I/O error"),
257            VIRTIO_BLK_S_UNSUPP => Err("Unsupported request"),
258            _ => Err("Unknown error"),
259        }
260    }
261}
262
263impl Device for VirtioBlockDevice {
264    fn device_type(&self) -> DeviceType {
265        DeviceType::Block
266    }
267    
268    fn name(&self) -> &'static str {
269        "virtio-blk"
270    }
271    
272    fn as_any(&self) -> &dyn core::any::Any {
273        self
274    }
275    
276    fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
277        self
278    }
279    
280    fn as_block_device(&self) -> Option<&dyn crate::device::block::BlockDevice> {
281        Some(self)
282    }
283}
284
285impl VirtioDevice for VirtioBlockDevice {
286    fn get_base_addr(&self) -> usize {
287        self.base_addr
288    }
289    
290    fn get_virtqueue_count(&self) -> usize {
291        1 // We have one virtqueue
292    }
293    
294    fn get_supported_features(&self, device_features: u32) -> u32 {
295        // Accept most features but we might want to be selective
296        device_features & !(1 << VIRTIO_BLK_F_RO |
297            1 << VIRTIO_BLK_F_SCSI |
298            1 << VIRTIO_BLK_F_CONFIG_WCE |
299            1 << VIRTIO_BLK_F_MQ |
300            1 << VIRTIO_F_ANY_LAYOUT |
301            1 << VIRTIO_RING_F_EVENT_IDX |
302            1 << VIRTIO_RING_F_INDIRECT_DESC)
303    }
304    
305    fn get_queue_desc_addr(&self, queue_idx: usize) -> Option<u64> {
306        if queue_idx >= 1 {
307            return None;
308        }
309        
310        let virtqueues = self.virtqueues.lock();
311        Some(virtqueues[queue_idx].get_raw_ptr() as u64)
312    }
313    
314    fn get_queue_driver_addr(&self, queue_idx: usize) -> Option<u64> {
315        if queue_idx >= 1 {
316            return None;
317        }
318        
319        let virtqueues = self.virtqueues.lock();
320        Some(virtqueues[queue_idx].avail.flags as *const _ as u64)
321    }
322    
323    fn get_queue_device_addr(&self, queue_idx: usize) -> Option<u64> {
324        if queue_idx >= 1 {
325            return None;
326        }
327        
328        let virtqueues = self.virtqueues.lock();
329        Some(virtqueues[queue_idx].used.flags as *const _ as u64)
330    }
331}
332
333impl BlockDevice for VirtioBlockDevice {
334    fn get_disk_name(&self) -> &'static str {
335        "virtio-blk"
336    }
337    
338    fn get_disk_size(&self) -> usize {
339        let capacity = *self.capacity.read();
340        let sector_size = *self.sector_size.read();
341        (capacity * sector_size as u64) as usize
342    }
343    
344    fn enqueue_request(&self, request: Box<BlockIORequest>) {
345        // Enqueue the request
346        self.request_queue.lock().push(request);
347    }
348    
349    fn process_requests(&self) -> Vec<BlockIOResult> {
350        let mut results = Vec::new();
351        let mut queue = self.request_queue.lock();
352        while let Some(mut request) = queue.pop() {
353            drop(queue); // Release the lock before processing
354            let result = self.process_request(&mut *request);
355            results.push(BlockIOResult { request, result });
356            queue = self.request_queue.lock(); // Reacquire the lock
357        }
358        
359        results
360    }
361}
362
363#[cfg(test)]
364pub mod tests {
365    use super::*;
366    use alloc::vec;
367
368    #[test_case]
369    fn test_virtio_block_device_init() {
370        let base_addr = 0x10001000; // Example base address
371        let device = VirtioBlockDevice::new(base_addr);
372        
373        assert_eq!(device.get_disk_name(), "virtio-blk");
374        assert_eq!(device.get_disk_size(), (*device.capacity.read() * *device.sector_size.read() as u64) as usize);
375    }
376    
377    #[test_case]
378    fn test_virtio_block_device() {
379        let base_addr = 0x10001000; // Example base address
380        let device = VirtioBlockDevice::new(base_addr);
381        
382        assert_eq!(device.get_disk_name(), "virtio-blk");
383        assert_eq!(device.get_disk_size(), (*device.capacity.read() * *device.sector_size.read() as u64) as usize);
384        
385        // Test enqueue and process requests
386        let sector_size = *device.sector_size.read();
387        let request = BlockIORequest {
388            request_type: BlockIORequestType::Read,
389            sector: 0,
390            sector_count: 1,
391            head: 0,
392            cylinder: 0,
393            buffer: vec![0; sector_size as usize],
394        };
395        device.enqueue_request(Box::new(request));
396        
397        let results = device.process_requests();
398        assert_eq!(results.len(), 1);
399
400        let result = &results[0];
401        assert!(result.result.is_ok());
402
403        // str from buffer (trim \0)
404        let buffer = &result.request.buffer;
405        let buffer_str = core::str::from_utf8(buffer).unwrap_or("Invalid UTF-8").trim_matches(char::from(0));
406        assert_eq!(buffer_str, "Hello, world!");
407    }
408}