diff --git a/.changeset/good-hoops-battle.md b/.changeset/good-hoops-battle.md new file mode 100644 index 00000000..c6e97eaf --- /dev/null +++ b/.changeset/good-hoops-battle.md @@ -0,0 +1,5 @@ +--- +"ai-elements": patch +--- + +Enhance PromptInput with SourceDocument functionality diff --git a/packages/elements/__tests__/prompt-input.test.tsx b/packages/elements/__tests__/prompt-input.test.tsx index b6163539..8ff17fd6 100644 --- a/packages/elements/__tests__/prompt-input.test.tsx +++ b/packages/elements/__tests__/prompt-input.test.tsx @@ -12,6 +12,8 @@ import { PromptInputAttachments, PromptInputBody, PromptInputButton, + PromptInputReferencedSource, + PromptInputReferencedSources, PromptInputSelect, PromptInputSelectContent, PromptInputSelectItem, @@ -22,6 +24,7 @@ import { PromptInputTextarea, PromptInputTools, usePromptInputAttachments, + usePromptInputReferencedSources, } from "../src/prompt-input"; // Mock URL.createObjectURL and URL.revokeObjectURL for tests @@ -651,9 +654,7 @@ describe("PromptInputSelect", () => { - - GPT-4 - + GPT-4 @@ -1395,6 +1396,446 @@ describe("PromptInputAttachment", () => { }); }); +describe("PromptInputReferencedSource", () => { + it("renders referenced source with globe icon", () => { + const onSubmit = vi.fn(); + const source = { + id: "1", + type: "source-document" as const, + sourceId: "source-1", + title: "Test Document", + filename: "doc.pdf", + mediaType: "application/pdf", + }; + + render( + + + + + + ); + + expect(screen.getByText("Test Document")).toBeInTheDocument(); + }); + + it("falls back to filename when title is not provided", () => { + const onSubmit = vi.fn(); + const source = { + id: "1", + type: "source-document" as const, + sourceId: "source-1", + title: "", + filename: "document.pdf", + mediaType: "application/pdf", + }; + + render( + + + + + + ); + + expect(screen.getByText("document.pdf")).toBeInTheDocument(); + }); + + it("removes referenced source when remove button clicked", async () => { + const onSubmit = vi.fn(); + const user = userEvent.setup(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + <> + + + {(source) => ( + + )} + + + ); + }; + + render( + + + + + + + ); + + await user.click(screen.getByTestId("add-source")); + expect(screen.getByText("Test Source")).toBeInTheDocument(); + + const removeButton = screen.getByLabelText("Remove referenced source"); + await user.click(removeButton); + + expect(screen.queryByText("Test Source")).not.toBeInTheDocument(); + }); +}); + +describe("PromptInputReferencedSources", () => { + it("renders multiple referenced sources", async () => { + const onSubmit = vi.fn(); + const user = userEvent.setup(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + <> + + + {(source) =>
{source.title}
} +
+ + ); + }; + + render( + + + + + + + ); + + await user.click(screen.getByTestId("add-sources")); + + expect(screen.getByText("Source 1")).toBeInTheDocument(); + expect(screen.getByText("Source 2")).toBeInTheDocument(); + }); + + it("does not render when no sources exist", () => { + const onSubmit = vi.fn(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + + {(source) =>
{source.title}
} +
+ ); + }; + + render( + + + + + + + ); + + expect(screen.queryByTestId("sources-container")).not.toBeInTheDocument(); + }); + + it("clears referenced sources after successful form submission", async () => { + const onSubmit = vi.fn(() => Promise.resolve()); + const user = userEvent.setup(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + <> + + + {(source) =>
{source.title}
} +
+ + ); + }; + + render( + + + + + + + + ); + + // Add a referenced source + await user.click(screen.getByTestId("add-source")); + expect(screen.getByText("Test Source")).toBeInTheDocument(); + + // Type and submit + const textarea = screen.getByPlaceholderText( + "What would you like to know?" + ) as HTMLTextAreaElement; + await user.type(textarea, "test message"); + await user.keyboard("{Enter}"); + + // Wait for async submission to complete + await vi.waitFor(() => { + expect(onSubmit).toHaveBeenCalledTimes(1); + }); + + // Give time for promise resolution + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify referenced source was cleared + expect(screen.queryByText("Test Source")).not.toBeInTheDocument(); + }); + + it("does not clear referenced sources when onSubmit throws an error", async () => { + const onSubmit = vi.fn(() => { + throw new Error("Submission failed"); + }); + const user = userEvent.setup(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + <> + + + {(source) =>
{source.title}
} +
+ + ); + }; + + render( + + + + + + + + ); + + // Add a referenced source + await user.click(screen.getByTestId("add-source")); + expect(screen.getByText("Test Source")).toBeInTheDocument(); + + // Type and submit + const textarea = screen.getByPlaceholderText( + "What would you like to know?" + ) as HTMLTextAreaElement; + await user.type(textarea, "test message"); + await user.keyboard("{Enter}"); + + // Wait for submission attempt + await vi.waitFor(() => { + expect(onSubmit).toHaveBeenCalledTimes(1); + }); + + // Verify referenced source was NOT cleared due to error + expect(screen.getByText("Test Source")).toBeInTheDocument(); + }); + + it("does not clear referenced sources when async onSubmit rejects", async () => { + const onSubmit = vi.fn(() => + Promise.reject(new Error("Async submission failed")) + ); + const user = userEvent.setup(); + + const ReferencedSourceConsumer = () => { + const refs = usePromptInputReferencedSources(); + return ( + <> + + + {(source) =>
{source.title}
} +
+ + ); + }; + + render( + + + + + + + + ); + + // Add a referenced source + await user.click(screen.getByTestId("add-source")); + expect(screen.getByText("Test Source")).toBeInTheDocument(); + + // Type and submit + const textarea = screen.getByPlaceholderText( + "What would you like to know?" + ) as HTMLTextAreaElement; + await user.type(textarea, "test message"); + await user.keyboard("{Enter}"); + + // Wait for async submission attempt + await vi.waitFor(() => { + expect(onSubmit).toHaveBeenCalledTimes(1); + }); + + // Give time for promise rejection + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify referenced source was NOT cleared due to rejection + expect(screen.getByText("Test Source")).toBeInTheDocument(); + }); + + it("clears both attachments and referenced sources after successful submission", async () => { + const onSubmit = vi.fn(() => Promise.resolve()); + const user = userEvent.setup(); + + const file = new File(["test"], "test.txt", { type: "text/plain" }); + + const Consumer = () => { + const attachments = usePromptInputAttachments(); + const refs = usePromptInputReferencedSources(); + return ( + <> + + + + {(attachment) => ( +
{attachment.filename}
+ )} +
+ + {(source) =>
{source.title}
} +
+ + ); + }; + + render( + + + + + + + + ); + + // Add both attachment and referenced source + await user.click(screen.getByTestId("add-file")); + await user.click(screen.getByTestId("add-source")); + expect(screen.getByText("test.txt")).toBeInTheDocument(); + expect(screen.getByText("Test Source")).toBeInTheDocument(); + + // Type and submit + const textarea = screen.getByPlaceholderText( + "What would you like to know?" + ) as HTMLTextAreaElement; + await user.type(textarea, "test message"); + await user.keyboard("{Enter}"); + + // Wait for async submission + await vi.waitFor(() => { + expect(onSubmit).toHaveBeenCalledTimes(1); + }); + + // Give time for promise resolution + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Verify both were cleared + expect(screen.queryByText("test.txt")).not.toBeInTheDocument(); + expect(screen.queryByText("Test Source")).not.toBeInTheDocument(); + }); +}); + describe("PromptInputActionAddAttachments", () => { it("opens file dialog when clicked", async () => { const onSubmit = vi.fn(); diff --git a/packages/elements/src/prompt-input.tsx b/packages/elements/src/prompt-input.tsx index d0a2a364..c28f04eb 100644 --- a/packages/elements/src/prompt-input.tsx +++ b/packages/elements/src/prompt-input.tsx @@ -35,9 +35,10 @@ import { SelectValue, } from "@repo/shadcn-ui/components/ui/select"; import { cn } from "@repo/shadcn-ui/lib/utils"; -import type { ChatStatus, FileUIPart } from "ai"; +import type { ChatStatus, FileUIPart, SourceDocumentUIPart } from "ai"; import { CornerDownLeftIcon, + GlobeIcon, ImageIcon, Loader2Icon, MicIcon, @@ -201,15 +202,16 @@ export function PromptInputProvider({ attachmentsRef.current = attachmentFiles; // Cleanup blob URLs on unmount to prevent memory leaks - useEffect(() => { - return () => { + useEffect( + () => () => { for (const f of attachmentsRef.current) { if (f.url) { URL.revokeObjectURL(f.url); } } - }; - }, []); + }, + [] + ); const openFileDialog = useCallback(() => { openRef.current?.(); @@ -276,6 +278,30 @@ export const usePromptInputAttachments = () => { return context; }; +// ============================================================================ +// Referenced Sources (Local to PromptInput) +// ============================================================================ + +export type ReferencedSourcesContext = { + sources: (SourceDocumentUIPart & { id: string })[]; + add: (sources: SourceDocumentUIPart[] | SourceDocumentUIPart) => void; + remove: (id: string) => void; + clear: () => void; +}; + +export const LocalReferencedSourcesContext = + createContext(null); + +export const usePromptInputReferencedSources = () => { + const ctx = useContext(LocalReferencedSourcesContext); + if (!ctx) { + throw new Error( + "usePromptInputReferencedSources must be used within a LocalReferencedSourcesContext.Provider" + ); + } + return ctx; +}; + export type PromptInputAttachmentProps = HTMLAttributes & { data: FileUIPart & { id: string }; className?: string; @@ -392,7 +418,7 @@ export function PromptInputAttachments({ return (
{attachments.files.map((file) => ( @@ -402,6 +428,110 @@ export function PromptInputAttachments({ ); } +export type PromptInputReferencedSourceProps = + HTMLAttributes & { + data: SourceDocumentUIPart & { id: string }; + className?: string; + }; + +export function PromptInputReferencedSource({ + data, + className, + ...props +}: PromptInputReferencedSourceProps) { + const referencedSources = usePromptInputReferencedSources(); + const label = data.title || data.filename || "Source"; + + return ( + + +
+
+
+
+ +
+
+ +
+ + {label} +
+
+ +
+
+
+

+ {label} +

+ {data.mediaType && ( +

+ {data.mediaType} +

+ )} + {data.filename && ( +

+ {data.filename} +

+ )} +
+
+
+
+
+ ); +} + +export type PromptInputReferencedSourcesProps = Omit< + HTMLAttributes, + "children" +> & { + children: (source: SourceDocumentUIPart & { id: string }) => ReactNode; +}; + +export function PromptInputReferencedSources({ + children, + className, + ...props +}: PromptInputReferencedSourcesProps) { + const referencedSources = usePromptInputReferencedSources(); + + if (!referencedSources.sources.length) { + return null; + } + + return ( +
+ {referencedSources.sources.map((source) => ( + {children(source)} + ))} +
+ ); +} + export type PromptInputActionAddAttachmentsProps = ComponentProps< typeof DropdownMenuItem > & { @@ -480,6 +610,11 @@ export const PromptInput = ({ const [items, setItems] = useState<(FileUIPart & { id: string })[]>([]); const files = usingProvider ? controller.attachments.files : items; + // ----- Local referenced sources (always local to PromptInput) + const [referencedSources, setReferencedSources] = useState< + (SourceDocumentUIPart & { id: string })[] + >([]); + // Keep a ref to files for cleanup on unmount (avoids stale closure) const filesRef = useRef(files); filesRef.current = files; @@ -578,13 +713,37 @@ export const PromptInput = ({ [] ); + const clearAttachments = useCallback( + () => + usingProvider + ? controller?.attachments.clear() + : setItems((prev) => { + for (const file of prev) { + if (file.url) { + URL.revokeObjectURL(file.url); + } + } + return []; + }), + [usingProvider, controller] + ); + + const clearReferencedSources = useCallback( + () => setReferencedSources([]), + [] + ); + const add = usingProvider ? controller.attachments.add : addLocal; const remove = usingProvider ? controller.attachments.remove : removeLocal; - const clear = usingProvider ? controller.attachments.clear : clearLocal; const openFileDialog = usingProvider ? controller.attachments.openFileDialog : openFileDialogLocal; + const clear = useCallback(() => { + clearAttachments(); + clearReferencedSources(); + }, [clearAttachments, clearReferencedSources]); + // Let provider know about our hidden file input so external menus can call openFileDialog() useEffect(() => { if (!usingProvider) return; @@ -686,16 +845,33 @@ export const PromptInput = ({ } }; - const ctx = useMemo( + const attachmentsCtx = useMemo( () => ({ files: files.map((item) => ({ ...item, id: item.id })), add, remove, - clear, + clear: clearAttachments, openFileDialog, fileInputRef: inputRef, }), - [files, add, remove, clear, openFileDialog] + [files, add, remove, clearAttachments, openFileDialog] + ); + + const refsCtx = useMemo( + () => ({ + sources: referencedSources, + add: (incoming: SourceDocumentUIPart[] | SourceDocumentUIPart) => { + const array = Array.isArray(incoming) ? incoming : [incoming]; + setReferencedSources((prev) => + prev.concat(array.map((s) => ({ ...s, id: nanoid() }))) + ); + }, + remove: (id: string) => { + setReferencedSources((prev) => prev.filter((s) => s.id !== id)); + }, + clear: clearReferencedSources, + }), + [referencedSources, clearReferencedSources] ); const handleSubmit: FormEventHandler = (event) => { @@ -746,7 +922,7 @@ export const PromptInput = ({ // Don't clear on error - user may want to retry }); } else { - // Sync function completed without throwing, clear attachments + // Sync function completed without throwing, clear inputs clear(); if (usingProvider) { controller.textInput.clear(); @@ -785,11 +961,17 @@ export const PromptInput = ({ ); + const withReferencedSources = ( + + {inner} + + ); + return usingProvider ? ( - inner + withReferencedSources ) : ( - - {inner} + + {withReferencedSources} ); }; diff --git a/packages/examples/src/prompt-input-cursor.tsx b/packages/examples/src/prompt-input-cursor.tsx index 9a826e61..e9383aa5 100644 --- a/packages/examples/src/prompt-input-cursor.tsx +++ b/packages/examples/src/prompt-input-cursor.tsx @@ -18,7 +18,6 @@ import { PromptInputHoverCard, PromptInputHoverCardContent, PromptInputHoverCardTrigger, - type PromptInputMessage, PromptInputProvider, PromptInputSubmit, PromptInputTab, @@ -27,6 +26,10 @@ import { PromptInputTabLabel, PromptInputTextarea, PromptInputTools, + PromptInputReferencedSource, + PromptInputReferencedSources, + usePromptInputReferencedSources, + type PromptInputMessage, } from "@repo/elements/prompt-input"; import { ModelSelector, @@ -51,6 +54,7 @@ import { RulerIcon, } from "lucide-react"; import { useRef, useState } from "react"; +import type { SourceDocumentUIPart } from "ai"; const models = [ { @@ -93,24 +97,11 @@ const models = [ const SUBMITTING_TIMEOUT = 200; const STREAMING_TIMEOUT = 2000; -const sampleFiles = { - activeTabs: [{ path: "prompt-input.tsx", location: "packages/elements/src" }], - recents: [ - { path: "queue.tsx", location: "apps/test/app/examples" }, - { path: "queue.tsx", location: "packages/elements/src" }, - ], - added: [ - { path: "prompt-input.tsx", location: "packages/elements/src" }, - { path: "queue.tsx", location: "apps/test/app/examples" }, - { path: "queue.tsx", location: "packages/elements/src" }, - ], - filesAndFolders: [ - { path: "prompt-input.tsx", location: "packages/elements/src" }, - { path: "queue.tsx", location: "apps/test/app/examples" }, - ], - code: [{ path: "prompt-input.tsx", location: "packages/elements/src" }], - docs: [{ path: "README.md", location: "packages/elements" }], -}; +const sampleSources: SourceDocumentUIPart[] = [ + { type: "source-document", sourceId: "1", title: "prompt-input.tsx", filename: "packages/elements/src", mediaType: "text/plain" }, + { type: "source-document", sourceId: "2", title: "queue.tsx", filename: "apps/test/app/examples", mediaType: "text/plain" }, + { type: "source-document", sourceId: "3", title: "queue.tsx", filename: "packages/elements/src", mediaType: "text/plain" }, +]; const sampleTabs = { active: [{ path: "packages/elements/src/task-queue-panel.tsx" }], @@ -144,8 +135,6 @@ const Example = () => { setStatus("submitted"); - console.log("Submitting message:", message); - setTimeout(() => { setStatus("streaming"); }, SUBMITTING_TIMEOUT); @@ -171,40 +160,7 @@ const Example = () => { - - - - - No results found. - - - - - Active Tabs - - - - - - {sampleFiles.added.map((file, index) => ( - - -
- - {file.path} - - - {file.location} - -
-
- ))} -
-
-
+
@@ -271,6 +227,9 @@ const Example = () => { {(attachment) => } + + {(source) => } + { }; export default Example; + +const SampleFilesMenu = () => { + const refs = usePromptInputReferencedSources(); + + const handleAdd = (source: SourceDocumentUIPart) => { + refs.add(source); + }; + + return ( + + + + + No results found. + + + + + Active Tabs + + + + + + {sampleSources + .filter( + (source) => + !refs.sources.some( + (s) => s.title === source.title && s.filename === source.filename + ) + ) + .map((source, index) => ( + handleAdd(source)} + > + +
+ {source.title} + + {source.filename} + +
+
+ ))} +
+
+
+ ); +};