Refine tool call execution logic in chat-stream-provider to prevent duplicates, enhance retries for agentic-rag, and improve incremental processing, including test updates.

This commit is contained in:
geoffsee
2025-07-31 18:26:42 -04:00
parent 6c433581d3
commit 9810f67af0
6 changed files with 906 additions and 90 deletions

1
.gitignore vendored
View File

@@ -25,3 +25,4 @@ wrangler.dev.jsonc
/packages/client/public/pwa-192x192.png /packages/client/public/pwa-192x192.png
/packages/client/public/pwa-512x512.png /packages/client/public/pwa-512x512.png
packages/client/public/yachtpit_bg* packages/client/public/yachtpit_bg*
/project/

View File

@@ -1,5 +1,5 @@
import { OpenAI } from 'openai'; import { OpenAI } from 'openai';
import { describe, it, expect, vi } from 'vitest'; import { describe, it, expect, vi, beforeEach } from 'vitest';
import { import {
BaseChatProvider, BaseChatProvider,
@@ -29,7 +29,7 @@ class TestChatProvider extends BaseChatProvider {
} }
// Mock dependencies // Mock dependencies
vi.mock('../../lib/chat-sdk', () => ({ vi.mock('../../chat-sdk/chat-sdk.ts', () => ({
default: { default: {
buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'), buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'),
buildMessageChain: vi.fn().mockReturnValue([ buildMessageChain: vi.fn().mockReturnValue([
@@ -39,6 +39,26 @@ vi.mock('../../lib/chat-sdk', () => ({
}, },
})); }));
vi.mock('../../tools/agentic-rag.ts', () => ({
agenticRAG: vi.fn(),
AgenticRAGTools: {
type: 'function',
function: {
name: 'agentic_rag',
description: 'Test agentic RAG tool',
parameters: {
type: 'object',
properties: {
action: { type: 'string', enum: ['search_knowledge'] },
query: { type: 'string' },
collection_name: { type: 'string' },
},
required: ['action', 'collection_name'],
},
},
},
}));
describe('ChatStreamProvider', () => { describe('ChatStreamProvider', () => {
it('should define the required interface', () => { it('should define the required interface', () => {
// Verify the interface has the required method // Verify the interface has the required method
@@ -50,26 +70,616 @@ describe('ChatStreamProvider', () => {
}); });
}); });
describe('BaseChatProvider', () => { describe('BaseChatProvider - Model Tool Calling', () => {
it('should implement the ChatStreamProvider interface', () => { let provider: TestChatProvider;
// Create a concrete implementation let mockOpenAI: any;
const provider = new TestChatProvider(); let dataCallback: any;
let commonParams: CommonProviderParams;
// Verify it implements the interface beforeEach(() => {
vi.clearAllMocks();
provider = new TestChatProvider();
dataCallback = vi.fn();
mockOpenAI = {
chat: {
completions: {
create: vi.fn(),
},
},
};
commonParams = {
openai: mockOpenAI,
systemPrompt: 'Test system prompt',
preprocessedContext: {},
maxTokens: 1000,
messages: [{ role: 'user', content: 'Test message' }],
model: 'gpt-4',
env: {} as any,
};
});
it('should implement the ChatStreamProvider interface', () => {
expect(provider.handleStream).toBeInstanceOf(Function); expect(provider.handleStream).toBeInstanceOf(Function);
expect(provider.getOpenAIClient).toBeInstanceOf(Function); expect(provider.getOpenAIClient).toBeInstanceOf(Function);
expect(provider.getStreamParams).toBeInstanceOf(Function); expect(provider.getStreamParams).toBeInstanceOf(Function);
expect(provider.processChunk).toBeInstanceOf(Function); expect(provider.processChunk).toBeInstanceOf(Function);
}); });
it('should have abstract methods that need to be implemented', () => { it('should handle regular text streaming without tool calls', async () => {
// This test verifies that the abstract methods exist // Mock stream chunks for regular text response
// We can't instantiate BaseChatProvider directly, so we use the concrete implementation const chunks = [
const provider = new TestChatProvider(); {
choices: [
{
delta: { content: 'Hello ' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: { content: 'world!' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
// Verify the abstract methods are implemented mockOpenAI.chat.completions.create.mockResolvedValue({
expect(provider.getOpenAIClient).toBeDefined(); async *[Symbol.asyncIterator]() {
expect(provider.getStreamParams).toBeDefined(); for (const chunk of chunks) {
expect(provider.processChunk).toBeDefined(); yield chunk;
}
},
});
await provider.handleStream(commonParams, dataCallback);
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: 'function',
function: expect.objectContaining({
name: 'agentic_rag',
}),
}),
]),
}),
);
});
it('should handle tool calls in streaming response', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: {
results: ['Test result'],
analysis: { needsRetrieval: false },
},
});
// Mock stream chunks for tool call response
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test query", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
// Second stream for response after tool execution
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'Based on the search results: Test result' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Verify tool was called
expect(agenticRAG).toHaveBeenCalledWith({
action: 'search_knowledge',
query: 'test query',
collection_name: 'test_collection',
});
// Verify feedback messages were sent
expect(dataCallback).toHaveBeenCalledWith(
expect.objectContaining({
type: 'chat',
data: expect.objectContaining({
choices: expect.arrayContaining([
expect.objectContaining({
delta: expect.objectContaining({
content: expect.stringContaining('🔧 Invoking'),
}),
}),
]),
}),
}),
);
expect(dataCallback).toHaveBeenCalledWith(
expect.objectContaining({
type: 'chat',
data: expect.objectContaining({
choices: expect.arrayContaining([
expect.objectContaining({
delta: expect.objectContaining({
content: expect.stringContaining('📞 Calling agentic_rag'),
}),
}),
]),
}),
}),
);
});
it('should handle tool call streaming with incremental arguments', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: { results: ['Test result'] },
});
// Mock stream chunks with incremental tool call arguments
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_',
type: 'function',
function: { name: 'agentic_rag', arguments: '{"action": "search_' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: '123',
function: { arguments: 'knowledge", "query": "test", ' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: { arguments: '"collection_name": "test_collection"}' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'Response after tool call' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Verify the complete tool call was assembled and executed
expect(agenticRAG).toHaveBeenCalledWith({
action: 'search_knowledge',
query: 'test',
collection_name: 'test_collection',
});
});
it('should prevent infinite tool call loops', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: {
results: [],
analysis: { needsRetrieval: true },
retrieved_documents: [],
},
});
// Mock stream that always returns tool calls
const toolCallChunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
mockOpenAI.chat.completions.create.mockResolvedValue({
async *[Symbol.asyncIterator]() {
for (const chunk of toolCallChunks) {
yield chunk;
}
},
});
await provider.handleStream(commonParams, dataCallback);
// Should detect duplicate tool calls and force completion (up to 5 iterations based on maxToolCallIterations)
// In this case, it should stop after 2 calls due to duplicate detection, but could go up to 5
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
});
it('should handle tool call errors gracefully', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockRejectedValue(new Error('Tool execution failed'));
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'I apologize, but I encountered an error.' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Should still complete without throwing
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
});
it('should prevent duplicate tool calls', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: { results: ['Test result'] },
});
// Mock the same tool call twice
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
// Second iteration with same tool call
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
});
await provider.handleStream(commonParams, dataCallback);
// Should only execute the tool once, then force completion
expect(agenticRAG).toHaveBeenCalledTimes(1);
});
it('should handle invalid JSON in tool call arguments', async () => {
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments: '{"action": "search_knowledge", "invalid": json}', // Invalid JSON
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'I encountered an error parsing the tool arguments.' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
// Should not throw, should handle gracefully
await expect(provider.handleStream(commonParams, dataCallback)).resolves.not.toThrow();
}); });
}); });

