diff --git a/server/src/controllers/note/note.post.controller.ts b/server/src/controllers/note/note.post.controller.ts index 0344855..c554037 100644 --- a/server/src/controllers/note/note.post.controller.ts +++ b/server/src/controllers/note/note.post.controller.ts @@ -17,20 +17,19 @@ import { ValidateNested, } from "class-validator"; import prisma from "../../db/client"; -import { createEmbed, EncryptedEmbedDTO } from "../../db/embed.dao"; export class EncryptedEmbedBody { @IsBase64() @IsNotEmpty() - ciphertext: string | undefined; + ciphertext!: string; @IsBase64() @IsNotEmpty() - hmac: string | undefined; + hmac!: string; @IsString() @IsNotEmpty() - embed_id: string | undefined; + embed_id!: string; } /** @@ -115,25 +114,7 @@ export async function postNoteController( // Store note object and possible embeds in database transaction try { - const savedNote = await prisma.$transaction(async () => { - // 1. Save note - const savedNote = await createNote(note); - - // 2. Store embeds - const embeds: EncryptedEmbedDTO[] = noteEmbedRequests.map( - (embed) => - ({ - ...embed, - note_id: savedNote.id, - } as EncryptedEmbedDTO) - ); - embeds.forEach(async (embed) => { - await createEmbed(embed); - }); - - // 3. Finalize transaction - return savedNote; - }); + const savedNote = await createNote(note, noteEmbedRequests); // Log write event event.success = true; @@ -148,28 +129,15 @@ export async function postNoteController( expire_time: savedNote.expire_time, }); } catch (err: any) { + // if the error matches "Duplicate embed", return a 409 conflict event.error = err.toString(); await EventLogger.writeEvent(event); - next(err); + if (err.message.includes("Duplicate embed")) { + res.status(409).send(err.message); + } else { + next(err); + } } - - createNote(note) - .then(async (savedNote) => { - // event.success = true; - // event.note_id = savedNote.id; - // event.size_bytes = savedNote.ciphertext.length + savedNote.hmac.length; - // event.expire_window_days = EXPIRE_WINDOW_DAYS; - // await EventLogger.writeEvent(event); - // res.json({ - // view_url: `${process.env.FRONTEND_URL}/note/${savedNote.id}`, - // expire_time: savedNote.expire_time, - // }); - }) - .catch(async (err) => { - // event.error = err.toString(); - // await EventLogger.writeEvent(event); - // next(err); - }); } /** diff --git a/server/src/controllers/note/note.post.controller.unit.test.ts b/server/src/controllers/note/note.post.controller.unit.test.ts index cbabd64..59d7676 100644 --- a/server/src/controllers/note/note.post.controller.unit.test.ts +++ b/server/src/controllers/note/note.post.controller.unit.test.ts @@ -264,7 +264,7 @@ const TEST_PAYLOADS: TestParams[] = [ }, ], }, - expectedStatus: 400, + expectedStatus: 409, }, ]; @@ -279,15 +279,22 @@ describe("Execute test cases", () => { // database writes always succeed const storedEmbeds: string[] = []; - mockNoteDao.createNote.mockImplementation(async (note) => ({ - ...note, - id: MOCK_NOTE_ID, - insert_time: new Date(), - })); - mockEmbedDao.createEmbed.mockImplementation(async (embed) => { - if (storedEmbeds.find((s) => s === embed.note_id + embed.embed_id)) { - throw new Error("duplicate embed!"); + mockNoteDao.createNote.mockImplementation(async (note, embeds) => { + if (embeds && embeds.length > 0) { + for (const e of embeds) { + if (storedEmbeds.find((s) => s === MOCK_NOTE_ID + e.embed_id)) { + throw new Error("Duplicate embed"); + } + storedEmbeds.push(MOCK_NOTE_ID + e.embed_id); + } } + return { + ...note, + id: MOCK_NOTE_ID, + insert_time: new Date(), + }; + }); + mockEmbedDao.createEmbed.mockImplementation(async (embed) => { storedEmbeds.push(embed.note_id + embed.embed_id); return { ...embed, @@ -338,7 +345,8 @@ describe("Execute test cases", () => { hmac: payload.hmac, crypto_version: payload.crypto_version || "v1", expire_time: expect.any(Date), - }) + }), + expect.arrayContaining(payload?.embeds ?? []) ); } diff --git a/server/src/db/embed.dao.integration.test.ts b/server/src/db/embed.dao.integration.test.ts index 1849af2..12b1d8c 100644 --- a/server/src/db/embed.dao.integration.test.ts +++ b/server/src/db/embed.dao.integration.test.ts @@ -56,7 +56,7 @@ describe("Reading and writing embeds", () => { }; await createEmbed(embed); // embed 1 - await expect(createEmbed(embed)).rejects.toThrowError(); // duplicate embed + await expect(createEmbed(embed)).rejects.toThrowError(/Duplicate embed/g); // duplicate embed }); it("Should read embeds for existing note", async () => { diff --git a/server/src/db/embed.dao.ts b/server/src/db/embed.dao.ts index 2b5d42a..75dd397 100644 --- a/server/src/db/embed.dao.ts +++ b/server/src/db/embed.dao.ts @@ -1,4 +1,4 @@ -import { EncryptedEmbed } from "@prisma/client"; +import { EncryptedEmbed, Prisma, PrismaClient } from "@prisma/client"; import { BufferToBase64, base64ToBuffer } from "../util"; import prisma from "./client"; @@ -9,6 +9,12 @@ export interface EncryptedEmbedDTO { hmac: string; } +/** + * Get an embed for a note by embed_id. + * @param noteId note id + * @param embedId embed id + * @returns encrypted embed (serialized ciphertext to base64) + */ export async function getEmbed( noteId: string, embedId: string @@ -24,8 +30,6 @@ export async function getEmbed( if (!embed) return null; - console.log(embed.ciphertext.byteLength, embed.size_bytes); - return { note_id: embed.note_id, embed_id: embed.embed_id, @@ -34,8 +38,15 @@ export async function getEmbed( }; } +/** + * Create an embed for a note. + * @param embed EncryptedEmbedDTO to serialize and save + * @param transactionClient optionally pass a TransactionClient object when running in a Prisma interactive transaction + * @returns the saved EncryptedEmbed (deserialized ciphertext to Buffer) + */ export async function createEmbed( - embed: EncryptedEmbedDTO + embed: EncryptedEmbedDTO, + transactionClient: Prisma.TransactionClient = prisma ): Promise { const cipher_buf = base64ToBuffer(embed.ciphertext); const data = { @@ -45,5 +56,13 @@ export async function createEmbed( ciphertext: cipher_buf, size_bytes: cipher_buf.byteLength, } as EncryptedEmbed; - return prisma.encryptedEmbed.create({ data }); + return transactionClient.encryptedEmbed.create({ data }).catch((err) => { + if (err instanceof Prisma.PrismaClientKnownRequestError) { + // The .code property can be accessed in a type-safe manner + if (err.code === "P2002") { + throw new Error("Duplicate embed"); + } + } + throw err; + }); } diff --git a/server/src/db/note.dao.integration.test.ts b/server/src/db/note.dao.integration.test.ts index ff2fef9..0a5d9ee 100644 --- a/server/src/db/note.dao.integration.test.ts +++ b/server/src/db/note.dao.integration.test.ts @@ -1,6 +1,8 @@ import { EncryptedNote } from "@prisma/client"; import { describe, it, expect } from "vitest"; +import { getEmbed } from "./embed.dao"; import { createNote, deleteNotes, getExpiredNotes, getNote } from "./note.dao"; +import prisma from "./client"; const VALID_CIPHERTEXT = Buffer.from("sample_ciphertext").toString("base64"); const VALID_HMAC = Buffer.from("sample_hmac").toString("base64"); @@ -12,6 +14,12 @@ const VALID_NOTE = { expire_time: new Date(), } as EncryptedNote; +const VALID_EMBED = { + embed_id: "embed_id", + hmac: VALID_HMAC, + ciphertext: VALID_CIPHERTEXT, +}; + describe("Writes and reads", () => { it("should write a new note", async () => { const res = await createNote(VALID_NOTE); @@ -25,6 +33,50 @@ describe("Writes and reads", () => { expect(res.insert_time.getTime()).toBeLessThanOrEqual(new Date().getTime()); }); + it("should write a new note with one embed", async () => { + const res = await createNote(VALID_NOTE, [VALID_EMBED]); + expect(res.id).not.toBeNull(); + + const res2 = await getEmbed(res.id, VALID_EMBED.embed_id); + expect(res2).not.toBeNull(); + expect(res2?.ciphertext).toStrictEqual(VALID_EMBED.ciphertext); + expect(res2?.hmac).toStrictEqual(VALID_EMBED.hmac); + }); + + it("should write a new note with multiple embeds", async () => { + const res = await createNote(VALID_NOTE, [ + VALID_EMBED, + { ...VALID_EMBED, embed_id: "embed_id2" }, + ]); + expect(res.id).not.toBeNull(); + + const res2 = await getEmbed(res.id, VALID_EMBED.embed_id); + expect(res2?.embed_id).toStrictEqual(VALID_EMBED.embed_id); + + const res3 = await getEmbed(res.id, "embed_id2"); + expect(res3?.embed_id).toStrictEqual("embed_id2"); + }), + it("should fail writing a new note with duplicate embed_ids", async () => { + await expect( + createNote(VALID_NOTE, [VALID_EMBED, VALID_EMBED]) + ).rejects.toThrowError(); + }); + + it("should roll back a failed note with embeds", async () => { + const noteCount = (await prisma.encryptedNote.findMany())?.length; + const embedCount = (await prisma.encryptedEmbed.findMany())?.length; + + await expect( + createNote({ ...VALID_NOTE }, [VALID_EMBED, VALID_EMBED]) + ).rejects.toThrowError(); + + const noteCountAfter = (await prisma.encryptedNote.findMany())?.length; + const embedCountAfter = (await prisma.encryptedEmbed.findMany())?.length; + + expect(noteCountAfter).toStrictEqual(noteCount); + expect(embedCountAfter).toStrictEqual(embedCount); + }); + it("should find an existing note by id", async () => { const note = await createNote(VALID_NOTE); const res = await getNote(note.id); diff --git a/server/src/db/note.dao.ts b/server/src/db/note.dao.ts index e91a573..8059030 100644 --- a/server/src/db/note.dao.ts +++ b/server/src/db/note.dao.ts @@ -1,5 +1,12 @@ import { EncryptedNote } from "@prisma/client"; import prisma from "./client"; +import { createEmbed, EncryptedEmbedDTO } from "./embed.dao"; + +type EncryptedEmbed = { + ciphertext: string; + hmac: string; + embed_id: string; +}; export async function getNote(noteId: string): Promise { return prisma.encryptedNote.findUnique({ @@ -7,9 +14,32 @@ export async function getNote(noteId: string): Promise { }); } -export async function createNote(note: EncryptedNote): Promise { - return prisma.encryptedNote.create({ - data: note, +export async function createNote( + note: EncryptedNote, + embeds: EncryptedEmbed[] = [] +): Promise { + return prisma.$transaction(async (transactionClient) => { + // 1. Save note + const savedNote = await transactionClient.encryptedNote.create({ + data: note, + }); + + // 2. Store embeds + if (embeds.length > 0) { + const _embeds: EncryptedEmbedDTO[] = embeds.map( + (embed) => + ({ + ...embed, + note_id: savedNote.id, + } as EncryptedEmbedDTO) + ); + for (const embed of _embeds) { + await createEmbed(embed, transactionClient); + } + } + + // 3. Finalize transaction + return savedNote; }); }