From 5a51a39f4f89dcb8198ec12f68810783845c5721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1vio=20Juvenal?= Date: Thu, 3 Oct 2024 11:51:05 -0300 Subject: [PATCH] Threads per assistant (optional) --- CONTRIBUTING.md | 19 +++++- django_ai_assistant/api/schemas.py | 2 + django_ai_assistant/api/views.py | 9 ++- django_ai_assistant/helpers/use_cases.py | 35 ++++++++-- .../migrations/0006_thread_assistant_id.py | 18 ++++++ django_ai_assistant/models.py | 2 + example/assets/js/components/Chat/Chat.tsx | 2 +- example/demo/templates/demo/chat_thread.html | 1 - example/package.json | 4 +- example/pnpm-lock.yaml | 64 ------------------- frontend/openapi_schema.json | 42 +++++++++++- frontend/src/client/schemas.gen.ts | 23 +++++++ frontend/src/client/services.gen.ts | 11 +++- frontend/src/client/types.gen.ts | 7 ++ frontend/src/hooks/useThreadList.ts | 11 ++-- frontend/tests/useThreadList.test.ts | 53 +++++++++++++++ tests/test_helpers/test_use_cases.py | 24 +++++++ tests/test_views.py | 33 ++++++++++ 18 files changed, 273 insertions(+), 87 deletions(-) create mode 100644 django_ai_assistant/migrations/0006_thread_assistant_id.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5c33dae..ad6a6ac 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -42,7 +42,6 @@ If you encounter an error regarding the Python version required for the project, pyenv install ``` - #### Frontend Go to the frontend directory and install the Node dependencies: @@ -60,6 +59,15 @@ pre-commit install It's critical to run the pre-commit hooks before pushing your code to follow the project's code style, and avoid linting errors. +### Updating the OpenAPI schema + +It's critical to update the OpenAPI schema when you make changes to the `django_ai_assistant/api/views.py` or related files: + +```bash +poetry run python manage.py generate_openapi_schema --output frontend/openapi_schema.json +sh -c 'cd frontend && pnpm run generate-client' +``` + ### Developing with the example project Run the frontend project in `build:watch` mode: @@ -105,6 +113,13 @@ Then, you will run the tests in record mode: poetry run pytest --record-mode=once ``` +To run frontend tests: + +```bash +cd frontend +pnpm run test +``` + ## Documentation We use [mkdocs-material](https://squidfunk.github.io/mkdocs-material/) to generate the documentation from markdown files. @@ -123,7 +138,7 @@ poetry run mkdocs serve To release and publish a new version, follow these steps: -1. Update the version in `pyproject.toml` and `frontend/package.json`. +1. Update the version in `pyproject.toml`, `frontend/package.json` and `example/package.json`. 2. Re-install the local version of the Python project: `poetry install` 3. In the project root, run `poetry run python manage.py generate_openapi_schema --output frontend/openapi_schema.json` to update the OpenAPI schema. 4. Re-install the local version of the frontend project: diff --git a/django_ai_assistant/api/schemas.py b/django_ai_assistant/api/schemas.py index 2ad71ae..c94d1a4 100644 --- a/django_ai_assistant/api/schemas.py +++ b/django_ai_assistant/api/schemas.py @@ -18,6 +18,7 @@ class Meta: fields = ( "id", "name", + "assistant_id", "created_at", "updated_at", ) @@ -25,6 +26,7 @@ class Meta: class ThreadIn(Schema): name: str = Field(default_factory=lambda: timezone.now().strftime("%Y-%m-%d %H:%M")) + assistant_id: str | None = None class ThreadMessageIn(Schema): diff --git a/django_ai_assistant/api/views.py b/django_ai_assistant/api/views.py index 3ae79f2..6e764e0 100644 --- a/django_ai_assistant/api/views.py +++ b/django_ai_assistant/api/views.py @@ -76,14 +76,17 @@ def get_assistant(request, assistant_id: str): @api.get("threads/", response=List[Thread], url_name="threads_list_create") -def list_threads(request): - return list(use_cases.get_threads(user=request.user)) +def list_threads(request, assistant_id: str | None = None): + return list(use_cases.get_threads(user=request.user, assistant_id=assistant_id)) @api.post("threads/", response=Thread, url_name="threads_list_create") def create_thread(request, payload: ThreadIn): name = payload.name - return use_cases.create_thread(name=name, user=request.user, request=request) + assistant_id = payload.assistant_id + return use_cases.create_thread( + name=name, assistant_id=assistant_id, user=request.user, request=request + ) @api.get("threads/{thread_id}/", response=Thread, url_name="thread_detail_update_delete") diff --git a/django_ai_assistant/helpers/use_cases.py b/django_ai_assistant/helpers/use_cases.py index b494808..a0fe386 100644 --- a/django_ai_assistant/helpers/use_cases.py +++ b/django_ai_assistant/helpers/use_cases.py @@ -142,6 +142,7 @@ def create_message( def create_thread( name: str, user: Any, + assistant_id: str | None = None, request: HttpRequest | None = None, ) -> Thread: """Create a thread.\n @@ -149,6 +150,8 @@ def create_thread( Args: name (str): Thread name + assistant_id (str | None): Assistant ID to associate the thread with. + If empty or None, the thread is not associated with any assistant. user (Any): Current user request (HttpRequest | None): Current request, if any Returns: @@ -159,7 +162,7 @@ def create_thread( if not can_create_thread(user=user, request=request): raise AIUserNotAllowedError("User is not allowed to create threads") - thread = Thread.objects.create(name=name, created_by=user) + thread = Thread.objects.create(name=name, created_by=user, assistant_id=assistant_id or "") return thread @@ -188,15 +191,37 @@ def get_single_thread( return thread -def get_threads(user: Any) -> list[Thread]: - """Get all user owned threads.\n +def get_threads( + user: Any, + assistant_id: str | None = None, + request: HttpRequest | None = None, +) -> list[Thread]: + """Get all threads for the user.\n + Uses `AI_ASSISTANT_CAN_VIEW_THREAD_FN` permission to check the threads the user can see, + and returns only the ones the user can see. Args: user (Any): Current user + assistant_id (str | None): Assistant ID to filter threads by. + If empty or None, all threads for the user are returned. + request (HttpRequest | None): Current request, if any Returns: - list[Thread]: List of thread model instances + list[Thread]: QuerySet of Thread model instances """ - return list(Thread.objects.filter(created_by=user)) + threads = Thread.objects.filter(created_by=user) + + if assistant_id: + threads = threads.filter(assistant_id=assistant_id) + + return list( + threads.filter( + id__in=[ + thread.id + for thread in threads + if can_view_thread(thread=thread, user=user, request=request) + ] + ) + ) def update_thread( diff --git a/django_ai_assistant/migrations/0006_thread_assistant_id.py b/django_ai_assistant/migrations/0006_thread_assistant_id.py new file mode 100644 index 0000000..1722c89 --- /dev/null +++ b/django_ai_assistant/migrations/0006_thread_assistant_id.py @@ -0,0 +1,18 @@ +# Generated by Django 5.1.1 on 2024-10-03 13:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('django_ai_assistant', '0005_alter_message_options'), + ] + + operations = [ + migrations.AddField( + model_name='thread', + name='assistant_id', + field=models.CharField(blank=True, max_length=255), + ), + ] diff --git a/django_ai_assistant/models.py b/django_ai_assistant/models.py index 510cbcd..b7360ab 100644 --- a/django_ai_assistant/models.py +++ b/django_ai_assistant/models.py @@ -20,6 +20,8 @@ class Thread(models.Model): null=True, ) """User who created the thread. Can be null. Set to null/None when user is deleted.""" + assistant_id = models.CharField(max_length=255, blank=True) + """Associated assistant ID. Can be empty.""" created_at = models.DateTimeField(auto_now_add=True) """Date and time when the thread was created. Automatically set when the thread is created.""" diff --git a/example/assets/js/components/Chat/Chat.tsx b/example/assets/js/components/Chat/Chat.tsx index be6fa6c..361bdd6 100644 --- a/example/assets/js/components/Chat/Chat.tsx +++ b/example/assets/js/components/Chat/Chat.tsx @@ -120,7 +120,7 @@ export function Chat({ assistantId }: { assistantId: string }) { const [activeThread, setActiveThread] = useState(null); const [inputValue, setInputValue] = useState(""); - const { fetchThreads, threads, createThread, deleteThread } = useThreadList(); + const { fetchThreads, threads, createThread, deleteThread } = useThreadList({ assistantId }); const { fetchMessages, messages, diff --git a/example/demo/templates/demo/chat_thread.html b/example/demo/templates/demo/chat_thread.html index 8cd25ba..53b1dd9 100644 --- a/example/demo/templates/demo/chat_thread.html +++ b/example/demo/templates/demo/chat_thread.html @@ -20,7 +20,6 @@ {% if message.type == "ai" %}AI{% else %}User{% endif %} - {{ message.content|markdown }} {% endfor %} diff --git a/example/package.json b/example/package.json index 381fa9e..8cc6626 100644 --- a/example/package.json +++ b/example/package.json @@ -44,9 +44,9 @@ "@mantine/notifications": "^7.11.0", "@tabler/icons-react": "^3.7.0", "cookie": "^0.6.0", - "django-ai-assistant-client": "0.0.1", + "django-ai-assistant-client": "0.0.4", "modern-normalize": "^2.0.0", "react-markdown": "^9.0.1", "react-router-dom": "^6.24.0" } -} +} \ No newline at end of file diff --git a/example/pnpm-lock.yaml b/example/pnpm-lock.yaml index c5518f5..9e9924c 100644 --- a/example/pnpm-lock.yaml +++ b/example/pnpm-lock.yaml @@ -23,9 +23,6 @@ importers: cookie: specifier: ^0.6.0 version: 0.6.0 - django-ai-assistant-client: - specifier: 0.0.1 - version: 0.0.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) modern-normalize: specifier: ^2.0.0 version: 2.0.0 @@ -1351,9 +1348,6 @@ packages: array-flatten@1.1.1: resolution: {integrity: sha512-PCVAQswWemu6UdxsDFFX/+gVeYqKAod3D3UVm91jHwynguOwAvYPhx8nNlM++NqRcK6CxxpUafjmhIdKiHibqg==} - asynckit@0.4.0: - resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} - autoprefixer@10.4.19: resolution: {integrity: sha512-BaENR2+zBZ8xXhM4pUaKUxlVdxZ0EZhjvbopwnXmxRUfqDmwSpC2lAi/QXvx7NRdPCo1WKEcEF6mV64si1z4Ew==} engines: {node: ^10 || ^12 || >=14} @@ -1361,9 +1355,6 @@ packages: peerDependencies: postcss: ^8.1.0 - axios@1.7.5: - resolution: {integrity: sha512-fZu86yCo+svH3uqJ/yTdQ0QHpQu5oL+/QE+QPSv6BZSkDAoky9vytxp7u5qk83OJFS3kEBcesWni9WTZAv3tSw==} - babel-loader@9.1.3: resolution: {integrity: sha512-xG3ST4DglodGf8qSwv0MdeWLhrDsw/32QMdTO5T1ZIp9gQur0HkCyFs7Awskr10JKXFXwpAhiCuYX5oGXnRGbw==} engines: {node: '>= 14.15.0'} @@ -1499,10 +1490,6 @@ packages: colorette@2.0.20: resolution: {integrity: sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==} - combined-stream@1.0.8: - resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} - engines: {node: '>= 0.8'} - comma-separated-tokens@2.0.3: resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} @@ -1646,10 +1633,6 @@ packages: resolution: {integrity: sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==} engines: {node: '>=12'} - delayed-stream@1.0.0: - resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} - engines: {node: '>=0.4.0'} - depd@1.1.2: resolution: {integrity: sha512-7emPTl6Dpo6JRXOXjLRxck+FlLRX5847cLKEn00PLAgc3g2hTZZgr+e4c2v6QpSmLeFP3n5yUo7ft6avBK/5jQ==} engines: {node: '>= 0.6'} @@ -1675,13 +1658,6 @@ packages: devlop@1.1.0: resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} - django-ai-assistant-client@0.0.1: - resolution: {integrity: sha512-QGek+qNkbqIX+i6jHp/nGmopEMV1IV6wRMFdMq/7dQxp82atautnDdhXMJR+peQT7jvRjNiBjfDn0iKbhQ1lzQ==} - engines: {node: '>=20 <21'} - peerDependencies: - react: ^18.3.1 - react-dom: ^18.3.1 - dns-packet@5.6.1: resolution: {integrity: sha512-l4gcSouhcgIKRvyy99RNVOgxXiicE+2jZoNmaNmZ6JXiGajBOJAesk1OBlJuM5k2c+eudGdLxDqXuPCKIj6kpw==} engines: {node: '>=6'} @@ -1849,10 +1825,6 @@ packages: resolution: {integrity: sha512-TMKDUnIte6bfb5nWv7V/caI169OHgvwjb7V4WkeUvbQQdjr5rWKqHFiKWb/fcOwB+CzBT+qbWjvj+DVwRskpIg==} engines: {node: '>=14'} - form-data@4.0.0: - resolution: {integrity: sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==} - engines: {node: '>= 6'} - forwarded@0.2.0: resolution: {integrity: sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==} engines: {node: '>= 0.6'} @@ -2760,9 +2732,6 @@ packages: resolution: {integrity: sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==} engines: {node: '>= 0.10'} - proxy-from-env@1.1.0: - resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==} - punycode@2.3.1: resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} engines: {node: '>=6'} @@ -4932,8 +4901,6 @@ snapshots: array-flatten@1.1.1: {} - asynckit@0.4.0: {} - autoprefixer@10.4.19(postcss@8.4.38): dependencies: browserslist: 4.23.0 @@ -4944,14 +4911,6 @@ snapshots: postcss: 8.4.38 postcss-value-parser: 4.2.0 - axios@1.7.5: - dependencies: - follow-redirects: 1.15.6 - form-data: 4.0.0 - proxy-from-env: 1.1.0 - transitivePeerDependencies: - - debug - babel-loader@9.1.3(@babel/core@7.24.7)(webpack@5.92.1(webpack-cli@5.1.4)): dependencies: '@babel/core': 7.24.7 @@ -5104,10 +5063,6 @@ snapshots: colorette@2.0.20: {} - combined-stream@1.0.8: - dependencies: - delayed-stream: 1.0.0 - comma-separated-tokens@2.0.3: {} commander@10.0.1: {} @@ -5233,8 +5188,6 @@ snapshots: define-lazy-prop@3.0.0: {} - delayed-stream@1.0.0: {} - depd@1.1.2: {} depd@2.0.0: {} @@ -5251,15 +5204,6 @@ snapshots: dependencies: dequal: 2.0.3 - django-ai-assistant-client@0.0.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1): - dependencies: - axios: 1.7.5 - cookie: 0.6.0 - react: 18.3.1 - react-dom: 18.3.1(react@18.3.1) - transitivePeerDependencies: - - debug - dns-packet@5.6.1: dependencies: '@leichtgewicht/ip-codec': 2.0.5 @@ -5443,12 +5387,6 @@ snapshots: cross-spawn: 7.0.3 signal-exit: 4.1.0 - form-data@4.0.0: - dependencies: - asynckit: 0.4.0 - combined-stream: 1.0.8 - mime-types: 2.1.35 - forwarded@0.2.0: {} fraction.js@4.3.7: {} @@ -6465,8 +6403,6 @@ snapshots: forwarded: 0.2.0 ipaddr.js: 1.9.1 - proxy-from-env@1.1.0: {} - punycode@2.3.1: {} qs@6.11.0: diff --git a/frontend/openapi_schema.json b/frontend/openapi_schema.json index 1af3897..4f0d936 100644 --- a/frontend/openapi_schema.json +++ b/frontend/openapi_schema.json @@ -72,7 +72,24 @@ "get": { "operationId": "ai_list_threads", "summary": "List Threads", - "parameters": [], + "parameters": [ + { + "in": "query", + "name": "assistant_id", + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Assistant Id" + }, + "required": false + } + ], "responses": { "200": { "description": "OK", @@ -377,6 +394,18 @@ ], "title": "Name" }, + "assistant_id": { + "anyOf": [ + { + "maxLength": 255, + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Assistant Id" + }, "created_at": { "format": "date-time", "title": "Created At", @@ -400,6 +429,17 @@ "name": { "title": "Name", "type": "string" + }, + "assistant_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Assistant Id" } }, "title": "ThreadIn", diff --git a/frontend/src/client/schemas.gen.ts b/frontend/src/client/schemas.gen.ts index 555d5d8..3e10493 100644 --- a/frontend/src/client/schemas.gen.ts +++ b/frontend/src/client/schemas.gen.ts @@ -41,6 +41,18 @@ export const $Thread = { ], title: 'Name' }, + assistant_id: { + anyOf: [ + { + maxLength: 255, + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Assistant Id' + }, created_at: { format: 'date-time', title: 'Created At', @@ -62,6 +74,17 @@ export const $ThreadIn = { name: { title: 'Name', type: 'string' + }, + assistant_id: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Assistant Id' } }, title: 'ThreadIn', diff --git a/frontend/src/client/services.gen.ts b/frontend/src/client/services.gen.ts index 3caeffe..2f27930 100644 --- a/frontend/src/client/services.gen.ts +++ b/frontend/src/client/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { AiListAssistantsResponse, AiGetAssistantData, AiGetAssistantResponse, AiListThreadsResponse, AiCreateThreadData, AiCreateThreadResponse, AiGetThreadData, AiGetThreadResponse, AiUpdateThreadData, AiUpdateThreadResponse, AiDeleteThreadData, AiDeleteThreadResponse, AiListThreadMessagesData, AiListThreadMessagesResponse, AiCreateThreadMessageData, AiCreateThreadMessageResponse, AiDeleteThreadMessageData, AiDeleteThreadMessageResponse } from './types.gen'; +import type { AiListAssistantsResponse, AiGetAssistantData, AiGetAssistantResponse, AiListThreadsData, AiListThreadsResponse, AiCreateThreadData, AiCreateThreadResponse, AiGetThreadData, AiGetThreadResponse, AiUpdateThreadData, AiUpdateThreadResponse, AiDeleteThreadData, AiDeleteThreadResponse, AiListThreadMessagesData, AiListThreadMessagesResponse, AiCreateThreadMessageData, AiCreateThreadMessageResponse, AiDeleteThreadMessageData, AiDeleteThreadMessageResponse } from './types.gen'; /** * List Assistants @@ -32,12 +32,17 @@ export const aiGetAssistant = (data: AiGetAssistantData): CancelablePromise => { return __request(OpenAPI, { +export const aiListThreads = (data: AiListThreadsData = {}): CancelablePromise => { return __request(OpenAPI, { method: 'GET', - url: '/threads/' + url: '/threads/', + query: { + assistant_id: data.assistantId + } }); }; /** diff --git a/frontend/src/client/types.gen.ts b/frontend/src/client/types.gen.ts index 5610a4a..ff0dfa7 100644 --- a/frontend/src/client/types.gen.ts +++ b/frontend/src/client/types.gen.ts @@ -8,12 +8,14 @@ export type Assistant = { export type Thread = { id?: number | null; name?: string | null; + assistant_id?: string | null; created_at: string; updated_at: string; }; export type ThreadIn = { name?: string; + assistant_id?: string | null; }; export type ThreadMessage = { @@ -37,6 +39,10 @@ export type AiGetAssistantData = { export type AiGetAssistantResponse = Assistant; +export type AiListThreadsData = { + assistantId?: string | null; +}; + export type AiListThreadsResponse = Array; export type AiCreateThreadData = { @@ -108,6 +114,7 @@ export type $OpenApiTs = { }; '/threads/': { get: { + req: AiListThreadsData; res: { /** * OK diff --git a/frontend/src/hooks/useThreadList.ts b/frontend/src/hooks/useThreadList.ts index bc5df6f..1c8e374 100644 --- a/frontend/src/hooks/useThreadList.ts +++ b/frontend/src/hooks/useThreadList.ts @@ -9,8 +9,9 @@ import { /** * React hook to manage the list, create, and delete of Threads. + * @param assistantId Optional assistant ID to filter threads */ -export function useThreadList() { +export function useThreadList({ assistantId }: { assistantId?: string } = {}) { const [threads, setThreads] = useState(null); const [loadingFetchThreads, setLoadingFetchThreads] = useState(false); @@ -27,13 +28,13 @@ export function useThreadList() { const fetchThreads = useCallback(async (): Promise => { try { setLoadingFetchThreads(true); - const fetchedThreads = await aiListThreads(); + const fetchedThreads = await aiListThreads({ assistantId }); setThreads(fetchedThreads); return fetchedThreads; } finally { setLoadingFetchThreads(false); } - }, []); + }, [assistantId]); /** * Creates a new thread. @@ -45,7 +46,7 @@ export function useThreadList() { try { setLoadingCreateThread(true); const thread = await aiCreateThread({ - requestBody: { name: name }, + requestBody: { name, assistant_id: assistantId }, }); await fetchThreads(); return thread; @@ -53,7 +54,7 @@ export function useThreadList() { setLoadingCreateThread(false); } }, - [fetchThreads] + [fetchThreads, assistantId] ); /** diff --git a/frontend/tests/useThreadList.test.ts b/frontend/tests/useThreadList.test.ts index 29067d5..abb763a 100644 --- a/frontend/tests/useThreadList.test.ts +++ b/frontend/tests/useThreadList.test.ts @@ -33,6 +33,22 @@ describe("useThreadList", () => { updated_at: "2024-06-09T00:00:00Z", }, ]; + const mockThreadsWithAssistantId = [ + { + id: 3, + name: "Thread 3", + created_at: "2024-06-11T00:00:00Z", + updated_at: "2024-06-11T00:00:00Z", + assistant_id: "test_assistant", + }, + { + id: 4, + name: "Thread 4", + created_at: "2024-06-12T00:00:00Z", + updated_at: "2024-06-12T00:00:00Z", + assistant_id: "test_assistant", + }, + ]; beforeEach(() => { jest.clearAllMocks(); @@ -65,6 +81,20 @@ describe("useThreadList", () => { expect(result.current.loadingFetchThreads).toBe(false); }); + it("should fetch threads with assistantId when provided", async () => { + const assistantId = "test_assistant"; + (aiListThreads as jest.Mock).mockResolvedValue(mockThreadsWithAssistantId); + + const { result } = renderHook(() => useThreadList({ assistantId })); + + await act(async () => { + await result.current.fetchThreads(); + }); + + expect(aiListThreads).toHaveBeenCalledWith({ assistantId }); + expect(result.current.threads).toEqual(mockThreadsWithAssistantId); + }); + it("should set loading to false if fetch fails", async () => { (aiListThreads as jest.Mock).mockRejectedValue( new Error("Failed to fetch") @@ -118,6 +148,29 @@ describe("useThreadList", () => { expect(result.current.loadingCreateThread).toBe(false); }); + it("should create a thread with assistantId when provided", async () => { + const assistantId = "test_assistant"; + const mockNewThread = { + id: 3, + name: "Thread 3", + assistant_id: assistantId, + created_at: "2024-06-11T00:00:00Z", + updated_at: "2024-06-11T00:00:00Z", + }; + (aiCreateThread as jest.Mock).mockResolvedValue(mockNewThread); + (aiListThreads as jest.Mock).mockResolvedValue([mockNewThread, ...mockThreads]); + + const { result } = renderHook(() => useThreadList({ assistantId })); + + await act(async () => { + const newThread = await result.current.createThread({ name: "Thread 3" }); + expect(newThread).toEqual(mockNewThread); + }); + + expect(result.current.threads).toEqual([mockNewThread, ...mockThreads]); + expect(result.current.loadingCreateThread).toBe(false); + }); + it("should create a thread with no name and update state correctly", async () => { const mockNewThread = { id: 3, diff --git a/tests/test_helpers/test_use_cases.py b/tests/test_helpers/test_use_cases.py index 3ad088f..8b6e511 100644 --- a/tests/test_helpers/test_use_cases.py +++ b/tests/test_helpers/test_use_cases.py @@ -184,6 +184,18 @@ def test_create_thread(): assert response.name == "My thread" assert response.created_by == user + assert response.assistant_id is None + + +@pytest.mark.django_db(transaction=True) +def test_create_thread_with_assistant_id(): + user = baker.make(User) + assistant_id = "temperature_assistant" + response = use_cases.create_thread("My thread", user, assistant_id) + + assert response.name == "My thread" + assert response.created_by == user + assert response.assistant_id == assistant_id @pytest.mark.django_db(transaction=True) @@ -225,6 +237,18 @@ def test_get_threads(): assert len(response) == 3 +@pytest.mark.django_db(transaction=True) +def test_get_threads_with_assistant_id(): + user = baker.make(User) + assistant_id = "temperature_assistant" + baker.make(Thread, created_by=user, _quantity=2) + baker.make(Thread, created_by=user, assistant_id=assistant_id, _quantity=3) + response = use_cases.get_threads(user, assistant_id) + + assert len(response) == 3 + assert all(thread.assistant_id == assistant_id for thread in response) + + @pytest.mark.django_db(transaction=True) def test_get_threads_does_not_list_other_users_threads(): user = baker.make(User) diff --git a/tests/test_views.py b/tests/test_views.py index a621859..f3e78cc 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -136,6 +136,22 @@ def test_list_threads_with_results(authenticated_client): assert len(response.json()) == 2 +@pytest.mark.django_db(transaction=True) +def test_list_threads_with_assistant_id(authenticated_client): + user = User.objects.first() + assistant_id = "temperature_assistant" + baker.make(Thread, created_by=user, _quantity=2) + baker.make(Thread, created_by=user, assistant_id=assistant_id, _quantity=3) + response = authenticated_client.get( + reverse("django_ai_assistant:threads_list_create"), + data={"assistant_id": assistant_id}, + ) + + assert response.status_code == HTTPStatus.OK + assert len(response.json()) == 3 + assert all(thread["assistant_id"] == assistant_id for thread in response.json()) + + @pytest.mark.django_db(transaction=True) def test_does_not_list_other_users_threads(authenticated_client): baker.make(Thread) @@ -176,6 +192,23 @@ def test_create_thread(authenticated_client): assert response.json()["id"] == thread.id +@pytest.mark.django_db(transaction=True) +def test_create_thread_with_assistant_id(authenticated_client): + assistant_id = "temperature_assistant" + response = authenticated_client.post( + reverse("django_ai_assistant:threads_list_create"), + data={"assistant_id": assistant_id}, + content_type="application/json", + ) + + thread = Thread.objects.first() + + assert response.status_code == HTTPStatus.OK + response_data = response.json() + assert response_data["id"] == thread.id + assert response_data["assistant_id"] == assistant_id + + def test_cannot_create_thread_if_unauthorized(): # TODO: Implement this test once permissions are in place pass