draw_keypoints()The task was to R-ise and vectorise the TODO block in vision_utils.R
lines 476–492, implementing connectivity line drawing between
keypoints.
library(torch)
library(magick)
library(png)
source("hard_test_1_draw_keypoints/vision_utils_patched.R")
set.seed(42)
image <- torch_randint(190L, 255L, size = c(3L, 360L, 360L))$to(torch_uint8())
keypoints <- torch_randint(low = 60L, high = 300L, size = c(4L, 5L, 2L))
connectivity <- list(c(1L,2L), c(2L,3L), c(3L,4L), c(4L,5L))
result <- draw_keypoints(image, keypoints,
connectivity = connectivity,
colors = c("red", "blue", "green", "orange"),
radius = 4, width = 2)
cat("Output shape:", paste(result$shape, collapse = "x"), "\n")
## Output shape: 3x360x360
cat("Output dtype: uint8 =", result$dtype == torch_uint8(), "\n")
## Output dtype: uint8 = TRUE
arr <- result$to(dtype = torch_float())$div(255)$permute(c(2L,3L,1L))$to(device="cpu") |> as.array()
grid::grid.raster(arr)
4 pose instances with keypoints and skeleton connectivity lines
# Shape preserved
stopifnot(all(result$shape == c(3L, 360L, 360L)))
# dtype is uint8
stopifnot(result$dtype == torch_uint8())
# Pixels were actually drawn (output differs from input)
diff <- (result$to(dtype=torch_float()) - image$to(dtype=torch_float()))$abs()$sum()$item()
cat("Total pixels modified:", diff, "\n")
## Total pixels modified: 1523616
stopifnot(diff > 0)
# Float input also works
image_f <- image$to(dtype = torch_float())$div(255)
result_f <- draw_keypoints(image_f, keypoints, connectivity = connectivity)
stopifnot(result_f$dtype == torch_uint8())
cat("All Hard Test 1 assertions PASSED\n")
## All Hard Test 1 assertions PASSED
The task was to demonstrate exposing a C++ function to R via Rcpp.
Two functions were implemented: nms_cpp() (Non-Maximum
Suppression) and keypoint_distances_cpp() (pairwise
keypoint distances).
library(Rcpp)
Rcpp::sourceCpp("hard_test_2_rcpp/keypoint_ops.cpp")
cat("C++ compiled and loaded successfully\n")
## C++ compiled and loaded successfully
boxes <- matrix(c(
10, 10, 60, 60,
12, 12, 62, 62,
15, 15, 55, 55,
200, 200, 250, 250,
202, 202, 252, 252,
400, 100, 480, 180
), ncol = 4, byrow = TRUE)
scores <- c(0.90, 0.75, 0.60, 0.95, 0.85, 0.70)
kept <- nms_cpp(boxes, scores, iou_threshold = 0.5)
cat("Kept box indices:", kept, "\n")
## Kept box indices: 4 1 6
cat("Suppressed:", setdiff(seq_len(nrow(boxes)), kept), "\n")
## Suppressed: 2 3 5
stopifnot(setequal(kept, c(1L, 4L, 6L)))
cat("NMS assertion PASSED\n")
## NMS assertion PASSED
kpts <- matrix(c(0,0, 0,80, 60,80, 30,160), ncol = 2, byrow = TRUE)
rownames(kpts) <- c("head", "l_shoulder", "r_shoulder", "hip")
D <- keypoint_distances_cpp(kpts)
round(D, 2)
## head l_shoulder r_shoulder hip
## head 0.00 80.00 100.00 162.79
## l_shoulder 80.00 0.00 60.00 85.44
## r_shoulder 100.00 60.00 0.00 85.44
## hip 162.79 85.44 85.44 0.00
stopifnot(abs(D["head", "l_shoulder"] - 80) < 1e-9)
stopifnot(abs(D["l_shoulder", "r_shoulder"] - 60) < 1e-9)
stopifnot(abs(D["head", "r_shoulder"] - 100) < 1e-9)
stopifnot(all(diag(D) == 0))
stopifnot(all(D == t(D)))
cat("Distance assertions PASSED\n")
## Distance assertions PASSED
set.seed(42)
N <- 1000L
x1 <- runif(N, 0, 500); y1 <- runif(N, 0, 500)
x2 <- x1 + runif(N, 10, 100); y2 <- y1 + runif(N, 10, 100)
t0 <- proc.time()
kept <- nms_cpp(cbind(x1, y1, x2, y2), runif(N))
elapsed <- (proc.time() - t0)[["elapsed"]]
cat(sprintf("NMS on %d boxes: kept %d | elapsed %.4f s\n", N, length(kept), elapsed))
## NMS on 1000 boxes: kept 782 | elapsed 0.0080 s
The task was to write an article demonstrating
model_fasterrcnn_resnet50_fpn_v2() in the style of the
existing fcnresnet vignette. Key discoveries during
development:
torch_stack(), not a plain R listoutput$detections[[i]], each with
$boxes, $labels, $scoresscore_thresh = 0.01 needed (model yields low scores
with current torch version)valid_boxes() required to filter degenerate boxes
produced at low confidencesafe_labels() with tryCatch fallback for
when coco_label() needs networklibrary(torch)
library(torchvision)
url1 <- "https://raw.githubusercontent.com/pytorch/vision/main/gallery/assets/dog1.jpg"
url2 <- "https://raw.githubusercontent.com/pytorch/vision/main/gallery/assets/dog2.jpg"
dog1 <- magick_loader(url1) %>% transform_to_tensor()
dog2 <- magick_loader(url2) %>% transform_to_tensor()
dog_batch <- torch_stack(list(dog1, dog2))
cat("Batch shape:", paste(dim(dog_batch), collapse = "x"), "\n")
## Batch shape: 2x3x500x500
model <- model_fasterrcnn_resnet50_fpn_v2(pretrained = TRUE,
score_thresh = 0.01,
nms_thresh = 0.5)
model$eval()
output <- model(dog_batch)
det1 <- output$detections[[1]]
det2 <- output$detections[[2]]
cat("Image 1:", as.integer(det1$boxes$size(1)), "detections\n")
## Image 1: 100 detections
cat("Image 2:", as.integer(det2$boxes$size(1)), "detections\n")
## Image 2: 100 detections
top_k <- function(det, k = 5) {
n <- as.integer(det$boxes$size(1))
if (n == 0) return(det)
idx <- seq_len(min(k, n))
list(boxes = det$boxes[idx, ], labels = det$labels[idx], scores = det$scores[idx])
}
# Appends confidence score to each label: "dog: 0.045"
safe_labels <- function(label_tensor, score_tensor = NULL) {
ids <- as.integer(label_tensor)
names <- tryCatch(coco_label(ids), error = function(e) paste0("class_", ids))
if (is.null(score_tensor)) return(names)
paste0(names, ": ", round(as.numeric(score_tensor), 3))
}
# Rejects boxes that are geometrically invalid or too small (area < min_area).
# Note: with torch < 0.17 the pretrained v2 weights produce degenerate
# boxes (x2 < x1) at near-uniform low scores (~0.026); this filters them out.
valid_boxes <- function(det, min_area = 400) {
n <- as.integer(det$boxes$size(1))
if (n == 0) return(det)
b <- as.matrix(det$boxes$to(device = "cpu") |> as.array())
area <- (b[, 3] - b[, 1]) * (b[, 4] - b[, 2])
keep <- which((b[, 3] > b[, 1]) & (b[, 4] > b[, 2]) & (area >= min_area))
if (length(keep) == 0) return(list(boxes = NULL, labels = NULL, scores = NULL))
list(boxes = det$boxes[keep, , drop = FALSE],
labels = det$labels[keep],
scores = det$scores[keep])
}
top1 <- top_k(det1); top2 <- top_k(det2)
cat("Labels for dog1:", paste(safe_labels(top1$labels), collapse = ", "), "\n")
## Labels for dog1: person, person, person, person, person
draw_detections <- function(image, det) {
det <- valid_boxes(det)
n <- if (is.null(det$labels)) 0L else length(as.integer(det$labels))
if (n == 0) {
message("No valid detections after filtering (torch version incompatibility).")
return((image * 255)$to(dtype = torch_uint8()))
}
draw_bounding_boxes(
(image * 255)$to(dtype = torch_uint8()),
boxes = det$boxes$view(c(-1L, 4L)),
labels = safe_labels(det$labels, det$scores)
)
}
result1 <- draw_detections(dog1, top1)
arr1 <- result1$to(dtype = torch_float())$div(255)$permute(c(2L, 3L, 1L)) |> as.array()
grid::grid.raster(arr1)
Dog 1 – top-5 detections (note: degenerate boxes filtered; model requires torch ≥ 0.17 for correct output)
result2 <- draw_detections(dog2, top2)
arr2 <- result2$to(dtype = torch_float())$div(255)$permute(c(2L, 3L, 1L)) |> as.array()
grid::grid.raster(arr2)
Dog 2 – top-5 detections (note: degenerate boxes filtered; model requires torch ≥ 0.17 for correct output)
model_v1 <- model_fasterrcnn_resnet50_fpn(pretrained = TRUE,
score_thresh = 0.01,
nms_thresh = 0.5)
model_v1$eval()
out_v1 <- model_v1(dog_batch)
cat("v1 detections on dog1:", as.integer(out_v1$detections[[1]]$boxes$size(1)), "\n")
## v1 detections on dog1: 100
cat("v2 detections on dog1:", as.integer(det1$boxes$size(1)), "\n")
## v2 detections on dog1: 100
Fixed three documentation bugs in draw_keypoints() in
R/vision_utils.R:
connectivity parameter: Changed
type annotation from Vector to List, removed
stale “(currently unavailable)” note, and updated the description with a
concrete usage example (list(c(1, 2), c(2, 3))).
colors parameter: Corrected
“viridis” → “rainbow” (the function uses
grDevices::rainbow()) and “boxes” → “keypoints”.
Error message typo: Fixed “but is current shape is” → “but current shape is”.
The fixes were applied to both R/vision_utils.R (roxygen
comments) and the generated man/draw_keypoints.Rd, with a
NEWS.md entry.
Pull request: mlverse/torchvision#296