kernel/drivers/block/
virtio_blk.rs1use alloc::{boxed::Box, vec::Vec};
27use alloc::vec;
28use spin::{Mutex, RwLock};
29
30use core::{mem, ptr};
31
32use crate::defer;
33use crate::device::{Device, DeviceType};
34use crate::{
35 device::block::{request::{BlockIORequest, BlockIORequestType, BlockIOResult}, BlockDevice},
36 drivers::virtio::{device::{Register, VirtioDevice}, queue::{DescriptorFlag, VirtQueue}}
37};
38
39const VIRTIO_BLK_T_IN: u32 = 0; const VIRTIO_BLK_T_OUT: u32 = 1; const VIRTIO_BLK_S_OK: u8 = 0;
46const VIRTIO_BLK_S_IOERR: u8 = 1;
47const VIRTIO_BLK_S_UNSUPP: u8 = 2;
48
49const VIRTIO_BLK_F_RO: u32 = 5;
54const VIRTIO_BLK_F_BLK_SIZE: u32 = 6;
55const VIRTIO_BLK_F_SCSI: u32 = 7;
56const VIRTIO_BLK_F_CONFIG_WCE: u32 = 11;
58const VIRTIO_BLK_F_MQ: u32 = 12;
59const VIRTIO_F_ANY_LAYOUT: u32 = 27;
60const VIRTIO_RING_F_INDIRECT_DESC: u32 = 28;
61const VIRTIO_RING_F_EVENT_IDX: u32 = 29;
62
63#[repr(C)]
72pub struct VirtioBlkConfig {
73 pub capacity: u64,
74 pub size_max: u32,
75 pub seg_max: u32,
76 pub geometry: VirtioBlkGeometry,
77 pub blk_size: u32,
78 pub topology: VirtioBlkTopology,
79 pub writeback: u8,
80}
81
82#[repr(C)]
83pub struct VirtioBlkGeometry {
84 pub cylinders: u16,
85 pub heads: u8,
86 pub sectors: u8,
87}
88
89#[repr(C)]
90pub struct VirtioBlkTopology {
91 pub physical_block_exp: u8,
92 pub alignment_offset: u8,
93 pub min_io_size: u16,
94 pub opt_io_size: u32,
95}
96
97#[repr(C)]
98pub struct VirtioBlkReqHeader {
99 pub type_: u32,
100 pub reserved: u32,
101 pub sector: u64,
102}
103
104pub struct VirtioBlockDevice {
105 base_addr: usize,
106 virtqueues: Mutex<[VirtQueue<'static>; 1]>, capacity: RwLock<u64>,
108 sector_size: RwLock<u32>,
109 features: RwLock<u32>,
110 read_only: RwLock<bool>,
111 request_queue: Mutex<Vec<Box<BlockIORequest>>>,
112}
113
114impl VirtioBlockDevice {
115 pub fn new(base_addr: usize) -> Self {
116 let mut device = Self {
117 base_addr,
118 virtqueues: Mutex::new([VirtQueue::new(8)]),
119 capacity: RwLock::new(0),
120 sector_size: RwLock::new(512), features: RwLock::new(0),
122 read_only: RwLock::new(false),
123 request_queue: Mutex::new(Vec::new()),
124 };
125
126 if device.init().is_err() {
128 panic!("Failed to initialize Virtio Block Device");
129 }
130
131 *device.capacity.write() = device.read_config::<u64>(0); let features = device.read32_register(Register::DeviceFeatures);
136 *device.features.write() = features;
137
138 if features & (1 << VIRTIO_BLK_F_BLK_SIZE) != 0 {
140 *device.sector_size.write() = device.read_config::<u32>(20); }
142
143 *device.read_only.write() = features & (1 << VIRTIO_BLK_F_RO) != 0;
145
146 device
147 }
148
149 fn process_request(&self, req: &mut BlockIORequest) -> Result<(), &'static str> {
150 let header = Box::new(VirtioBlkReqHeader {
152 type_: match req.request_type {
153 BlockIORequestType::Read => VIRTIO_BLK_T_IN,
154 BlockIORequestType::Write => VIRTIO_BLK_T_OUT,
155 },
156 reserved: 0,
157 sector: req.sector as u64,
158 });
159 let data = vec![0u8; req.buffer.len()].into_boxed_slice();
160 let status = Box::new(0u8);
161
162 let header_ptr = Box::into_raw(header);
164 let data_ptr = Box::into_raw(data) as *mut [u8];
165 let status_ptr = Box::into_raw(status);
166
167 defer! {
168 unsafe {
170 drop(Box::from_raw(header_ptr));
171 drop(Box::from_raw(data_ptr));
172 drop(Box::from_raw(status_ptr));
173 }
174 }
175
176 unsafe {
178 if let BlockIORequestType::Write = req.request_type {
180 ptr::copy_nonoverlapping(
181 req.buffer.as_ptr(),
182 data_ptr as *mut u8,
183 req.buffer.len()
184 );
185 }
186 }
187
188 let mut virtqueues = self.virtqueues.lock();
190
191 let header_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
193 let data_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
194 let status_desc = virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
195
196 virtqueues[0].desc[header_desc].addr = (header_ptr as usize) as u64;
198 virtqueues[0].desc[header_desc].len = mem::size_of::<VirtioBlkReqHeader>() as u32;
199 virtqueues[0].desc[header_desc].flags = DescriptorFlag::Next as u16;
200 virtqueues[0].desc[header_desc].next = data_desc as u16;
201
202 virtqueues[0].desc[data_desc].addr = (data_ptr as *mut u8 as usize) as u64;
204 virtqueues[0].desc[data_desc].len = req.buffer.len() as u32;
205
206 match req.request_type {
208 BlockIORequestType::Read => {
209 DescriptorFlag::Next.set(&mut virtqueues[0].desc[data_desc].flags);
210 DescriptorFlag::Write.set(&mut virtqueues[0].desc[data_desc].flags);
211 },
212 BlockIORequestType::Write => {
213 DescriptorFlag::Next.set(&mut virtqueues[0].desc[data_desc].flags);
214 }
215 }
216
217 virtqueues[0].desc[data_desc].next = status_desc as u16;
218
219 virtqueues[0].desc[status_desc].addr = (status_ptr as usize) as u64;
221 virtqueues[0].desc[status_desc].len = 1;
222 virtqueues[0].desc[status_desc].flags |= DescriptorFlag::Write as u16;
223
224 virtqueues[0].push(header_desc)?;
226
227 self.notify(0);
229
230 while virtqueues[0].is_busy() {}
232 while *virtqueues[0].used.idx as usize == virtqueues[0].last_used_idx {}
233
234 let desc_idx = virtqueues[0].pop().ok_or("No response from device")?;
236 if desc_idx != header_desc {
237 return Err("Invalid descriptor index");
238 }
239
240 let status_val = unsafe { *status_ptr };
242 match status_val {
243 VIRTIO_BLK_S_OK => {
244 if let BlockIORequestType::Read = req.request_type {
246 unsafe {
247 req.buffer.clear();
248 req.buffer.extend_from_slice(core::slice::from_raw_parts(
249 data_ptr as *const u8,
250 virtqueues[0].desc[data_desc].len as usize
251 ));
252 }
253 }
254 Ok(())
255 },
256 VIRTIO_BLK_S_IOERR => Err("I/O error"),
257 VIRTIO_BLK_S_UNSUPP => Err("Unsupported request"),
258 _ => Err("Unknown error"),
259 }
260 }
261}
262
263impl Device for VirtioBlockDevice {
264 fn device_type(&self) -> DeviceType {
265 DeviceType::Block
266 }
267
268 fn name(&self) -> &'static str {
269 "virtio-blk"
270 }
271
272 fn as_any(&self) -> &dyn core::any::Any {
273 self
274 }
275
276 fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
277 self
278 }
279
280 fn as_block_device(&self) -> Option<&dyn crate::device::block::BlockDevice> {
281 Some(self)
282 }
283}
284
285impl VirtioDevice for VirtioBlockDevice {
286 fn get_base_addr(&self) -> usize {
287 self.base_addr
288 }
289
290 fn get_virtqueue_count(&self) -> usize {
291 1 }
293
294 fn get_supported_features(&self, device_features: u32) -> u32 {
295 device_features & !(1 << VIRTIO_BLK_F_RO |
297 1 << VIRTIO_BLK_F_SCSI |
298 1 << VIRTIO_BLK_F_CONFIG_WCE |
299 1 << VIRTIO_BLK_F_MQ |
300 1 << VIRTIO_F_ANY_LAYOUT |
301 1 << VIRTIO_RING_F_EVENT_IDX |
302 1 << VIRTIO_RING_F_INDIRECT_DESC)
303 }
304
305 fn get_queue_desc_addr(&self, queue_idx: usize) -> Option<u64> {
306 if queue_idx >= 1 {
307 return None;
308 }
309
310 let virtqueues = self.virtqueues.lock();
311 Some(virtqueues[queue_idx].get_raw_ptr() as u64)
312 }
313
314 fn get_queue_driver_addr(&self, queue_idx: usize) -> Option<u64> {
315 if queue_idx >= 1 {
316 return None;
317 }
318
319 let virtqueues = self.virtqueues.lock();
320 Some(virtqueues[queue_idx].avail.flags as *const _ as u64)
321 }
322
323 fn get_queue_device_addr(&self, queue_idx: usize) -> Option<u64> {
324 if queue_idx >= 1 {
325 return None;
326 }
327
328 let virtqueues = self.virtqueues.lock();
329 Some(virtqueues[queue_idx].used.flags as *const _ as u64)
330 }
331}
332
333impl BlockDevice for VirtioBlockDevice {
334 fn get_disk_name(&self) -> &'static str {
335 "virtio-blk"
336 }
337
338 fn get_disk_size(&self) -> usize {
339 let capacity = *self.capacity.read();
340 let sector_size = *self.sector_size.read();
341 (capacity * sector_size as u64) as usize
342 }
343
344 fn enqueue_request(&self, request: Box<BlockIORequest>) {
345 self.request_queue.lock().push(request);
347 }
348
349 fn process_requests(&self) -> Vec<BlockIOResult> {
350 let mut results = Vec::new();
351 let mut queue = self.request_queue.lock();
352 while let Some(mut request) = queue.pop() {
353 drop(queue); let result = self.process_request(&mut *request);
355 results.push(BlockIOResult { request, result });
356 queue = self.request_queue.lock(); }
358
359 results
360 }
361}
362
363#[cfg(test)]
364pub mod tests {
365 use super::*;
366 use alloc::vec;
367
368 #[test_case]
369 fn test_virtio_block_device_init() {
370 let base_addr = 0x10001000; let device = VirtioBlockDevice::new(base_addr);
372
373 assert_eq!(device.get_disk_name(), "virtio-blk");
374 assert_eq!(device.get_disk_size(), (*device.capacity.read() * *device.sector_size.read() as u64) as usize);
375 }
376
377 #[test_case]
378 fn test_virtio_block_device() {
379 let base_addr = 0x10001000; let device = VirtioBlockDevice::new(base_addr);
381
382 assert_eq!(device.get_disk_name(), "virtio-blk");
383 assert_eq!(device.get_disk_size(), (*device.capacity.read() * *device.sector_size.read() as u64) as usize);
384
385 let sector_size = *device.sector_size.read();
387 let request = BlockIORequest {
388 request_type: BlockIORequestType::Read,
389 sector: 0,
390 sector_count: 1,
391 head: 0,
392 cylinder: 0,
393 buffer: vec![0; sector_size as usize],
394 };
395 device.enqueue_request(Box::new(request));
396
397 let results = device.process_requests();
398 assert_eq!(results.len(), 1);
399
400 let result = &results[0];
401 assert!(result.result.is_ok());
402
403 let buffer = &result.request.buffer;
405 let buffer_str = core::str::from_utf8(buffer).unwrap_or("Invalid UTF-8").trim_matches(char::from(0));
406 assert_eq!(buffer_str, "Hello, world!");
407 }
408}