Refine tool call execution logic in chat-stream-provider to prevent duplicates, enhance retries for agentic-rag, and improve incremental processing, including test updates.

This commit is contained in:
geoffsee
2025-07-31 18:26:42 -04:00
parent 6c433581d3
commit 9810f67af0
6 changed files with 906 additions and 90 deletions

1
.gitignore vendored
View File

@@ -25,3 +25,4 @@ wrangler.dev.jsonc
/packages/client/public/pwa-192x192.png
/packages/client/public/pwa-512x512.png
packages/client/public/yachtpit_bg*
/project/

View File

@@ -1,5 +1,5 @@
import { OpenAI } from 'openai';
import { describe, it, expect, vi } from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
BaseChatProvider,
@@ -29,7 +29,7 @@ class TestChatProvider extends BaseChatProvider {
}
// Mock dependencies
vi.mock('../../lib/chat-sdk', () => ({
vi.mock('../../chat-sdk/chat-sdk.ts', () => ({
default: {
buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'),
buildMessageChain: vi.fn().mockReturnValue([
@@ -39,6 +39,26 @@ vi.mock('../../lib/chat-sdk', () => ({
},
}));
vi.mock('../../tools/agentic-rag.ts', () => ({
agenticRAG: vi.fn(),
AgenticRAGTools: {
type: 'function',
function: {
name: 'agentic_rag',
description: 'Test agentic RAG tool',
parameters: {
type: 'object',
properties: {
action: { type: 'string', enum: ['search_knowledge'] },
query: { type: 'string' },
collection_name: { type: 'string' },
},
required: ['action', 'collection_name'],
},
},
},
}));
describe('ChatStreamProvider', () => {
it('should define the required interface', () => {
// Verify the interface has the required method
@@ -50,26 +70,616 @@ describe('ChatStreamProvider', () => {
});
});
describe('BaseChatProvider', () => {
it('should implement the ChatStreamProvider interface', () => {
// Create a concrete implementation
const provider = new TestChatProvider();
describe('BaseChatProvider - Model Tool Calling', () => {
let provider: TestChatProvider;
let mockOpenAI: any;
let dataCallback: any;
let commonParams: CommonProviderParams;
// Verify it implements the interface
beforeEach(() => {
vi.clearAllMocks();
provider = new TestChatProvider();
dataCallback = vi.fn();
mockOpenAI = {
chat: {
completions: {
create: vi.fn(),
},
},
};
commonParams = {
openai: mockOpenAI,
systemPrompt: 'Test system prompt',
preprocessedContext: {},
maxTokens: 1000,
messages: [{ role: 'user', content: 'Test message' }],
model: 'gpt-4',
env: {} as any,
};
});
it('should implement the ChatStreamProvider interface', () => {
expect(provider.handleStream).toBeInstanceOf(Function);
expect(provider.getOpenAIClient).toBeInstanceOf(Function);
expect(provider.getStreamParams).toBeInstanceOf(Function);
expect(provider.processChunk).toBeInstanceOf(Function);
});
it('should have abstract methods that need to be implemented', () => {
// This test verifies that the abstract methods exist
// We can't instantiate BaseChatProvider directly, so we use the concrete implementation
const provider = new TestChatProvider();
it('should handle regular text streaming without tool calls', async () => {
// Mock stream chunks for regular text response
const chunks = [
{
choices: [
{
delta: { content: 'Hello ' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: { content: 'world!' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
// Verify the abstract methods are implemented
expect(provider.getOpenAIClient).toBeDefined();
expect(provider.getStreamParams).toBeDefined();
expect(provider.processChunk).toBeDefined();
mockOpenAI.chat.completions.create.mockResolvedValue({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
await provider.handleStream(commonParams, dataCallback);
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledWith(
expect.objectContaining({
tools: expect.arrayContaining([
expect.objectContaining({
type: 'function',
function: expect.objectContaining({
name: 'agentic_rag',
}),
}),
]),
}),
);
});
it('should handle tool calls in streaming response', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: {
results: ['Test result'],
analysis: { needsRetrieval: false },
},
});
// Mock stream chunks for tool call response
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test query", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
// Second stream for response after tool execution
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'Based on the search results: Test result' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Verify tool was called
expect(agenticRAG).toHaveBeenCalledWith({
action: 'search_knowledge',
query: 'test query',
collection_name: 'test_collection',
});
// Verify feedback messages were sent
expect(dataCallback).toHaveBeenCalledWith(
expect.objectContaining({
type: 'chat',
data: expect.objectContaining({
choices: expect.arrayContaining([
expect.objectContaining({
delta: expect.objectContaining({
content: expect.stringContaining('🔧 Invoking'),
}),
}),
]),
}),
}),
);
expect(dataCallback).toHaveBeenCalledWith(
expect.objectContaining({
type: 'chat',
data: expect.objectContaining({
choices: expect.arrayContaining([
expect.objectContaining({
delta: expect.objectContaining({
content: expect.stringContaining('📞 Calling agentic_rag'),
}),
}),
]),
}),
}),
);
});
it('should handle tool call streaming with incremental arguments', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: { results: ['Test result'] },
});
// Mock stream chunks with incremental tool call arguments
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_',
type: 'function',
function: { name: 'agentic_rag', arguments: '{"action": "search_' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: '123',
function: { arguments: 'knowledge", "query": "test", ' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
function: { arguments: '"collection_name": "test_collection"}' },
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'Response after tool call' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Verify the complete tool call was assembled and executed
expect(agenticRAG).toHaveBeenCalledWith({
action: 'search_knowledge',
query: 'test',
collection_name: 'test_collection',
});
});
it('should prevent infinite tool call loops', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: {
results: [],
analysis: { needsRetrieval: true },
retrieved_documents: [],
},
});
// Mock stream that always returns tool calls
const toolCallChunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
mockOpenAI.chat.completions.create.mockResolvedValue({
async *[Symbol.asyncIterator]() {
for (const chunk of toolCallChunks) {
yield chunk;
}
},
});
await provider.handleStream(commonParams, dataCallback);
// Should detect duplicate tool calls and force completion (up to 5 iterations based on maxToolCallIterations)
// In this case, it should stop after 2 calls due to duplicate detection, but could go up to 5
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
});
it('should handle tool call errors gracefully', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockRejectedValue(new Error('Tool execution failed'));
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'I apologize, but I encountered an error.' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
await provider.handleStream(commonParams, dataCallback);
// Should still complete without throwing
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
});
it('should prevent duplicate tool calls', async () => {
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
vi.mocked(agenticRAG).mockResolvedValue({
success: true,
data: { results: ['Test result'] },
});
// Mock the same tool call twice
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments:
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
// Second iteration with same tool call
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
});
await provider.handleStream(commonParams, dataCallback);
// Should only execute the tool once, then force completion
expect(agenticRAG).toHaveBeenCalledTimes(1);
});
it('should handle invalid JSON in tool call arguments', async () => {
const chunks = [
{
choices: [
{
delta: {
tool_calls: [
{
index: 0,
id: 'call_123',
type: 'function',
function: {
name: 'agentic_rag',
arguments: '{"action": "search_knowledge", "invalid": json}', // Invalid JSON
},
},
],
},
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'tool_calls',
},
],
},
];
const secondStreamChunks = [
{
choices: [
{
delta: { content: 'I encountered an error parsing the tool arguments.' },
finish_reason: null,
},
],
},
{
choices: [
{
delta: {},
finish_reason: 'stop',
},
],
},
];
let callCount = 0;
mockOpenAI.chat.completions.create.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
yield chunk;
}
},
});
} else {
return Promise.resolve({
async *[Symbol.asyncIterator]() {
for (const chunk of secondStreamChunks) {
yield chunk;
}
},
});
}
});
// Should not throw, should handle gracefully
await expect(provider.handleStream(commonParams, dataCallback)).resolves.not.toThrow();
});
});

View File

@@ -50,6 +50,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
let toolCallIterations = 0;
const maxToolCallIterations = 5; // Prevent infinite loops
let toolsExecuted = false; // Track if we've executed tools
const attemptedToolCalls = new Set<string>(); // Track attempted tool calls to prevent duplicates
while (!conversationComplete && toolCallIterations < maxToolCallIterations) {
const streamParams = this.getStreamParams(param, safeMessages);
@@ -112,6 +113,21 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
// Execute tool calls and add results to conversation
console.log('Executing tool calls:', toolCalls);
// Limit to one tool call per iteration to prevent concurrent execution issues
// Also filter out duplicate tool calls
const uniqueToolCalls = toolCalls.filter(toolCall => {
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
return !attemptedToolCalls.has(toolCallKey);
});
const toolCallsToExecute = uniqueToolCalls.slice(0, 1);
if (toolCallsToExecute.length === 0) {
console.log('All tool calls have been attempted already, forcing completion');
toolsExecuted = true;
conversationComplete = true;
break;
}
// Send feedback to user about tool invocation
dataCallback({
type: 'chat',
@@ -119,7 +135,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
choices: [
{
delta: {
content: `\n\n🔧 Invoking ${toolCalls.length} tool${toolCalls.length > 1 ? 's' : ''}...\n`,
content: `\n\n🔧 Invoking ${toolCallsToExecute.length} tool${toolCallsToExecute.length > 1 ? 's' : ''}...\n`,
},
},
],
@@ -130,15 +146,20 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
safeMessages.push({
role: 'assistant',
content: assistantMessage || null,
tool_calls: toolCalls,
tool_calls: toolCallsToExecute,
});
// Execute each tool call and add results
for (const toolCall of toolCalls) {
let needsMoreRetrieval = false;
for (const toolCall of toolCallsToExecute) {
if (toolCall.type === 'function') {
const name = toolCall.function.name;
console.log(`Calling function: ${name}`);
// Track this tool call attempt
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
attemptedToolCalls.add(toolCallKey);
// Send feedback about specific tool being called
dataCallback({
type: 'chat',
@@ -160,6 +181,36 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
const result = await callFunction(name, args);
console.log(`Function result:`, result);
// Check if agentic-rag indicates more retrieval is needed
if (
name === 'agentic_rag' &&
result?.data?.analysis?.needsRetrieval === true &&
(!result?.data?.retrieved_documents ||
result.data.retrieved_documents.length === 0)
) {
needsMoreRetrieval = true;
console.log('Agentic RAG indicates more retrieval needed');
// Add context about previous attempts to help LLM make better decisions
const attemptedActions = Array.from(attemptedToolCalls)
.filter(key => key.startsWith('agentic_rag:'))
.map(key => {
try {
const args = JSON.parse(key.split(':', 2)[1]);
return `${args.action} with query: "${args.query}"`;
} catch {
return 'unknown action';
}
});
if (attemptedActions.length > 0) {
safeMessages.push({
role: 'system',
content: `Previous retrieval attempts: ${attemptedActions.join(', ')}. Consider trying a different approach or more specific query.`,
});
}
}
// Send feedback about tool completion
dataCallback({
type: 'chat',
@@ -206,8 +257,10 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
}
}
// Mark that tools have been executed to prevent repeated calls
// Only mark tools as executed if we don't need more retrieval
if (!needsMoreRetrieval) {
toolsExecuted = true;
}
// Send feedback that tool execution is complete
dataCallback({

View File

@@ -1,36 +1,88 @@
import { describe, it, expect } from 'vitest';
import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { agenticRAG, AgenticRAGTools } from '../agentic-rag';
// Mock the dependencies
vi.mock('@zilliz/milvus2-sdk-node', () => ({
MilvusClient: vi.fn().mockImplementation(() => ({
listCollections: vi.fn().mockResolvedValue({
collection_names: ['family_domestic', 'business_corporate'],
data: [{ name: 'family_domestic' }, { name: 'business_corporate' }],
}),
search: vi.fn().mockResolvedValue({
results: [
{
content: 'Test document about AI and machine learning',
score: 0.85,
metadata: '{"category": "AI", "author": "Test Author"}',
},
{
content: 'Another document about neural networks',
score: 0.75,
metadata: '{"category": "ML", "author": "Another Author"}',
},
],
}),
insert: vi.fn().mockResolvedValue({ success: true }),
createCollection: vi.fn().mockResolvedValue({ success: true }),
createIndex: vi.fn().mockResolvedValue({ success: true }),
})),
DataType: {
VarChar: 'VarChar',
FloatVector: 'FloatVector',
},
}));
vi.mock('openai', () => ({
OpenAI: vi.fn().mockImplementation(() => ({
embeddings: {
create: vi.fn().mockResolvedValue({
data: [{ embedding: new Array(768).fill(0.1) }],
}),
},
})),
}));
// Mock environment variables
vi.stubEnv('FIREWORKS_API_KEY', 'test-api-key');
describe('Agentic RAG System', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('should analyze queries correctly', async () => {
// Test factual query
const factualResult = await agenticRAG({
action: 'analyze_query',
query: 'What is artificial intelligence?',
collection_name: 'family_domestic',
});
expect(factualResult.status).toBe('success');
expect(factualResult.needsRetrieval).toBe(true);
expect(factualResult.data.needsRetrieval).toBe(true);
expect(factualResult.data.queryType).toBe('factual');
// Test conversational query
// Test conversational query with multiple conversational keywords
const conversationalResult = await agenticRAG({
action: 'analyze_query',
query: 'Hello, how are you?',
query: 'Hello, how are you doing today?',
collection_name: 'family_domestic',
});
expect(conversationalResult.status).toBe('success');
expect(conversationalResult.needsRetrieval).toBe(false);
expect(conversationalResult.data.needsRetrieval).toBe(false);
expect(conversationalResult.data.queryType).toBe('conversational');
// Test creative query
// Test creative query with multiple creative keywords
const creativeResult = await agenticRAG({
action: 'analyze_query',
query: 'Write a story about a robot',
query: 'Write a story and compose a poem',
collection_name: 'family_domestic',
});
expect(creativeResult.status).toBe('success');
expect(creativeResult.needsRetrieval).toBe(false);
expect(creativeResult.data.needsRetrieval).toBe(false);
expect(creativeResult.data.queryType).toBe('creative');
});
@@ -38,25 +90,27 @@ describe('Agentic RAG System', () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'What is machine learning?',
collection_name: 'family_domestic',
top_k: 2,
threshold: 0.1,
similarity_threshold: 0.1,
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true);
expect(result.context).toBeDefined();
expect(Array.isArray(result.context)).toBe(true);
expect(result.data.retrieved_documents).toBeDefined();
expect(result.data.analysis.needsRetrieval).toBe(true);
});
it('should not search for conversational queries', async () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'Hello there!',
query: 'Hello there! How are you?',
collection_name: 'family_domestic',
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(false);
expect(result.data.analysis.needsRetrieval).toBe(false);
expect(result.data.retrieved_documents).toHaveLength(0);
});
@@ -68,6 +122,7 @@ describe('Agentic RAG System', () => {
content: 'This is a test document about neural networks and deep learning.',
metadata: { category: 'AI', author: 'Test Author' },
},
collection_name: 'family_domestic',
});
expect(result.status).toBe('success');
@@ -79,18 +134,43 @@ describe('Agentic RAG System', () => {
const result = await agenticRAG({
action: 'get_context',
query: 'Tell me about vector databases',
collection_name: 'family_domestic',
top_k: 2,
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true);
expect(result.data.analysis.needsRetrieval).toBe(true);
expect(result.context).toBeDefined();
expect(result.data.context_summary).toBeDefined();
});
it('should handle semantic search', async () => {
const result = await agenticRAG({
action: 'semantic_search',
query: 'artificial intelligence concepts',
collection_name: 'family_domestic',
top_k: 3,
});
expect(result.status).toBe('success');
expect(result.data.results).toBeDefined();
expect(Array.isArray(result.data.results)).toBe(true);
});
it('should list collections', async () => {
const result = await agenticRAG({
action: 'list_collections',
collection_name: 'family_domestic',
});
expect(result.status).toBe('success');
expect(result.message).toContain('family_domestic');
});
it('should handle errors gracefully', async () => {
const result = await agenticRAG({
action: 'analyze_query',
collection_name: 'family_domestic',
// Missing query parameter
});
@@ -98,35 +178,82 @@ describe('Agentic RAG System', () => {
expect(result.message).toContain('Query is required');
});
it('should handle invalid actions', async () => {
const result = await agenticRAG({
action: 'invalid_action',
collection_name: 'family_domestic',
});
expect(result.status).toBe('error');
expect(result.message).toContain('Invalid action');
});
it('should have correct tool definition structure', () => {
expect(AgenticRAGTool.type).toBe('function');
expect(AgenticRAGTool.function.name).toBe('agentic_rag');
expect(AgenticRAGTool.function.description).toBeDefined();
expect(AgenticRAGTool.function.parameters.type).toBe('object');
expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined();
expect(AgenticRAGTool.function.parameters.required).toContain('action');
expect(AgenticRAGTools.type).toBe('function');
expect(AgenticRAGTools.function.name).toBe('agentic_rag');
expect(AgenticRAGTools.function.description).toBeDefined();
expect(AgenticRAGTools.function.parameters.type).toBe('object');
expect(AgenticRAGTools.function.parameters.properties.action).toBeDefined();
expect(AgenticRAGTools.function.parameters.required).toContain('action');
expect(AgenticRAGTools.function.parameters.required).toContain('collection_name');
});
it('should demonstrate intelligent retrieval decision making', async () => {
// Test various query types to show intelligent decision making
const queries = [
{ query: 'What is AI?', expectedRetrieval: true },
{ query: 'Hello world', expectedRetrieval: false },
{ query: 'Write a poem', expectedRetrieval: false },
{ query: 'Hello world how are you', expectedRetrieval: false },
{ query: 'Write a poem and create a story', expectedRetrieval: false },
{ query: 'Explain machine learning', expectedRetrieval: true },
{ query: 'How are you doing?', expectedRetrieval: false },
{ query: 'How are you doing today?', expectedRetrieval: true },
{ query: 'Tell me about neural networks', expectedRetrieval: true },
];
for (const testCase of queries) {
const result = await agenticRAG({
action: 'search_knowledge',
query: testCase.query,
collection_name: 'family_domestic',
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(testCase.expectedRetrieval);
expect(result.data.analysis.needsRetrieval).toBe(testCase.expectedRetrieval);
console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`);
console.log(
`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.data.analysis.needsRetrieval}`,
);
}
});
it('should filter results by similarity threshold', async () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'What is machine learning?',
collection_name: 'family_domestic',
similarity_threshold: 0.8, // High threshold
});
expect(result.status).toBe('success');
if (result.data.analysis.needsRetrieval) {
// Should only return results above threshold
result.data.retrieved_documents.forEach((doc: any) => {
expect(doc.score).toBeGreaterThanOrEqual(0.8);
});
}
});
it('should handle context window limits', async () => {
const result = await agenticRAG({
action: 'get_context',
query: 'Tell me about artificial intelligence',
collection_name: 'family_domestic',
context_window: 1000,
});
expect(result.status).toBe('success');
if (result.data.analysis.needsRetrieval && result.data.context_summary) {
// Context should respect the window limit (approximate check)
expect(result.data.context_summary.length).toBeLessThanOrEqual(2000); // Allow some flexibility
}
});
});

View File

@@ -50,24 +50,48 @@ export const AgenticRAGTools = {
properties: {
action: {
type: 'string',
enum: ['search', 'list_collections', 'report_status'],
enum: [
'list_collections',
'report_status',
'semantic_search',
'search_knowledge',
'analyze_query',
'get_context',
],
description: 'Action to perform with the agentic RAG system.',
},
query: {
type: 'string',
description: 'User query or search term for knowledge retrieval.',
},
document: {
type: 'object',
properties: {
content: { type: 'string', description: 'Document content to store' },
metadata: { type: 'object', description: 'Additional metadata for the document' },
id: { type: 'string', description: 'Unique identifier for the document' },
},
description: 'Document to store in the knowledge base.',
},
// document: {
// type: 'object',
// properties: {
// content: { type: 'string', description: 'Document content to store' },
// metadata: { type: 'object', description: 'Additional metadata for the document' },
// id: { type: 'string', description: 'Unique identifier for the document' },
// },
// description: 'Document to store in the knowledge base.',
// },
collection_name: {
type: 'string',
// todo: make this fancy w/ dynamic collection
enum: [
'business_corporate',
'civil_procedure',
'criminal_justice',
'education_professions',
'environmental_infrastructure',
'family_domestic',
'foundational_law',
'government_administration',
'health_social_services',
'miscellaneous',
'property_real_estate',
'special_documents',
'taxation_finance',
'transportation_motor_vehicles',
],
description: 'Name of the collection to work with.',
},
top_k: {
@@ -83,7 +107,7 @@ export const AgenticRAGTools = {
description: 'Maximum number of context tokens to include (default: 2000).',
},
},
required: ['action'],
required: ['action', 'collection_name'],
additionalProperties: false,
},
strict: true,
@@ -95,10 +119,10 @@ export const AgenticRAGTools = {
*/
const DEFAULT_CONFIG: AgenticRAGConfig = {
milvusAddress: 'localhost:19530',
collectionName: 'knowledge_base',
collectionName: 'family_domestic',
embeddingDimension: 768,
topK: 5,
similarityThreshold: 0.58,
similarityThreshold: 0.5,
};
/**
@@ -200,14 +224,16 @@ function analyzeQueryForRetrieval(query: string): {
'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
queryType: 'factual',
};
} else if (creativeScore > conversationalScore) {
} else if (creativeScore > conversationalScore && creativeScore > 1) {
// Only skip retrieval for clearly creative tasks with multiple creative keywords
return {
needsRetrieval: false,
confidence: 0.8,
reasoning: 'Query appears to be requesting creative content generation.',
queryType: 'creative',
};
} else if (conversationalScore > 0) {
} else if (conversationalScore > 1 && conversationalScore > factualScore) {
// Only skip retrieval for clearly conversational queries with multiple conversational keywords
return {
needsRetrieval: false,
confidence: 0.7,
@@ -215,10 +241,11 @@ function analyzeQueryForRetrieval(query: string): {
queryType: 'conversational',
};
} else {
// Default to retrieval for most cases to ensure comprehensive responses
return {
needsRetrieval: true,
confidence: 0.6,
reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.',
confidence: 0.8,
reasoning: 'Defaulting to retrieval to provide comprehensive and accurate information.',
queryType: 'analytical',
};
}
@@ -235,8 +262,8 @@ export async function agenticRAG(args: {
top_k?: number;
similarity_threshold?: number;
context_window?: number;
user_confirmed?: boolean;
}): Promise<AgenticRAGResult> {
console.log('calling agentic rag tool', args);
const config = { ...DEFAULT_CONFIG };
const collectionName = args.collection_name || config.collectionName!;
const topK = args.top_k || config.topK!;
@@ -297,9 +324,9 @@ export async function agenticRAG(args: {
// eslint-disable-next-line no-case-declarations
const searchResult = await milvusClient.search({
collection_name: collectionName,
query_vectors: [queryEmbedding],
top_k: topK,
params: { nprobe: 16 },
vector: queryEmbedding,
topk: topK,
params: { nprobe: 8 },
output_fields: ['content', 'metadata'],
});
@@ -385,8 +412,7 @@ export async function agenticRAG(args: {
],
};
// @ts-expect-error - idk man
await milvusClient.createCollection(collectionSchema);
await milvusClient.createCollection(collectionSchema as any);
// Create index for efficient similarity search
await milvusClient.createIndex({
@@ -426,9 +452,9 @@ export async function agenticRAG(args: {
// eslint-disable-next-line no-case-declarations
const semanticResults = await milvusClient.search({
collection_name: collectionName,
query_vectors: [semanticEmbedding],
top_k: topK,
params: { nprobe: 16 },
vector: semanticEmbedding,
topk: topK,
params: { nprobe: 8 },
output_fields: ['content', 'metadata'],
});
@@ -452,14 +478,13 @@ export async function agenticRAG(args: {
// This is a comprehensive context retrieval that combines analysis and search
// eslint-disable-next-line no-case-declarations
const contextAnalysis = analyzeQueryForRetrieval(args.query);
if (contextAnalysis.needsRetrieval) {
const contextEmbedding = await generateEmbedding(args.query);
const contextSearch = await milvusClient.search({
collection_name: collectionName,
query_vectors: [contextEmbedding],
top_k: topK,
params: { nprobe: 16 },
vector: contextEmbedding,
topk: topK,
params: { nprobe: 8 },
output_fields: ['content', 'metadata'],
});

View File

@@ -39,19 +39,19 @@ export const LandingComponent: React.FC = () => {
},
}}
switches={{
GpsMap: {
value: mapActive,
onChange(enabled) {
if (enabled) {
setEnabledComponent('gpsmap');
setAiActive(false);
} else {
setEnabledComponent('');
}
setMapActive(enabled);
},
label: 'GPS',
},
// GpsMap: {
// value: mapActive,
// onChange(enabled) {
// if (enabled) {
// setEnabledComponent('gpsmap');
// setAiActive(false);
// } else {
// setEnabledComponent('');
// }
// setMapActive(enabled);
// },
// label: 'GPS',
// },
AI: {
value: aiActive,
onChange(enabled) {