kernel/drivers/block/
virtio_blk.rs

1//! # VirtIO Block Device Driver
2//! 
3//! This module provides a driver for VirtIO block devices, implementing the
4//! BlockDevice trait for integration with the kernel's block device subsystem.
5//!
6//! The driver supports basic block operations (read/write) and handles the VirtIO
7//! queue management for block device requests.
8//!
9//! ## Features Support
10//! 
11//! The driver checks for and handles the following VirtIO block device features:
12//! - `VIRTIO_BLK_F_BLK_SIZE`: Custom sector size
13//! - `VIRTIO_BLK_F_RO`: Read-only device detection
14//!
15//! ## Implementation Details
16//!
17//! The driver uses a single virtqueue for processing block I/O requests. Each request
18//! consists of three parts:
19//! 1. Request header (specifying operation type and sector)
20//! 2. Data buffer (for read/write content)
21//! 3. Status byte (for operation result)
22//!
23//! Requests are processed through the VirtIO descriptor chain mechanism, with proper
24//! memory management using Box allocations to ensure data remains valid during transfers.
25
26use alloc::{boxed::Box, vec::Vec};
27use alloc::vec;
28
29use core::{mem, ptr};
30
31use crate::defer;
32use crate::device::{Device, DeviceType};
33use crate::{
34    device::block::{request::{BlockIORequest, BlockIORequestType, BlockIOResult}, BlockDevice}, drivers::virtio::{device::{Register, VirtioDevice}, queue::{DescriptorFlag, VirtQueue}}
35};
36
37// VirtIO Block Request Type
38const VIRTIO_BLK_T_IN: u32 = 0;     // Read
39const VIRTIO_BLK_T_OUT: u32 = 1;    // Write
40// const VIRTIO_BLK_T_FLUSH: u32 = 4;  // Flush
41
42// VirtIO Block Status Codes
43const VIRTIO_BLK_S_OK: u8 = 0;
44const VIRTIO_BLK_S_IOERR: u8 = 1;
45const VIRTIO_BLK_S_UNSUPP: u8 = 2;
46
47// Device Feature bits
48// const VIRTIO_BLK_F_SIZE_MAX: u32 = 1;
49// const VIRTIO_BLK_F_SEG_MAX: u32 = 2;
50// const VIRTIO_BLK_F_GEOMETRY: u32 = 4;
51const VIRTIO_BLK_F_RO: u32 = 5;
52const VIRTIO_BLK_F_BLK_SIZE: u32 = 6;
53const VIRTIO_BLK_F_SCSI: u32 = 7;
54// const VIRTIO_BLK_F_FLUSH: u32 = 9;
55const VIRTIO_BLK_F_CONFIG_WCE: u32 = 11;
56const VIRTIO_BLK_F_MQ: u32 = 12;
57const VIRTIO_F_ANY_LAYOUT: u32 = 27;
58const VIRTIO_RING_F_INDIRECT_DESC: u32 = 28;
59const VIRTIO_RING_F_EVENT_IDX: u32 = 29;
60
61// #define VIRTIO_BLK_F_RO              5	/* Disk is read-only */
62// #define VIRTIO_BLK_F_SCSI            7	/* Supports scsi command passthru */
63// #define VIRTIO_BLK_F_CONFIG_WCE     11	/* Writeback mode available in config */
64// #define VIRTIO_BLK_F_MQ             12	/* support more than one vq */
65// #define VIRTIO_F_ANY_LAYOUT         27
66// #define VIRTIO_RING_F_INDIRECT_DESC 28
67// #define VIRTIO_RING_F_EVENT_IDX     29
68
69#[repr(C)]
70pub struct VirtioBlkConfig {
71    pub capacity: u64,
72    pub size_max: u32,
73    pub seg_max: u32,
74    pub geometry: VirtioBlkGeometry,
75    pub blk_size: u32,
76    pub topology: VirtioBlkTopology,
77    pub writeback: u8,
78}
79
80#[repr(C)]
81pub struct VirtioBlkGeometry {
82    pub cylinders: u16,
83    pub heads: u8,
84    pub sectors: u8,
85}
86
87#[repr(C)]
88pub struct VirtioBlkTopology {
89    pub physical_block_exp: u8,
90    pub alignment_offset: u8,
91    pub min_io_size: u16,
92    pub opt_io_size: u32,
93}
94
95#[repr(C)]
96pub struct VirtioBlkReqHeader {
97    pub type_: u32,
98    pub reserved: u32,
99    pub sector: u64,
100}
101
102pub struct VirtioBlockDevice {
103    base_addr: usize,
104    virtqueues: [VirtQueue<'static>; 1], // Only one queue for request/response
105    capacity: u64,
106    sector_size: u32,
107    features: u32,
108    read_only: bool,
109    request_queue: Vec<Box<BlockIORequest>>,
110}
111
112impl VirtioBlockDevice {
113    pub fn new(base_addr: usize) -> Self {
114        let mut device = Self {
115            base_addr,
116            virtqueues: [VirtQueue::new(8)],
117            capacity: 0,
118            sector_size: 512, // Default sector size
119            features: 0,
120            read_only: false,
121            request_queue: Vec::new(),
122        };
123        
124        // Initialize the device
125        if device.init().is_err() {
126            panic!("Failed to initialize Virtio Block Device");
127        }
128
129        // Read device configuration
130        device.capacity = device.read_config::<u64>(0); // Capacity at offset 0
131
132        // Read device features
133        device.features = device.read32_register(Register::DeviceFeatures);
134        device.sector_size = 0;
135        
136        // Check if block size feature is supported
137        if device.features & (1 << VIRTIO_BLK_F_BLK_SIZE) != 0 {
138            device.sector_size = device.read_config::<u32>(20); // blk_size at offset 20
139        }
140        
141        // Check if device is read-only
142        device.read_only = device.features & (1 << VIRTIO_BLK_F_RO) != 0;
143
144        device
145    }
146    
147    fn process_request(&mut self, req: &mut BlockIORequest) -> Result<(), &'static str> {
148        // Allocate memory for request header, data, and status
149        let header = Box::new(VirtioBlkReqHeader {
150            type_: match req.request_type {
151                BlockIORequestType::Read => VIRTIO_BLK_T_IN,
152                BlockIORequestType::Write => VIRTIO_BLK_T_OUT,
153            },
154            reserved: 0,
155            sector: req.sector as u64,
156        });
157        let data = vec![0u8; req.buffer.len()].into_boxed_slice();
158        let status = Box::new(0u8);
159                
160        // Cast pages to appropriate types
161        let header_ptr = Box::into_raw(header);
162        let data_ptr = Box::into_raw(data) as *mut [u8];
163        let status_ptr = Box::into_raw(status);
164
165        defer! {
166            // Deallocate memory after use
167            unsafe {
168                drop(Box::from_raw(header_ptr));
169                drop(Box::from_raw(data_ptr));
170                drop(Box::from_raw(status_ptr));
171            }
172        }
173
174        // Set up request header
175        unsafe {
176            // Copy data for write requests
177            if let BlockIORequestType::Write = req.request_type {
178                ptr::copy_nonoverlapping(
179                    req.buffer.as_ptr(),
180                    data_ptr as *mut u8,
181                    req.buffer.len()
182                );
183            }
184        }
185        
186        // Allocate descriptors for the request
187        let header_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
188        let data_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
189        let status_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
190        
191        // Set up header descriptor
192        self.virtqueues[0].desc[header_desc].addr = (header_ptr as usize) as u64;
193        self.virtqueues[0].desc[header_desc].len = mem::size_of::<VirtioBlkReqHeader>() as u32;
194        self.virtqueues[0].desc[header_desc].flags = DescriptorFlag::Next as u16;
195        self.virtqueues[0].desc[header_desc].next = data_desc as u16;
196        
197        // Set up data descriptor
198        self.virtqueues[0].desc[data_desc].addr = (data_ptr as *mut u8 as usize) as u64;
199        self.virtqueues[0].desc[data_desc].len = req.buffer.len() as u32;
200        
201        // Set flags based on request type
202        match req.request_type {
203            BlockIORequestType::Read => {
204                // self.virtqueues[0].desc[data_desc].flags = 
205                //     DescriptorFlag::Next as u16 | DescriptorFlag::Write as u16;
206                DescriptorFlag::Next.set(&mut self.virtqueues[0].desc[data_desc].flags);
207                DescriptorFlag::Write.set(&mut self.virtqueues[0].desc[data_desc].flags);
208            },
209            BlockIORequestType::Write => {
210                // self.virtqueues[0].desc[data_desc].flags = DescriptorFlag::Next as u16;
211                DescriptorFlag::Next.set(&mut self.virtqueues[0].desc[data_desc].flags);
212            }
213        }
214        
215        self.virtqueues[0].desc[data_desc].next = status_desc as u16;
216        
217        // Set up status descriptor
218        self.virtqueues[0].desc[status_desc].addr = (status_ptr as usize) as u64;
219        self.virtqueues[0].desc[status_desc].len = 1;
220        self.virtqueues[0].desc[status_desc].flags |= DescriptorFlag::Write as u16;
221        
222        // Submit the request to the queue
223        self.virtqueues[0].push(header_desc)?;
224
225        // Notify the device
226        self.notify(0);
227        
228        // Wait for the response (polling)
229        while self.virtqueues[0].is_busy() {}
230        while *self.virtqueues[0].used.idx as usize == self.virtqueues[0].last_used_idx {}
231
232        // Process completed request
233        let desc_idx = self.virtqueues[0].pop().ok_or("No response from device")?;
234        if desc_idx != header_desc {
235            return Err("Invalid descriptor index");
236        }
237        
238        // Check status
239        let status_val = unsafe { *status_ptr };
240        match status_val {
241            VIRTIO_BLK_S_OK => {
242                // For read requests, copy data to the buffer
243                if let BlockIORequestType::Read = req.request_type {
244                    unsafe {
245                        req.buffer.clear();
246                        req.buffer.extend_from_slice(core::slice::from_raw_parts(
247                            data_ptr as *const u8,
248                            self.virtqueues[0].desc[data_desc].len as usize
249                        ));
250                    }
251                }
252                Ok(())
253            },
254            VIRTIO_BLK_S_IOERR => Err("I/O error"),
255            VIRTIO_BLK_S_UNSUPP => Err("Unsupported request"),
256            _ => Err("Unknown error"),
257        }
258    }
259}
260
261impl Device for VirtioBlockDevice {
262    fn device_type(&self) -> DeviceType {
263        DeviceType::Block
264    }
265    
266    fn name(&self) -> &'static str {
267        "virtio-blk"
268    }
269    
270    fn id(&self) -> usize {
271        self.base_addr
272    }
273    
274    fn as_any(&self) -> &dyn core::any::Any {
275        self
276    }
277    
278    fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
279        self
280    }
281    
282    fn as_block_device(&mut self) -> Option<&mut dyn crate::device::block::BlockDevice> {
283        Some(self)
284    }
285}
286
287impl VirtioDevice for VirtioBlockDevice {
288    fn get_base_addr(&self) -> usize {
289        self.base_addr
290    }
291    
292    fn get_virtqueue_count(&self) -> usize {
293        self.virtqueues.len()
294    }
295
296    fn get_virtqueue(&self, queue_idx: usize) -> &VirtQueue {
297        &self.virtqueues[queue_idx]
298    }
299    
300    fn get_supported_features(&self, device_features: u32) -> u32 {
301        // Accept most features but we might want to be selective
302        // device_features
303        // NOT Negotiated
304        // - VIRTIO_BLK_F_RO
305        // - VIRTIO_BLK_F_SCSI
306        // - VIRTIO_BLK_F_CONFIG_WCE
307        // - VIRTIO_BLK_F_MQ
308        // - VIRTIO_F_ANY_LAYOUT
309        // - VIRTIO_RING_F_EVENT_IDX
310        // - VIRTIO_RING_F_INDIRECT_DESC
311
312        device_features & !(1 << VIRTIO_BLK_F_RO |
313            1 << VIRTIO_BLK_F_SCSI |
314            1 << VIRTIO_BLK_F_CONFIG_WCE |
315            1 << VIRTIO_BLK_F_MQ |
316            1 << VIRTIO_F_ANY_LAYOUT |
317            1 << VIRTIO_RING_F_EVENT_IDX |
318            1 << VIRTIO_RING_F_INDIRECT_DESC)
319    }
320}
321
322impl BlockDevice for VirtioBlockDevice {
323    fn get_id(&self) -> usize {
324        self.base_addr // Use base address as ID
325    }
326    
327    fn get_disk_name(&self) -> &'static str {
328        "virtio-blk"
329    }
330    
331    fn get_disk_size(&self) -> usize {
332        (self.capacity * self.sector_size as u64) as usize
333    }
334    
335    fn enqueue_request(&mut self, request: Box<BlockIORequest>) {
336        // Enqueue the request
337        self.request_queue.push(request);
338    }
339    
340    fn process_requests(&mut self) -> Vec<BlockIOResult> {
341        let mut results = Vec::new();
342        while let Some(mut request) = self.request_queue.pop() {
343            let result = self.process_request(&mut *request);
344            results.push(BlockIOResult { request, result });
345        }
346        
347        results
348    }
349}
350
351#[cfg(test)]
352pub mod tests {
353    use super::*;
354    use alloc::vec;
355
356    #[test_case]
357    fn test_virtio_block_device_init() {
358        let base_addr = 0x10001000; // Example base address
359        let device = VirtioBlockDevice::new(base_addr);
360        
361        assert_eq!(device.get_id(), base_addr);
362        assert_eq!(device.get_disk_name(), "virtio-blk");
363        assert_eq!(device.get_disk_size(), (device.capacity * device.sector_size as u64) as usize);
364    }
365    
366    #[test_case]
367    fn test_virtio_block_device() {
368        let base_addr = 0x10001000; // Example base address
369        let mut device = VirtioBlockDevice::new(base_addr);
370        
371        assert_eq!(device.get_id(), base_addr);
372        assert_eq!(device.get_disk_name(), "virtio-blk");
373        assert_eq!(device.get_disk_size(), (device.capacity * device.sector_size as u64) as usize);
374        
375        // Test enqueue and process requests
376        let request = BlockIORequest {
377            request_type: BlockIORequestType::Read,
378            sector: 0,
379            sector_count: 1,
380            head: 0,
381            cylinder: 0,
382            buffer: vec![0; device.sector_size as usize],
383        };
384        device.enqueue_request(Box::new(request));
385        
386        let results = device.process_requests();
387        assert_eq!(results.len(), 1);
388
389        let result = &results[0];
390        assert!(result.result.is_ok());
391
392        // str from buffer (trim \0)
393        let buffer = &result.request.buffer;
394        let buffer_str = core::str::from_utf8(buffer).unwrap_or("Invalid UTF-8").trim_matches(char::from(0));
395        assert_eq!(buffer_str, "Hello, world!");
396    }
397}