diff --git a/.gitignore b/.gitignore index 8c040c2..c24d5c2 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,5 @@ wrangler.dev.jsonc /packages/client/public/pwa-64x64.png /packages/client/public/pwa-192x192.png /packages/client/public/pwa-512x512.png -packages/client/public/yachtpit_bg* \ No newline at end of file +packages/client/public/yachtpit_bg* +/project/ diff --git a/packages/ai/src/providers/__tests__/chat-stream-provider.test.ts b/packages/ai/src/providers/__tests__/chat-stream-provider.test.ts index 01af5db..c165ded 100644 --- a/packages/ai/src/providers/__tests__/chat-stream-provider.test.ts +++ b/packages/ai/src/providers/__tests__/chat-stream-provider.test.ts @@ -1,5 +1,5 @@ import { OpenAI } from 'openai'; -import { describe, it, expect, vi } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { BaseChatProvider, @@ -29,7 +29,7 @@ class TestChatProvider extends BaseChatProvider { } // Mock dependencies -vi.mock('../../lib/chat-sdk', () => ({ +vi.mock('../../chat-sdk/chat-sdk.ts', () => ({ default: { buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'), buildMessageChain: vi.fn().mockReturnValue([ @@ -39,6 +39,26 @@ vi.mock('../../lib/chat-sdk', () => ({ }, })); +vi.mock('../../tools/agentic-rag.ts', () => ({ + agenticRAG: vi.fn(), + AgenticRAGTools: { + type: 'function', + function: { + name: 'agentic_rag', + description: 'Test agentic RAG tool', + parameters: { + type: 'object', + properties: { + action: { type: 'string', enum: ['search_knowledge'] }, + query: { type: 'string' }, + collection_name: { type: 'string' }, + }, + required: ['action', 'collection_name'], + }, + }, + }, +})); + describe('ChatStreamProvider', () => { it('should define the required interface', () => { // Verify the interface has the required method @@ -50,26 +70,616 @@ describe('ChatStreamProvider', () => { }); }); -describe('BaseChatProvider', () => { - it('should implement the ChatStreamProvider interface', () => { - // Create a concrete implementation - const provider = new TestChatProvider(); +describe('BaseChatProvider - Model Tool Calling', () => { + let provider: TestChatProvider; + let mockOpenAI: any; + let dataCallback: any; + let commonParams: CommonProviderParams; - // Verify it implements the interface + beforeEach(() => { + vi.clearAllMocks(); + + provider = new TestChatProvider(); + dataCallback = vi.fn(); + + mockOpenAI = { + chat: { + completions: { + create: vi.fn(), + }, + }, + }; + + commonParams = { + openai: mockOpenAI, + systemPrompt: 'Test system prompt', + preprocessedContext: {}, + maxTokens: 1000, + messages: [{ role: 'user', content: 'Test message' }], + model: 'gpt-4', + env: {} as any, + }; + }); + + it('should implement the ChatStreamProvider interface', () => { expect(provider.handleStream).toBeInstanceOf(Function); expect(provider.getOpenAIClient).toBeInstanceOf(Function); expect(provider.getStreamParams).toBeInstanceOf(Function); expect(provider.processChunk).toBeInstanceOf(Function); }); - it('should have abstract methods that need to be implemented', () => { - // This test verifies that the abstract methods exist - // We can't instantiate BaseChatProvider directly, so we use the concrete implementation - const provider = new TestChatProvider(); + it('should handle regular text streaming without tool calls', async () => { + // Mock stream chunks for regular text response + const chunks = [ + { + choices: [ + { + delta: { content: 'Hello ' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: { content: 'world!' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'stop', + }, + ], + }, + ]; - // Verify the abstract methods are implemented - expect(provider.getOpenAIClient).toBeDefined(); - expect(provider.getStreamParams).toBeDefined(); - expect(provider.processChunk).toBeDefined(); + mockOpenAI.chat.completions.create.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + + await provider.handleStream(commonParams, dataCallback); + + expect(mockOpenAI.chat.completions.create).toHaveBeenCalledWith( + expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + type: 'function', + function: expect.objectContaining({ + name: 'agentic_rag', + }), + }), + ]), + }), + ); + }); + + it('should handle tool calls in streaming response', async () => { + const { agenticRAG } = await import('../../tools/agentic-rag.ts'); + vi.mocked(agenticRAG).mockResolvedValue({ + success: true, + data: { + results: ['Test result'], + analysis: { needsRetrieval: false }, + }, + }); + + // Mock stream chunks for tool call response + const chunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { + name: 'agentic_rag', + arguments: + '{"action": "search_knowledge", "query": "test query", "collection_name": "test_collection"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + // Second stream for response after tool execution + const secondStreamChunks = [ + { + choices: [ + { + delta: { content: 'Based on the search results: Test result' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'stop', + }, + ], + }, + ]; + + let callCount = 0; + mockOpenAI.chat.completions.create.mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + } else { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of secondStreamChunks) { + yield chunk; + } + }, + }); + } + }); + + await provider.handleStream(commonParams, dataCallback); + + // Verify tool was called + expect(agenticRAG).toHaveBeenCalledWith({ + action: 'search_knowledge', + query: 'test query', + collection_name: 'test_collection', + }); + + // Verify feedback messages were sent + expect(dataCallback).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'chat', + data: expect.objectContaining({ + choices: expect.arrayContaining([ + expect.objectContaining({ + delta: expect.objectContaining({ + content: expect.stringContaining('šŸ”§ Invoking'), + }), + }), + ]), + }), + }), + ); + + expect(dataCallback).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'chat', + data: expect.objectContaining({ + choices: expect.arrayContaining([ + expect.objectContaining({ + delta: expect.objectContaining({ + content: expect.stringContaining('šŸ“ž Calling agentic_rag'), + }), + }), + ]), + }), + }), + ); + }); + + it('should handle tool call streaming with incremental arguments', async () => { + const { agenticRAG } = await import('../../tools/agentic-rag.ts'); + vi.mocked(agenticRAG).mockResolvedValue({ + success: true, + data: { results: ['Test result'] }, + }); + + // Mock stream chunks with incremental tool call arguments + const chunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_', + type: 'function', + function: { name: 'agentic_rag', arguments: '{"action": "search_' }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: '123', + function: { arguments: 'knowledge", "query": "test", ' }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + function: { arguments: '"collection_name": "test_collection"}' }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + const secondStreamChunks = [ + { + choices: [ + { + delta: { content: 'Response after tool call' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'stop', + }, + ], + }, + ]; + + let callCount = 0; + mockOpenAI.chat.completions.create.mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + } else { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of secondStreamChunks) { + yield chunk; + } + }, + }); + } + }); + + await provider.handleStream(commonParams, dataCallback); + + // Verify the complete tool call was assembled and executed + expect(agenticRAG).toHaveBeenCalledWith({ + action: 'search_knowledge', + query: 'test', + collection_name: 'test_collection', + }); + }); + + it('should prevent infinite tool call loops', async () => { + const { agenticRAG } = await import('../../tools/agentic-rag.ts'); + vi.mocked(agenticRAG).mockResolvedValue({ + success: true, + data: { + results: [], + analysis: { needsRetrieval: true }, + retrieved_documents: [], + }, + }); + + // Mock stream that always returns tool calls + const toolCallChunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { + name: 'agentic_rag', + arguments: + '{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + mockOpenAI.chat.completions.create.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + for (const chunk of toolCallChunks) { + yield chunk; + } + }, + }); + + await provider.handleStream(commonParams, dataCallback); + + // Should detect duplicate tool calls and force completion (up to 5 iterations based on maxToolCallIterations) + // In this case, it should stop after 2 calls due to duplicate detection, but could go up to 5 + expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2); + }); + + it('should handle tool call errors gracefully', async () => { + const { agenticRAG } = await import('../../tools/agentic-rag.ts'); + vi.mocked(agenticRAG).mockRejectedValue(new Error('Tool execution failed')); + + const chunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { + name: 'agentic_rag', + arguments: + '{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + const secondStreamChunks = [ + { + choices: [ + { + delta: { content: 'I apologize, but I encountered an error.' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'stop', + }, + ], + }, + ]; + + let callCount = 0; + mockOpenAI.chat.completions.create.mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + } else { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of secondStreamChunks) { + yield chunk; + } + }, + }); + } + }); + + await provider.handleStream(commonParams, dataCallback); + + // Should still complete without throwing + expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2); + }); + + it('should prevent duplicate tool calls', async () => { + const { agenticRAG } = await import('../../tools/agentic-rag.ts'); + vi.mocked(agenticRAG).mockResolvedValue({ + success: true, + data: { results: ['Test result'] }, + }); + + // Mock the same tool call twice + const chunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { + name: 'agentic_rag', + arguments: + '{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}', + }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + // Second iteration with same tool call + let callCount = 0; + mockOpenAI.chat.completions.create.mockImplementation(() => { + callCount++; + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + }); + + await provider.handleStream(commonParams, dataCallback); + + // Should only execute the tool once, then force completion + expect(agenticRAG).toHaveBeenCalledTimes(1); + }); + + it('should handle invalid JSON in tool call arguments', async () => { + const chunks = [ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: 'call_123', + type: 'function', + function: { + name: 'agentic_rag', + arguments: '{"action": "search_knowledge", "invalid": json}', // Invalid JSON + }, + }, + ], + }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'tool_calls', + }, + ], + }, + ]; + + const secondStreamChunks = [ + { + choices: [ + { + delta: { content: 'I encountered an error parsing the tool arguments.' }, + finish_reason: null, + }, + ], + }, + { + choices: [ + { + delta: {}, + finish_reason: 'stop', + }, + ], + }, + ]; + + let callCount = 0; + mockOpenAI.chat.completions.create.mockImplementation(() => { + callCount++; + if (callCount === 1) { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of chunks) { + yield chunk; + } + }, + }); + } else { + return Promise.resolve({ + async *[Symbol.asyncIterator]() { + for (const chunk of secondStreamChunks) { + yield chunk; + } + }, + }); + } + }); + + // Should not throw, should handle gracefully + await expect(provider.handleStream(commonParams, dataCallback)).resolves.not.toThrow(); }); }); diff --git a/packages/ai/src/providers/chat-stream-provider.ts b/packages/ai/src/providers/chat-stream-provider.ts index e656d6b..c8074c5 100644 --- a/packages/ai/src/providers/chat-stream-provider.ts +++ b/packages/ai/src/providers/chat-stream-provider.ts @@ -50,6 +50,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider { let toolCallIterations = 0; const maxToolCallIterations = 5; // Prevent infinite loops let toolsExecuted = false; // Track if we've executed tools + const attemptedToolCalls = new Set(); // Track attempted tool calls to prevent duplicates while (!conversationComplete && toolCallIterations < maxToolCallIterations) { const streamParams = this.getStreamParams(param, safeMessages); @@ -112,6 +113,21 @@ export abstract class BaseChatProvider implements ChatStreamProvider { // Execute tool calls and add results to conversation console.log('Executing tool calls:', toolCalls); + // Limit to one tool call per iteration to prevent concurrent execution issues + // Also filter out duplicate tool calls + const uniqueToolCalls = toolCalls.filter(toolCall => { + const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`; + return !attemptedToolCalls.has(toolCallKey); + }); + const toolCallsToExecute = uniqueToolCalls.slice(0, 1); + + if (toolCallsToExecute.length === 0) { + console.log('All tool calls have been attempted already, forcing completion'); + toolsExecuted = true; + conversationComplete = true; + break; + } + // Send feedback to user about tool invocation dataCallback({ type: 'chat', @@ -119,7 +135,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider { choices: [ { delta: { - content: `\n\nšŸ”§ Invoking ${toolCalls.length} tool${toolCalls.length > 1 ? 's' : ''}...\n`, + content: `\n\nšŸ”§ Invoking ${toolCallsToExecute.length} tool${toolCallsToExecute.length > 1 ? 's' : ''}...\n`, }, }, ], @@ -130,15 +146,20 @@ export abstract class BaseChatProvider implements ChatStreamProvider { safeMessages.push({ role: 'assistant', content: assistantMessage || null, - tool_calls: toolCalls, + tool_calls: toolCallsToExecute, }); // Execute each tool call and add results - for (const toolCall of toolCalls) { + let needsMoreRetrieval = false; + for (const toolCall of toolCallsToExecute) { if (toolCall.type === 'function') { const name = toolCall.function.name; console.log(`Calling function: ${name}`); + // Track this tool call attempt + const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`; + attemptedToolCalls.add(toolCallKey); + // Send feedback about specific tool being called dataCallback({ type: 'chat', @@ -160,6 +181,36 @@ export abstract class BaseChatProvider implements ChatStreamProvider { const result = await callFunction(name, args); console.log(`Function result:`, result); + // Check if agentic-rag indicates more retrieval is needed + if ( + name === 'agentic_rag' && + result?.data?.analysis?.needsRetrieval === true && + (!result?.data?.retrieved_documents || + result.data.retrieved_documents.length === 0) + ) { + needsMoreRetrieval = true; + console.log('Agentic RAG indicates more retrieval needed'); + + // Add context about previous attempts to help LLM make better decisions + const attemptedActions = Array.from(attemptedToolCalls) + .filter(key => key.startsWith('agentic_rag:')) + .map(key => { + try { + const args = JSON.parse(key.split(':', 2)[1]); + return `${args.action} with query: "${args.query}"`; + } catch { + return 'unknown action'; + } + }); + + if (attemptedActions.length > 0) { + safeMessages.push({ + role: 'system', + content: `Previous retrieval attempts: ${attemptedActions.join(', ')}. Consider trying a different approach or more specific query.`, + }); + } + } + // Send feedback about tool completion dataCallback({ type: 'chat', @@ -206,8 +257,10 @@ export abstract class BaseChatProvider implements ChatStreamProvider { } } - // Mark that tools have been executed to prevent repeated calls - toolsExecuted = true; + // Only mark tools as executed if we don't need more retrieval + if (!needsMoreRetrieval) { + toolsExecuted = true; + } // Send feedback that tool execution is complete dataCallback({ diff --git a/packages/ai/src/tools/__tests__/agentic-rag.test.ts b/packages/ai/src/tools/__tests__/agentic-rag.test.ts index 1287db1..510aee2 100644 --- a/packages/ai/src/tools/__tests__/agentic-rag.test.ts +++ b/packages/ai/src/tools/__tests__/agentic-rag.test.ts @@ -1,36 +1,88 @@ -import { describe, it, expect } from 'vitest'; -import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +import { agenticRAG, AgenticRAGTools } from '../agentic-rag'; + +// Mock the dependencies +vi.mock('@zilliz/milvus2-sdk-node', () => ({ + MilvusClient: vi.fn().mockImplementation(() => ({ + listCollections: vi.fn().mockResolvedValue({ + collection_names: ['family_domestic', 'business_corporate'], + data: [{ name: 'family_domestic' }, { name: 'business_corporate' }], + }), + search: vi.fn().mockResolvedValue({ + results: [ + { + content: 'Test document about AI and machine learning', + score: 0.85, + metadata: '{"category": "AI", "author": "Test Author"}', + }, + { + content: 'Another document about neural networks', + score: 0.75, + metadata: '{"category": "ML", "author": "Another Author"}', + }, + ], + }), + insert: vi.fn().mockResolvedValue({ success: true }), + createCollection: vi.fn().mockResolvedValue({ success: true }), + createIndex: vi.fn().mockResolvedValue({ success: true }), + })), + DataType: { + VarChar: 'VarChar', + FloatVector: 'FloatVector', + }, +})); + +vi.mock('openai', () => ({ + OpenAI: vi.fn().mockImplementation(() => ({ + embeddings: { + create: vi.fn().mockResolvedValue({ + data: [{ embedding: new Array(768).fill(0.1) }], + }), + }, + })), +})); + +// Mock environment variables +vi.stubEnv('FIREWORKS_API_KEY', 'test-api-key'); describe('Agentic RAG System', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + it('should analyze queries correctly', async () => { // Test factual query const factualResult = await agenticRAG({ action: 'analyze_query', query: 'What is artificial intelligence?', + collection_name: 'family_domestic', }); expect(factualResult.status).toBe('success'); - expect(factualResult.needsRetrieval).toBe(true); + expect(factualResult.data.needsRetrieval).toBe(true); expect(factualResult.data.queryType).toBe('factual'); - // Test conversational query + // Test conversational query with multiple conversational keywords const conversationalResult = await agenticRAG({ action: 'analyze_query', - query: 'Hello, how are you?', + query: 'Hello, how are you doing today?', + collection_name: 'family_domestic', }); expect(conversationalResult.status).toBe('success'); - expect(conversationalResult.needsRetrieval).toBe(false); + expect(conversationalResult.data.needsRetrieval).toBe(false); expect(conversationalResult.data.queryType).toBe('conversational'); - // Test creative query + // Test creative query with multiple creative keywords const creativeResult = await agenticRAG({ action: 'analyze_query', - query: 'Write a story about a robot', + query: 'Write a story and compose a poem', + collection_name: 'family_domestic', }); expect(creativeResult.status).toBe('success'); - expect(creativeResult.needsRetrieval).toBe(false); + expect(creativeResult.data.needsRetrieval).toBe(false); expect(creativeResult.data.queryType).toBe('creative'); }); @@ -38,25 +90,27 @@ describe('Agentic RAG System', () => { const result = await agenticRAG({ action: 'search_knowledge', query: 'What is machine learning?', + collection_name: 'family_domestic', top_k: 2, - threshold: 0.1, + similarity_threshold: 0.1, }); expect(result.status).toBe('success'); - expect(result.needsRetrieval).toBe(true); expect(result.context).toBeDefined(); expect(Array.isArray(result.context)).toBe(true); expect(result.data.retrieved_documents).toBeDefined(); + expect(result.data.analysis.needsRetrieval).toBe(true); }); it('should not search for conversational queries', async () => { const result = await agenticRAG({ action: 'search_knowledge', - query: 'Hello there!', + query: 'Hello there! How are you?', + collection_name: 'family_domestic', }); expect(result.status).toBe('success'); - expect(result.needsRetrieval).toBe(false); + expect(result.data.analysis.needsRetrieval).toBe(false); expect(result.data.retrieved_documents).toHaveLength(0); }); @@ -68,6 +122,7 @@ describe('Agentic RAG System', () => { content: 'This is a test document about neural networks and deep learning.', metadata: { category: 'AI', author: 'Test Author' }, }, + collection_name: 'family_domestic', }); expect(result.status).toBe('success'); @@ -79,18 +134,43 @@ describe('Agentic RAG System', () => { const result = await agenticRAG({ action: 'get_context', query: 'Tell me about vector databases', + collection_name: 'family_domestic', top_k: 2, }); expect(result.status).toBe('success'); - expect(result.needsRetrieval).toBe(true); + expect(result.data.analysis.needsRetrieval).toBe(true); expect(result.context).toBeDefined(); expect(result.data.context_summary).toBeDefined(); }); + it('should handle semantic search', async () => { + const result = await agenticRAG({ + action: 'semantic_search', + query: 'artificial intelligence concepts', + collection_name: 'family_domestic', + top_k: 3, + }); + + expect(result.status).toBe('success'); + expect(result.data.results).toBeDefined(); + expect(Array.isArray(result.data.results)).toBe(true); + }); + + it('should list collections', async () => { + const result = await agenticRAG({ + action: 'list_collections', + collection_name: 'family_domestic', + }); + + expect(result.status).toBe('success'); + expect(result.message).toContain('family_domestic'); + }); + it('should handle errors gracefully', async () => { const result = await agenticRAG({ action: 'analyze_query', + collection_name: 'family_domestic', // Missing query parameter }); @@ -98,35 +178,82 @@ describe('Agentic RAG System', () => { expect(result.message).toContain('Query is required'); }); + it('should handle invalid actions', async () => { + const result = await agenticRAG({ + action: 'invalid_action', + collection_name: 'family_domestic', + }); + + expect(result.status).toBe('error'); + expect(result.message).toContain('Invalid action'); + }); + it('should have correct tool definition structure', () => { - expect(AgenticRAGTool.type).toBe('function'); - expect(AgenticRAGTool.function.name).toBe('agentic_rag'); - expect(AgenticRAGTool.function.description).toBeDefined(); - expect(AgenticRAGTool.function.parameters.type).toBe('object'); - expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined(); - expect(AgenticRAGTool.function.parameters.required).toContain('action'); + expect(AgenticRAGTools.type).toBe('function'); + expect(AgenticRAGTools.function.name).toBe('agentic_rag'); + expect(AgenticRAGTools.function.description).toBeDefined(); + expect(AgenticRAGTools.function.parameters.type).toBe('object'); + expect(AgenticRAGTools.function.parameters.properties.action).toBeDefined(); + expect(AgenticRAGTools.function.parameters.required).toContain('action'); + expect(AgenticRAGTools.function.parameters.required).toContain('collection_name'); }); it('should demonstrate intelligent retrieval decision making', async () => { // Test various query types to show intelligent decision making const queries = [ { query: 'What is AI?', expectedRetrieval: true }, - { query: 'Hello world', expectedRetrieval: false }, - { query: 'Write a poem', expectedRetrieval: false }, + { query: 'Hello world how are you', expectedRetrieval: false }, + { query: 'Write a poem and create a story', expectedRetrieval: false }, { query: 'Explain machine learning', expectedRetrieval: true }, - { query: 'How are you doing?', expectedRetrieval: false }, + { query: 'How are you doing today?', expectedRetrieval: true }, + { query: 'Tell me about neural networks', expectedRetrieval: true }, ]; for (const testCase of queries) { const result = await agenticRAG({ action: 'search_knowledge', query: testCase.query, + collection_name: 'family_domestic', }); expect(result.status).toBe('success'); - expect(result.needsRetrieval).toBe(testCase.expectedRetrieval); + expect(result.data.analysis.needsRetrieval).toBe(testCase.expectedRetrieval); - console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`); + console.log( + `[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.data.analysis.needsRetrieval}`, + ); } }); -}); \ No newline at end of file + + it('should filter results by similarity threshold', async () => { + const result = await agenticRAG({ + action: 'search_knowledge', + query: 'What is machine learning?', + collection_name: 'family_domestic', + similarity_threshold: 0.8, // High threshold + }); + + expect(result.status).toBe('success'); + if (result.data.analysis.needsRetrieval) { + // Should only return results above threshold + result.data.retrieved_documents.forEach((doc: any) => { + expect(doc.score).toBeGreaterThanOrEqual(0.8); + }); + } + }); + + it('should handle context window limits', async () => { + const result = await agenticRAG({ + action: 'get_context', + query: 'Tell me about artificial intelligence', + collection_name: 'family_domestic', + context_window: 1000, + }); + + expect(result.status).toBe('success'); + if (result.data.analysis.needsRetrieval && result.data.context_summary) { + // Context should respect the window limit (approximate check) + expect(result.data.context_summary.length).toBeLessThanOrEqual(2000); // Allow some flexibility + } + }); +}); diff --git a/packages/ai/src/tools/agentic-rag.ts b/packages/ai/src/tools/agentic-rag.ts index 443dad3..ec7632e 100644 --- a/packages/ai/src/tools/agentic-rag.ts +++ b/packages/ai/src/tools/agentic-rag.ts @@ -50,24 +50,48 @@ export const AgenticRAGTools = { properties: { action: { type: 'string', - enum: ['search', 'list_collections', 'report_status'], + enum: [ + 'list_collections', + 'report_status', + 'semantic_search', + 'search_knowledge', + 'analyze_query', + 'get_context', + ], description: 'Action to perform with the agentic RAG system.', }, query: { type: 'string', description: 'User query or search term for knowledge retrieval.', }, - document: { - type: 'object', - properties: { - content: { type: 'string', description: 'Document content to store' }, - metadata: { type: 'object', description: 'Additional metadata for the document' }, - id: { type: 'string', description: 'Unique identifier for the document' }, - }, - description: 'Document to store in the knowledge base.', - }, + // document: { + // type: 'object', + // properties: { + // content: { type: 'string', description: 'Document content to store' }, + // metadata: { type: 'object', description: 'Additional metadata for the document' }, + // id: { type: 'string', description: 'Unique identifier for the document' }, + // }, + // description: 'Document to store in the knowledge base.', + // }, collection_name: { type: 'string', + // todo: make this fancy w/ dynamic collection + enum: [ + 'business_corporate', + 'civil_procedure', + 'criminal_justice', + 'education_professions', + 'environmental_infrastructure', + 'family_domestic', + 'foundational_law', + 'government_administration', + 'health_social_services', + 'miscellaneous', + 'property_real_estate', + 'special_documents', + 'taxation_finance', + 'transportation_motor_vehicles', + ], description: 'Name of the collection to work with.', }, top_k: { @@ -83,7 +107,7 @@ export const AgenticRAGTools = { description: 'Maximum number of context tokens to include (default: 2000).', }, }, - required: ['action'], + required: ['action', 'collection_name'], additionalProperties: false, }, strict: true, @@ -95,10 +119,10 @@ export const AgenticRAGTools = { */ const DEFAULT_CONFIG: AgenticRAGConfig = { milvusAddress: 'localhost:19530', - collectionName: 'knowledge_base', + collectionName: 'family_domestic', embeddingDimension: 768, topK: 5, - similarityThreshold: 0.58, + similarityThreshold: 0.5, }; /** @@ -200,14 +224,16 @@ function analyzeQueryForRetrieval(query: string): { 'Query appears to be asking for factual information that may benefit from knowledge retrieval.', queryType: 'factual', }; - } else if (creativeScore > conversationalScore) { + } else if (creativeScore > conversationalScore && creativeScore > 1) { + // Only skip retrieval for clearly creative tasks with multiple creative keywords return { needsRetrieval: false, confidence: 0.8, reasoning: 'Query appears to be requesting creative content generation.', queryType: 'creative', }; - } else if (conversationalScore > 0) { + } else if (conversationalScore > 1 && conversationalScore > factualScore) { + // Only skip retrieval for clearly conversational queries with multiple conversational keywords return { needsRetrieval: false, confidence: 0.7, @@ -215,10 +241,11 @@ function analyzeQueryForRetrieval(query: string): { queryType: 'conversational', }; } else { + // Default to retrieval for most cases to ensure comprehensive responses return { needsRetrieval: true, - confidence: 0.6, - reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.', + confidence: 0.8, + reasoning: 'Defaulting to retrieval to provide comprehensive and accurate information.', queryType: 'analytical', }; } @@ -235,8 +262,8 @@ export async function agenticRAG(args: { top_k?: number; similarity_threshold?: number; context_window?: number; + user_confirmed?: boolean; }): Promise { - console.log('calling agentic rag tool', args); const config = { ...DEFAULT_CONFIG }; const collectionName = args.collection_name || config.collectionName!; const topK = args.top_k || config.topK!; @@ -297,9 +324,9 @@ export async function agenticRAG(args: { // eslint-disable-next-line no-case-declarations const searchResult = await milvusClient.search({ collection_name: collectionName, - query_vectors: [queryEmbedding], - top_k: topK, - params: { nprobe: 16 }, + vector: queryEmbedding, + topk: topK, + params: { nprobe: 8 }, output_fields: ['content', 'metadata'], }); @@ -385,8 +412,7 @@ export async function agenticRAG(args: { ], }; - // @ts-expect-error - idk man - await milvusClient.createCollection(collectionSchema); + await milvusClient.createCollection(collectionSchema as any); // Create index for efficient similarity search await milvusClient.createIndex({ @@ -426,9 +452,9 @@ export async function agenticRAG(args: { // eslint-disable-next-line no-case-declarations const semanticResults = await milvusClient.search({ collection_name: collectionName, - query_vectors: [semanticEmbedding], - top_k: topK, - params: { nprobe: 16 }, + vector: semanticEmbedding, + topk: topK, + params: { nprobe: 8 }, output_fields: ['content', 'metadata'], }); @@ -452,14 +478,13 @@ export async function agenticRAG(args: { // This is a comprehensive context retrieval that combines analysis and search // eslint-disable-next-line no-case-declarations const contextAnalysis = analyzeQueryForRetrieval(args.query); - if (contextAnalysis.needsRetrieval) { const contextEmbedding = await generateEmbedding(args.query); const contextSearch = await milvusClient.search({ collection_name: collectionName, - query_vectors: [contextEmbedding], - top_k: topK, - params: { nprobe: 16 }, + vector: contextEmbedding, + topk: topK, + params: { nprobe: 8 }, output_fields: ['content', 'metadata'], }); diff --git a/packages/client/src/components/landing-component/LandingComponent.tsx b/packages/client/src/components/landing-component/LandingComponent.tsx index 442ac52..598e593 100644 --- a/packages/client/src/components/landing-component/LandingComponent.tsx +++ b/packages/client/src/components/landing-component/LandingComponent.tsx @@ -39,19 +39,19 @@ export const LandingComponent: React.FC = () => { }, }} switches={{ - GpsMap: { - value: mapActive, - onChange(enabled) { - if (enabled) { - setEnabledComponent('gpsmap'); - setAiActive(false); - } else { - setEnabledComponent(''); - } - setMapActive(enabled); - }, - label: 'GPS', - }, + // GpsMap: { + // value: mapActive, + // onChange(enabled) { + // if (enabled) { + // setEnabledComponent('gpsmap'); + // setAiActive(false); + // } else { + // setEnabledComponent(''); + // } + // setMapActive(enabled); + // }, + // label: 'GPS', + // }, AI: { value: aiActive, onChange(enabled) {