View File

@@ -50,6 +50,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
let toolCallIterations = 0; let toolCallIterations = 0;
const maxToolCallIterations = 5; // Prevent infinite loops const maxToolCallIterations = 5; // Prevent infinite loops
let toolsExecuted = false; // Track if we've executed tools let toolsExecuted = false; // Track if we've executed tools
const attemptedToolCalls = new Set<string>(); // Track attempted tool calls to prevent duplicates
while (!conversationComplete && toolCallIterations < maxToolCallIterations) { while (!conversationComplete && toolCallIterations < maxToolCallIterations) {
const streamParams = this.getStreamParams(param, safeMessages); const streamParams = this.getStreamParams(param, safeMessages);
@@ -112,6 +113,21 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
// Execute tool calls and add results to conversation // Execute tool calls and add results to conversation
console.log('Executing tool calls:', toolCalls); console.log('Executing tool calls:', toolCalls);
// Limit to one tool call per iteration to prevent concurrent execution issues
// Also filter out duplicate tool calls
const uniqueToolCalls = toolCalls.filter(toolCall => {
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
return !attemptedToolCalls.has(toolCallKey);
});
const toolCallsToExecute = uniqueToolCalls.slice(0, 1);
if (toolCallsToExecute.length === 0) {
console.log('All tool calls have been attempted already, forcing completion');
toolsExecuted = true;
conversationComplete = true;
break;
}
// Send feedback to user about tool invocation // Send feedback to user about tool invocation
dataCallback({ dataCallback({
type: 'chat', type: 'chat',
@@ -119,7 +135,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
choices: [ choices: [
{ {
delta: { delta: {
content: `\n\n🔧 Invoking ${toolCalls.length} tool${toolCalls.length > 1 ? 's' : ''}...\n`, content: `\n\n🔧 Invoking ${toolCallsToExecute.length} tool${toolCallsToExecute.length > 1 ? 's' : ''}...\n`,
}, },
}, },
], ],
@@ -130,15 +146,20 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
safeMessages.push({ safeMessages.push({
role: 'assistant', role: 'assistant',
content: assistantMessage || null, content: assistantMessage || null,
tool_calls: toolCalls, tool_calls: toolCallsToExecute,
}); });
// Execute each tool call and add results // Execute each tool call and add results
for (const toolCall of toolCalls) { let needsMoreRetrieval = false;
for (const toolCall of toolCallsToExecute) {
if (toolCall.type === 'function') { if (toolCall.type === 'function') {
const name = toolCall.function.name; const name = toolCall.function.name;
console.log(`Calling function: ${name}`); console.log(`Calling function: ${name}`);
// Track this tool call attempt
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
attemptedToolCalls.add(toolCallKey);
// Send feedback about specific tool being called // Send feedback about specific tool being called
dataCallback({ dataCallback({
type: 'chat', type: 'chat',
@@ -160,6 +181,36 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
const result = await callFunction(name, args); const result = await callFunction(name, args);
console.log(`Function result:`, result); console.log(`Function result:`, result);
// Check if agentic-rag indicates more retrieval is needed
if (
name === 'agentic_rag' &&
result?.data?.analysis?.needsRetrieval === true &&
(!result?.data?.retrieved_documents ||
result.data.retrieved_documents.length === 0)
) {
needsMoreRetrieval = true;
console.log('Agentic RAG indicates more retrieval needed');
// Add context about previous attempts to help LLM make better decisions
const attemptedActions = Array.from(attemptedToolCalls)
.filter(key => key.startsWith('agentic_rag:'))
.map(key => {
try {
const args = JSON.parse(key.split(':', 2)[1]);
return `${args.action} with query: "${args.query}"`;
} catch {
return 'unknown action';
}
});
if (attemptedActions.length > 0) {
safeMessages.push({
role: 'system',
content: `Previous retrieval attempts: ${attemptedActions.join(', ')}. Consider trying a different approach or more specific query.`,
});
}
}
// Send feedback about tool completion // Send feedback about tool completion
dataCallback({ dataCallback({
type: 'chat', type: 'chat',
@@ -206,8 +257,10 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
} }
} }
// Mark that tools have been executed to prevent repeated calls // Only mark tools as executed if we don't need more retrieval
if (!needsMoreRetrieval) {
toolsExecuted = true; toolsExecuted = true;
}
// Send feedback that tool execution is complete // Send feedback that tool execution is complete
dataCallback({ dataCallback({

View File

@@ -1,36 +1,88 @@
import { describe, it, expect } from 'vitest'; import { describe, it, expect, vi, beforeEach } from 'vitest';
import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean';
import { agenticRAG, AgenticRAGTools } from '../agentic-rag';
// Mock the dependencies
vi.mock('@zilliz/milvus2-sdk-node', () => ({
MilvusClient: vi.fn().mockImplementation(() => ({
listCollections: vi.fn().mockResolvedValue({
collection_names: ['family_domestic', 'business_corporate'],
data: [{ name: 'family_domestic' }, { name: 'business_corporate' }],
}),
search: vi.fn().mockResolvedValue({
results: [
{
content: 'Test document about AI and machine learning',
score: 0.85,
metadata: '{"category": "AI", "author": "Test Author"}',
},
{
content: 'Another document about neural networks',
score: 0.75,
metadata: '{"category": "ML", "author": "Another Author"}',
},
],
}),
insert: vi.fn().mockResolvedValue({ success: true }),
createCollection: vi.fn().mockResolvedValue({ success: true }),
createIndex: vi.fn().mockResolvedValue({ success: true }),
})),
DataType: {
VarChar: 'VarChar',
FloatVector: 'FloatVector',
},
}));
vi.mock('openai', () => ({
OpenAI: vi.fn().mockImplementation(() => ({
embeddings: {
create: vi.fn().mockResolvedValue({
data: [{ embedding: new Array(768).fill(0.1) }],
}),
},
})),
}));
// Mock environment variables
vi.stubEnv('FIREWORKS_API_KEY', 'test-api-key');
describe('Agentic RAG System', () => { describe('Agentic RAG System', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('should analyze queries correctly', async () => { it('should analyze queries correctly', async () => {
// Test factual query // Test factual query
const factualResult = await agenticRAG({ const factualResult = await agenticRAG({
action: 'analyze_query', action: 'analyze_query',
query: 'What is artificial intelligence?', query: 'What is artificial intelligence?',
collection_name: 'family_domestic',
}); });
expect(factualResult.status).toBe('success'); expect(factualResult.status).toBe('success');
expect(factualResult.needsRetrieval).toBe(true); expect(factualResult.data.needsRetrieval).toBe(true);
expect(factualResult.data.queryType).toBe('factual'); expect(factualResult.data.queryType).toBe('factual');
// Test conversational query // Test conversational query with multiple conversational keywords
const conversationalResult = await agenticRAG({ const conversationalResult = await agenticRAG({
action: 'analyze_query', action: 'analyze_query',
query: 'Hello, how are you?', query: 'Hello, how are you doing today?',
collection_name: 'family_domestic',
}); });
expect(conversationalResult.status).toBe('success'); expect(conversationalResult.status).toBe('success');
expect(conversationalResult.needsRetrieval).toBe(false); expect(conversationalResult.data.needsRetrieval).toBe(false);
expect(conversationalResult.data.queryType).toBe('conversational'); expect(conversationalResult.data.queryType).toBe('conversational');
// Test creative query // Test creative query with multiple creative keywords
const creativeResult = await agenticRAG({ const creativeResult = await agenticRAG({
action: 'analyze_query', action: 'analyze_query',
query: 'Write a story about a robot', query: 'Write a story and compose a poem',
collection_name: 'family_domestic',
}); });
expect(creativeResult.status).toBe('success'); expect(creativeResult.status).toBe('success');
expect(creativeResult.needsRetrieval).toBe(false); expect(creativeResult.data.needsRetrieval).toBe(false);
expect(creativeResult.data.queryType).toBe('creative'); expect(creativeResult.data.queryType).toBe('creative');
}); });
@@ -38,25 +90,27 @@ describe('Agentic RAG System', () => {
const result = await agenticRAG({ const result = await agenticRAG({
action: 'search_knowledge', action: 'search_knowledge',
query: 'What is machine learning?', query: 'What is machine learning?',
collection_name: 'family_domestic',
top_k: 2, top_k: 2,
threshold: 0.1, similarity_threshold: 0.1,
}); });
expect(result.status).toBe('success'); expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true);
expect(result.context).toBeDefined(); expect(result.context).toBeDefined();
expect(Array.isArray(result.context)).toBe(true); expect(Array.isArray(result.context)).toBe(true);
expect(result.data.retrieved_documents).toBeDefined(); expect(result.data.retrieved_documents).toBeDefined();
expect(result.data.analysis.needsRetrieval).toBe(true);
}); });
it('should not search for conversational queries', async () => { it('should not search for conversational queries', async () => {
const result = await agenticRAG({ const result = await agenticRAG({
action: 'search_knowledge', action: 'search_knowledge',
query: 'Hello there!', query: 'Hello there! How are you?',
collection_name: 'family_domestic',
}); });
expect(result.status).toBe('success'); expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(false); expect(result.data.analysis.needsRetrieval).toBe(false);
expect(result.data.retrieved_documents).toHaveLength(0); expect(result.data.retrieved_documents).toHaveLength(0);
}); });
@@ -68,6 +122,7 @@ describe('Agentic RAG System', () => {
content: 'This is a test document about neural networks and deep learning.', content: 'This is a test document about neural networks and deep learning.',
metadata: { category: 'AI', author: 'Test Author' }, metadata: { category: 'AI', author: 'Test Author' },
}, },
collection_name: 'family_domestic',
}); });
expect(result.status).toBe('success'); expect(result.status).toBe('success');
@@ -79,18 +134,43 @@ describe('Agentic RAG System', () => {
const result = await agenticRAG({ const result = await agenticRAG({
action: 'get_context', action: 'get_context',
query: 'Tell me about vector databases', query: 'Tell me about vector databases',
collection_name: 'family_domestic',
top_k: 2, top_k: 2,
}); });
expect(result.status).toBe('success'); expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true); expect(result.data.analysis.needsRetrieval).toBe(true);
expect(result.context).toBeDefined(); expect(result.context).toBeDefined();
expect(result.data.context_summary).toBeDefined(); expect(result.data.context_summary).toBeDefined();
}); });
it('should handle semantic search', async () => {
const result = await agenticRAG({
action: 'semantic_search',
query: 'artificial intelligence concepts',
collection_name: 'family_domestic',
top_k: 3,
});
expect(result.status).toBe('success');
expect(result.data.results).toBeDefined();
expect(Array.isArray(result.data.results)).toBe(true);
});
it('should list collections', async () => {
const result = await agenticRAG({
action: 'list_collections',
collection_name: 'family_domestic',
});
expect(result.status).toBe('success');
expect(result.message).toContain('family_domestic');
});
it('should handle errors gracefully', async () => { it('should handle errors gracefully', async () => {
const result = await agenticRAG({ const result = await agenticRAG({
action: 'analyze_query', action: 'analyze_query',
collection_name: 'family_domestic',
// Missing query parameter // Missing query parameter
}); });
@@ -98,35 +178,82 @@ describe('Agentic RAG System', () => {
expect(result.message).toContain('Query is required'); expect(result.message).toContain('Query is required');
}); });
it('should handle invalid actions', async () => {
const result = await agenticRAG({
action: 'invalid_action',
collection_name: 'family_domestic',
});
expect(result.status).toBe('error');
expect(result.message).toContain('Invalid action');
});
it('should have correct tool definition structure', () => { it('should have correct tool definition structure', () => {
expect(AgenticRAGTool.type).toBe('function'); expect(AgenticRAGTools.type).toBe('function');
expect(AgenticRAGTool.function.name).toBe('agentic_rag'); expect(AgenticRAGTools.function.name).toBe('agentic_rag');
expect(AgenticRAGTool.function.description).toBeDefined(); expect(AgenticRAGTools.function.description).toBeDefined();
expect(AgenticRAGTool.function.parameters.type).toBe('object'); expect(AgenticRAGTools.function.parameters.type).toBe('object');
expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined(); expect(AgenticRAGTools.function.parameters.properties.action).toBeDefined();
expect(AgenticRAGTool.function.parameters.required).toContain('action'); expect(AgenticRAGTools.function.parameters.required).toContain('action');
expect(AgenticRAGTools.function.parameters.required).toContain('collection_name');
}); });
it('should demonstrate intelligent retrieval decision making', async () => { it('should demonstrate intelligent retrieval decision making', async () => {
// Test various query types to show intelligent decision making // Test various query types to show intelligent decision making
const queries = [ const queries = [
{ query: 'What is AI?', expectedRetrieval: true }, { query: 'What is AI?', expectedRetrieval: true },
{ query: 'Hello world', expectedRetrieval: false }, { query: 'Hello world how are you', expectedRetrieval: false },
{ query: 'Write a poem', expectedRetrieval: false }, { query: 'Write a poem and create a story', expectedRetrieval: false },
{ query: 'Explain machine learning', expectedRetrieval: true }, { query: 'Explain machine learning', expectedRetrieval: true },
{ query: 'How are you doing?', expectedRetrieval: false }, { query: 'How are you doing today?', expectedRetrieval: true },
{ query: 'Tell me about neural networks', expectedRetrieval: true },
]; ];
for (const testCase of queries) { for (const testCase of queries) {
const result = await agenticRAG({ const result = await agenticRAG({
action: 'search_knowledge', action: 'search_knowledge',
query: testCase.query, query: testCase.query,
collection_name: 'family_domestic',
}); });
expect(result.status).toBe('success'); expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(testCase.expectedRetrieval); expect(result.data.analysis.needsRetrieval).toBe(testCase.expectedRetrieval);
console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`); console.log(
`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.data.analysis.needsRetrieval}`,
);
}
});
it('should filter results by similarity threshold', async () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'What is machine learning?',
collection_name: 'family_domestic',
similarity_threshold: 0.8, // High threshold
});
expect(result.status).toBe('success');
if (result.data.analysis.needsRetrieval) {
// Should only return results above threshold
result.data.retrieved_documents.forEach((doc: any) => {
expect(doc.score).toBeGreaterThanOrEqual(0.8);
});
}
});
it('should handle context window limits', async () => {
const result = await agenticRAG({
action: 'get_context',
query: 'Tell me about artificial intelligence',
collection_name: 'family_domestic',
context_window: 1000,
});
expect(result.status).toBe('success');
if (result.data.analysis.needsRetrieval && result.data.context_summary) {
// Context should respect the window limit (approximate check)
expect(result.data.context_summary.length).toBeLessThanOrEqual(2000); // Allow some flexibility
} }
}); });
}); });

View File

@@ -50,24 +50,48 @@ export const AgenticRAGTools = {
properties: { properties: {
action: { action: {
type: 'string', type: 'string',
enum: ['search', 'list_collections', 'report_status'], enum: [
'list_collections',
'report_status',
'semantic_search',
'search_knowledge',
'analyze_query',
'get_context',
],
description: 'Action to perform with the agentic RAG system.', description: 'Action to perform with the agentic RAG system.',
}, },
query: { query: {
type: 'string', type: 'string',
description: 'User query or search term for knowledge retrieval.', description: 'User query or search term for knowledge retrieval.',
}, },
document: { // document: {
type: 'object', // type: 'object',
properties: { // properties: {
content: { type: 'string', description: 'Document content to store' }, // content: { type: 'string', description: 'Document content to store' },
metadata: { type: 'object', description: 'Additional metadata for the document' }, // metadata: { type: 'object', description: 'Additional metadata for the document' },
id: { type: 'string', description: 'Unique identifier for the document' }, // id: { type: 'string', description: 'Unique identifier for the document' },
}, // },
description: 'Document to store in the knowledge base.', // description: 'Document to store in the knowledge base.',
}, // },
collection_name: { collection_name: {
type: 'string', type: 'string',
// todo: make this fancy w/ dynamic collection
enum: [
'business_corporate',
'civil_procedure',
'criminal_justice',
'education_professions',
'environmental_infrastructure',
'family_domestic',
'foundational_law',
'government_administration',
'health_social_services',
'miscellaneous',
'property_real_estate',
'special_documents',
'taxation_finance',
'transportation_motor_vehicles',
],
description: 'Name of the collection to work with.', description: 'Name of the collection to work with.',
}, },
top_k: { top_k: {
@@ -83,7 +107,7 @@ export const AgenticRAGTools = {
description: 'Maximum number of context tokens to include (default: 2000).', description: 'Maximum number of context tokens to include (default: 2000).',
}, },
}, },
required: ['action'], required: ['action', 'collection_name'],
additionalProperties: false, additionalProperties: false,
}, },
strict: true, strict: true,
@@ -95,10 +119,10 @@ export const AgenticRAGTools = {
*/ */
const DEFAULT_CONFIG: AgenticRAGConfig = { const DEFAULT_CONFIG: AgenticRAGConfig = {
milvusAddress: 'localhost:19530', milvusAddress: 'localhost:19530',
collectionName: 'knowledge_base', collectionName: 'family_domestic',
embeddingDimension: 768, embeddingDimension: 768,
topK: 5, topK: 5,
similarityThreshold: 0.58, similarityThreshold: 0.5,
}; };
/** /**
@@ -200,14 +224,16 @@ function analyzeQueryForRetrieval(query: string): {
'Query appears to be asking for factual information that may benefit from knowledge retrieval.', 'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
queryType: 'factual', queryType: 'factual',
}; };
} else if (creativeScore > conversationalScore) { } else if (creativeScore > conversationalScore && creativeScore > 1) {
// Only skip retrieval for clearly creative tasks with multiple creative keywords
return { return {
needsRetrieval: false, needsRetrieval: false,
confidence: 0.8, confidence: 0.8,
reasoning: 'Query appears to be requesting creative content generation.', reasoning: 'Query appears to be requesting creative content generation.',
queryType: 'creative', queryType: 'creative',
}; };
} else if (conversationalScore > 0) { } else if (conversationalScore > 1 && conversationalScore > factualScore) {
// Only skip retrieval for clearly conversational queries with multiple conversational keywords
return { return {
needsRetrieval: false, needsRetrieval: false,
confidence: 0.7, confidence: 0.7,
@@ -215,10 +241,11 @@ function analyzeQueryForRetrieval(query: string): {
queryType: 'conversational', queryType: 'conversational',
}; };
} else { } else {
// Default to retrieval for most cases to ensure comprehensive responses
return { return {
needsRetrieval: true, needsRetrieval: true,
confidence: 0.6, confidence: 0.8,
reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.', reasoning: 'Defaulting to retrieval to provide comprehensive and accurate information.',
queryType: 'analytical', queryType: 'analytical',
}; };
} }
@@ -235,8 +262,8 @@ export async function agenticRAG(args: {
top_k?: number; top_k?: number;
similarity_threshold?: number; similarity_threshold?: number;
context_window?: number; context_window?: number;
user_confirmed?: boolean;
}): Promise<AgenticRAGResult> { }): Promise<AgenticRAGResult> {
console.log('calling agentic rag tool', args);
const config = { ...DEFAULT_CONFIG }; const config = { ...DEFAULT_CONFIG };
const collectionName = args.collection_name || config.collectionName!; const collectionName = args.collection_name || config.collectionName!;
const topK = args.top_k || config.topK!; const topK = args.top_k || config.topK!;
@@ -297,9 +324,9 @@ export async function agenticRAG(args: {
// eslint-disable-next-line no-case-declarations // eslint-disable-next-line no-case-declarations
const searchResult = await milvusClient.search({ const searchResult = await milvusClient.search({
collection_name: collectionName, collection_name: collectionName,
query_vectors: [queryEmbedding], vector: queryEmbedding,
top_k: topK, topk: topK,
params: { nprobe: 16 }, params: { nprobe: 8 },
output_fields: ['content', 'metadata'], output_fields: ['content', 'metadata'],
}); });
@@ -385,8 +412,7 @@ export async function agenticRAG(args: {
], ],
}; };
// @ts-expect-error - idk man await milvusClient.createCollection(collectionSchema as any);
await milvusClient.createCollection(collectionSchema);
// Create index for efficient similarity search // Create index for efficient similarity search
await milvusClient.createIndex({ await milvusClient.createIndex({
@@ -426,9 +452,9 @@ export async function agenticRAG(args: {
// eslint-disable-next-line no-case-declarations // eslint-disable-next-line no-case-declarations
const semanticResults = await milvusClient.search({ const semanticResults = await milvusClient.search({
collection_name: collectionName, collection_name: collectionName,
query_vectors: [semanticEmbedding], vector: semanticEmbedding,
top_k: topK, topk: topK,
params: { nprobe: 16 }, params: { nprobe: 8 },
output_fields: ['content', 'metadata'], output_fields: ['content', 'metadata'],
}); });
@@ -452,14 +478,13 @@ export async function agenticRAG(args: {
// This is a comprehensive context retrieval that combines analysis and search // This is a comprehensive context retrieval that combines analysis and search
// eslint-disable-next-line no-case-declarations // eslint-disable-next-line no-case-declarations
const contextAnalysis = analyzeQueryForRetrieval(args.query); const contextAnalysis = analyzeQueryForRetrieval(args.query);
if (contextAnalysis.needsRetrieval) { if (contextAnalysis.needsRetrieval) {
const contextEmbedding = await generateEmbedding(args.query); const contextEmbedding = await generateEmbedding(args.query);
const contextSearch = await milvusClient.search({ const contextSearch = await milvusClient.search({
collection_name: collectionName, collection_name: collectionName,
query_vectors: [contextEmbedding], vector: contextEmbedding,
top_k: topK, topk: topK,
params: { nprobe: 16 }, params: { nprobe: 8 },
output_fields: ['content', 'metadata'], output_fields: ['content', 'metadata'],
}); });

View File

@@ -39,19 +39,19 @@ export const LandingComponent: React.FC = () => {
}, },
}} }}
switches={{ switches={{
GpsMap: { // GpsMap: {
value: mapActive, // value: mapActive,
onChange(enabled) { // onChange(enabled) {
if (enabled) { // if (enabled) {
setEnabledComponent('gpsmap'); // setEnabledComponent('gpsmap');
setAiActive(false); // setAiActive(false);
} else { // } else {
setEnabledComponent(''); // setEnabledComponent('');
} // }
setMapActive(enabled); // setMapActive(enabled);
}, // },
label: 'GPS', // label: 'GPS',
}, // },
AI: { AI: {
value: aiActive, value: aiActive,
onChange(enabled) { onChange(enabled) {