kernel/drivers/block/
virtio_blk.rs1use alloc::{boxed::Box, vec::Vec};
27use alloc::vec;
28
29use core::{mem, ptr};
30
31use crate::defer;
32use crate::device::{Device, DeviceType};
33use crate::{
34 device::block::{request::{BlockIORequest, BlockIORequestType, BlockIOResult}, BlockDevice}, drivers::virtio::{device::{Register, VirtioDevice}, queue::{DescriptorFlag, VirtQueue}}
35};
36
37const VIRTIO_BLK_T_IN: u32 = 0; const VIRTIO_BLK_T_OUT: u32 = 1; const VIRTIO_BLK_S_OK: u8 = 0;
44const VIRTIO_BLK_S_IOERR: u8 = 1;
45const VIRTIO_BLK_S_UNSUPP: u8 = 2;
46
47const VIRTIO_BLK_F_RO: u32 = 5;
52const VIRTIO_BLK_F_BLK_SIZE: u32 = 6;
53const VIRTIO_BLK_F_SCSI: u32 = 7;
54const VIRTIO_BLK_F_CONFIG_WCE: u32 = 11;
56const VIRTIO_BLK_F_MQ: u32 = 12;
57const VIRTIO_F_ANY_LAYOUT: u32 = 27;
58const VIRTIO_RING_F_INDIRECT_DESC: u32 = 28;
59const VIRTIO_RING_F_EVENT_IDX: u32 = 29;
60
61#[repr(C)]
70pub struct VirtioBlkConfig {
71 pub capacity: u64,
72 pub size_max: u32,
73 pub seg_max: u32,
74 pub geometry: VirtioBlkGeometry,
75 pub blk_size: u32,
76 pub topology: VirtioBlkTopology,
77 pub writeback: u8,
78}
79
80#[repr(C)]
81pub struct VirtioBlkGeometry {
82 pub cylinders: u16,
83 pub heads: u8,
84 pub sectors: u8,
85}
86
87#[repr(C)]
88pub struct VirtioBlkTopology {
89 pub physical_block_exp: u8,
90 pub alignment_offset: u8,
91 pub min_io_size: u16,
92 pub opt_io_size: u32,
93}
94
95#[repr(C)]
96pub struct VirtioBlkReqHeader {
97 pub type_: u32,
98 pub reserved: u32,
99 pub sector: u64,
100}
101
102pub struct VirtioBlockDevice {
103 base_addr: usize,
104 virtqueues: [VirtQueue<'static>; 1], capacity: u64,
106 sector_size: u32,
107 features: u32,
108 read_only: bool,
109 request_queue: Vec<Box<BlockIORequest>>,
110}
111
112impl VirtioBlockDevice {
113 pub fn new(base_addr: usize) -> Self {
114 let mut device = Self {
115 base_addr,
116 virtqueues: [VirtQueue::new(8)],
117 capacity: 0,
118 sector_size: 512, features: 0,
120 read_only: false,
121 request_queue: Vec::new(),
122 };
123
124 if device.init().is_err() {
126 panic!("Failed to initialize Virtio Block Device");
127 }
128
129 device.capacity = device.read_config::<u64>(0); device.features = device.read32_register(Register::DeviceFeatures);
134 device.sector_size = 0;
135
136 if device.features & (1 << VIRTIO_BLK_F_BLK_SIZE) != 0 {
138 device.sector_size = device.read_config::<u32>(20); }
140
141 device.read_only = device.features & (1 << VIRTIO_BLK_F_RO) != 0;
143
144 device
145 }
146
147 fn process_request(&mut self, req: &mut BlockIORequest) -> Result<(), &'static str> {
148 let header = Box::new(VirtioBlkReqHeader {
150 type_: match req.request_type {
151 BlockIORequestType::Read => VIRTIO_BLK_T_IN,
152 BlockIORequestType::Write => VIRTIO_BLK_T_OUT,
153 },
154 reserved: 0,
155 sector: req.sector as u64,
156 });
157 let data = vec![0u8; req.buffer.len()].into_boxed_slice();
158 let status = Box::new(0u8);
159
160 let header_ptr = Box::into_raw(header);
162 let data_ptr = Box::into_raw(data) as *mut [u8];
163 let status_ptr = Box::into_raw(status);
164
165 defer! {
166 unsafe {
168 drop(Box::from_raw(header_ptr));
169 drop(Box::from_raw(data_ptr));
170 drop(Box::from_raw(status_ptr));
171 }
172 }
173
174 unsafe {
176 if let BlockIORequestType::Write = req.request_type {
178 ptr::copy_nonoverlapping(
179 req.buffer.as_ptr(),
180 data_ptr as *mut u8,
181 req.buffer.len()
182 );
183 }
184 }
185
186 let header_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
188 let data_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
189 let status_desc = self.virtqueues[0].alloc_desc().ok_or("Failed to allocate descriptor")?;
190
191 self.virtqueues[0].desc[header_desc].addr = (header_ptr as usize) as u64;
193 self.virtqueues[0].desc[header_desc].len = mem::size_of::<VirtioBlkReqHeader>() as u32;
194 self.virtqueues[0].desc[header_desc].flags = DescriptorFlag::Next as u16;
195 self.virtqueues[0].desc[header_desc].next = data_desc as u16;
196
197 self.virtqueues[0].desc[data_desc].addr = (data_ptr as *mut u8 as usize) as u64;
199 self.virtqueues[0].desc[data_desc].len = req.buffer.len() as u32;
200
201 match req.request_type {
203 BlockIORequestType::Read => {
204 DescriptorFlag::Next.set(&mut self.virtqueues[0].desc[data_desc].flags);
207 DescriptorFlag::Write.set(&mut self.virtqueues[0].desc[data_desc].flags);
208 },
209 BlockIORequestType::Write => {
210 DescriptorFlag::Next.set(&mut self.virtqueues[0].desc[data_desc].flags);
212 }
213 }
214
215 self.virtqueues[0].desc[data_desc].next = status_desc as u16;
216
217 self.virtqueues[0].desc[status_desc].addr = (status_ptr as usize) as u64;
219 self.virtqueues[0].desc[status_desc].len = 1;
220 self.virtqueues[0].desc[status_desc].flags |= DescriptorFlag::Write as u16;
221
222 self.virtqueues[0].push(header_desc)?;
224
225 self.notify(0);
227
228 while self.virtqueues[0].is_busy() {}
230 while *self.virtqueues[0].used.idx as usize == self.virtqueues[0].last_used_idx {}
231
232 let desc_idx = self.virtqueues[0].pop().ok_or("No response from device")?;
234 if desc_idx != header_desc {
235 return Err("Invalid descriptor index");
236 }
237
238 let status_val = unsafe { *status_ptr };
240 match status_val {
241 VIRTIO_BLK_S_OK => {
242 if let BlockIORequestType::Read = req.request_type {
244 unsafe {
245 req.buffer.clear();
246 req.buffer.extend_from_slice(core::slice::from_raw_parts(
247 data_ptr as *const u8,
248 self.virtqueues[0].desc[data_desc].len as usize
249 ));
250 }
251 }
252 Ok(())
253 },
254 VIRTIO_BLK_S_IOERR => Err("I/O error"),
255 VIRTIO_BLK_S_UNSUPP => Err("Unsupported request"),
256 _ => Err("Unknown error"),
257 }
258 }
259}
260
261impl Device for VirtioBlockDevice {
262 fn device_type(&self) -> DeviceType {
263 DeviceType::Block
264 }
265
266 fn name(&self) -> &'static str {
267 "virtio-blk"
268 }
269
270 fn id(&self) -> usize {
271 self.base_addr
272 }
273
274 fn as_any(&self) -> &dyn core::any::Any {
275 self
276 }
277
278 fn as_any_mut(&mut self) -> &mut dyn core::any::Any {
279 self
280 }
281
282 fn as_block_device(&mut self) -> Option<&mut dyn crate::device::block::BlockDevice> {
283 Some(self)
284 }
285}
286
287impl VirtioDevice for VirtioBlockDevice {
288 fn get_base_addr(&self) -> usize {
289 self.base_addr
290 }
291
292 fn get_virtqueue_count(&self) -> usize {
293 self.virtqueues.len()
294 }
295
296 fn get_virtqueue(&self, queue_idx: usize) -> &VirtQueue {
297 &self.virtqueues[queue_idx]
298 }
299
300 fn get_supported_features(&self, device_features: u32) -> u32 {
301 device_features & !(1 << VIRTIO_BLK_F_RO |
313 1 << VIRTIO_BLK_F_SCSI |
314 1 << VIRTIO_BLK_F_CONFIG_WCE |
315 1 << VIRTIO_BLK_F_MQ |
316 1 << VIRTIO_F_ANY_LAYOUT |
317 1 << VIRTIO_RING_F_EVENT_IDX |
318 1 << VIRTIO_RING_F_INDIRECT_DESC)
319 }
320}
321
322impl BlockDevice for VirtioBlockDevice {
323 fn get_id(&self) -> usize {
324 self.base_addr }
326
327 fn get_disk_name(&self) -> &'static str {
328 "virtio-blk"
329 }
330
331 fn get_disk_size(&self) -> usize {
332 (self.capacity * self.sector_size as u64) as usize
333 }
334
335 fn enqueue_request(&mut self, request: Box<BlockIORequest>) {
336 self.request_queue.push(request);
338 }
339
340 fn process_requests(&mut self) -> Vec<BlockIOResult> {
341 let mut results = Vec::new();
342 while let Some(mut request) = self.request_queue.pop() {
343 let result = self.process_request(&mut *request);
344 results.push(BlockIOResult { request, result });
345 }
346
347 results
348 }
349}
350
351#[cfg(test)]
352pub mod tests {
353 use super::*;
354 use alloc::vec;
355
356 #[test_case]
357 fn test_virtio_block_device_init() {
358 let base_addr = 0x10001000; let device = VirtioBlockDevice::new(base_addr);
360
361 assert_eq!(device.get_id(), base_addr);
362 assert_eq!(device.get_disk_name(), "virtio-blk");
363 assert_eq!(device.get_disk_size(), (device.capacity * device.sector_size as u64) as usize);
364 }
365
366 #[test_case]
367 fn test_virtio_block_device() {
368 let base_addr = 0x10001000; let mut device = VirtioBlockDevice::new(base_addr);
370
371 assert_eq!(device.get_id(), base_addr);
372 assert_eq!(device.get_disk_name(), "virtio-blk");
373 assert_eq!(device.get_disk_size(), (device.capacity * device.sector_size as u64) as usize);
374
375 let request = BlockIORequest {
377 request_type: BlockIORequestType::Read,
378 sector: 0,
379 sector_count: 1,
380 head: 0,
381 cylinder: 0,
382 buffer: vec![0; device.sector_size as usize],
383 };
384 device.enqueue_request(Box::new(request));
385
386 let results = device.process_requests();
387 assert_eq!(results.len(), 1);
388
389 let result = &results[0];
390 assert!(result.result.is_ok());
391
392 let buffer = &result.request.buffer;
394 let buffer_str = core::str::from_utf8(buffer).unwrap_or("Invalid UTF-8").trim_matches(char::from(0));
395 assert_eq!(buffer_str, "Hello, world!");
396 }
397}