mirror of
https://github.com/geoffsee/open-gsio.git
synced 2025-09-08 22:56:46 +00:00
Refine tool call execution logic in chat-stream-provider
to prevent duplicates, enhance retries for agentic-rag
, and improve incremental processing, including test updates.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -25,3 +25,4 @@ wrangler.dev.jsonc
|
||||
/packages/client/public/pwa-192x192.png
|
||||
/packages/client/public/pwa-512x512.png
|
||||
packages/client/public/yachtpit_bg*
|
||||
/project/
|
||||
|
@@ -1,5 +1,5 @@
|
||||
import { OpenAI } from 'openai';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
import {
|
||||
BaseChatProvider,
|
||||
@@ -29,7 +29,7 @@ class TestChatProvider extends BaseChatProvider {
|
||||
}
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('../../lib/chat-sdk', () => ({
|
||||
vi.mock('../../chat-sdk/chat-sdk.ts', () => ({
|
||||
default: {
|
||||
buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'),
|
||||
buildMessageChain: vi.fn().mockReturnValue([
|
||||
@@ -39,6 +39,26 @@ vi.mock('../../lib/chat-sdk', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('../../tools/agentic-rag.ts', () => ({
|
||||
agenticRAG: vi.fn(),
|
||||
AgenticRAGTools: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
description: 'Test agentic RAG tool',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
action: { type: 'string', enum: ['search_knowledge'] },
|
||||
query: { type: 'string' },
|
||||
collection_name: { type: 'string' },
|
||||
},
|
||||
required: ['action', 'collection_name'],
|
||||
},
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ChatStreamProvider', () => {
|
||||
it('should define the required interface', () => {
|
||||
// Verify the interface has the required method
|
||||
@@ -50,26 +70,616 @@ describe('ChatStreamProvider', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('BaseChatProvider', () => {
|
||||
it('should implement the ChatStreamProvider interface', () => {
|
||||
// Create a concrete implementation
|
||||
const provider = new TestChatProvider();
|
||||
describe('BaseChatProvider - Model Tool Calling', () => {
|
||||
let provider: TestChatProvider;
|
||||
let mockOpenAI: any;
|
||||
let dataCallback: any;
|
||||
let commonParams: CommonProviderParams;
|
||||
|
||||
// Verify it implements the interface
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
provider = new TestChatProvider();
|
||||
dataCallback = vi.fn();
|
||||
|
||||
mockOpenAI = {
|
||||
chat: {
|
||||
completions: {
|
||||
create: vi.fn(),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
commonParams = {
|
||||
openai: mockOpenAI,
|
||||
systemPrompt: 'Test system prompt',
|
||||
preprocessedContext: {},
|
||||
maxTokens: 1000,
|
||||
messages: [{ role: 'user', content: 'Test message' }],
|
||||
model: 'gpt-4',
|
||||
env: {} as any,
|
||||
};
|
||||
});
|
||||
|
||||
it('should implement the ChatStreamProvider interface', () => {
|
||||
expect(provider.handleStream).toBeInstanceOf(Function);
|
||||
expect(provider.getOpenAIClient).toBeInstanceOf(Function);
|
||||
expect(provider.getStreamParams).toBeInstanceOf(Function);
|
||||
expect(provider.processChunk).toBeInstanceOf(Function);
|
||||
});
|
||||
|
||||
it('should have abstract methods that need to be implemented', () => {
|
||||
// This test verifies that the abstract methods exist
|
||||
// We can't instantiate BaseChatProvider directly, so we use the concrete implementation
|
||||
const provider = new TestChatProvider();
|
||||
it('should handle regular text streaming without tool calls', async () => {
|
||||
// Mock stream chunks for regular text response
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'Hello ' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'world!' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Verify the abstract methods are implemented
|
||||
expect(provider.getOpenAIClient).toBeDefined();
|
||||
expect(provider.getStreamParams).toBeDefined();
|
||||
expect(provider.processChunk).toBeDefined();
|
||||
mockOpenAI.chat.completions.create.mockResolvedValue({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tools: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
type: 'function',
|
||||
function: expect.objectContaining({
|
||||
name: 'agentic_rag',
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle tool calls in streaming response', async () => {
|
||||
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||
vi.mocked(agenticRAG).mockResolvedValue({
|
||||
success: true,
|
||||
data: {
|
||||
results: ['Test result'],
|
||||
analysis: { needsRetrieval: false },
|
||||
},
|
||||
});
|
||||
|
||||
// Mock stream chunks for tool call response
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
arguments:
|
||||
'{"action": "search_knowledge", "query": "test query", "collection_name": "test_collection"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Second stream for response after tool execution
|
||||
const secondStreamChunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'Based on the search results: Test result' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
let callCount = 0;
|
||||
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of secondStreamChunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
// Verify tool was called
|
||||
expect(agenticRAG).toHaveBeenCalledWith({
|
||||
action: 'search_knowledge',
|
||||
query: 'test query',
|
||||
collection_name: 'test_collection',
|
||||
});
|
||||
|
||||
// Verify feedback messages were sent
|
||||
expect(dataCallback).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'chat',
|
||||
data: expect.objectContaining({
|
||||
choices: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
delta: expect.objectContaining({
|
||||
content: expect.stringContaining('🔧 Invoking'),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(dataCallback).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'chat',
|
||||
data: expect.objectContaining({
|
||||
choices: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
delta: expect.objectContaining({
|
||||
content: expect.stringContaining('📞 Calling agentic_rag'),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle tool call streaming with incremental arguments', async () => {
|
||||
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||
vi.mocked(agenticRAG).mockResolvedValue({
|
||||
success: true,
|
||||
data: { results: ['Test result'] },
|
||||
});
|
||||
|
||||
// Mock stream chunks with incremental tool call arguments
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_',
|
||||
type: 'function',
|
||||
function: { name: 'agentic_rag', arguments: '{"action": "search_' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: '123',
|
||||
function: { arguments: 'knowledge", "query": "test", ' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
function: { arguments: '"collection_name": "test_collection"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const secondStreamChunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'Response after tool call' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
let callCount = 0;
|
||||
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of secondStreamChunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
// Verify the complete tool call was assembled and executed
|
||||
expect(agenticRAG).toHaveBeenCalledWith({
|
||||
action: 'search_knowledge',
|
||||
query: 'test',
|
||||
collection_name: 'test_collection',
|
||||
});
|
||||
});
|
||||
|
||||
it('should prevent infinite tool call loops', async () => {
|
||||
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||
vi.mocked(agenticRAG).mockResolvedValue({
|
||||
success: true,
|
||||
data: {
|
||||
results: [],
|
||||
analysis: { needsRetrieval: true },
|
||||
retrieved_documents: [],
|
||||
},
|
||||
});
|
||||
|
||||
// Mock stream that always returns tool calls
|
||||
const toolCallChunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
arguments:
|
||||
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
mockOpenAI.chat.completions.create.mockResolvedValue({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of toolCallChunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
// Should detect duplicate tool calls and force completion (up to 5 iterations based on maxToolCallIterations)
|
||||
// In this case, it should stop after 2 calls due to duplicate detection, but could go up to 5
|
||||
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle tool call errors gracefully', async () => {
|
||||
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||
vi.mocked(agenticRAG).mockRejectedValue(new Error('Tool execution failed'));
|
||||
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
arguments:
|
||||
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const secondStreamChunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'I apologize, but I encountered an error.' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
let callCount = 0;
|
||||
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of secondStreamChunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
// Should still complete without throwing
|
||||
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should prevent duplicate tool calls', async () => {
|
||||
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||
vi.mocked(agenticRAG).mockResolvedValue({
|
||||
success: true,
|
||||
data: { results: ['Test result'] },
|
||||
});
|
||||
|
||||
// Mock the same tool call twice
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
arguments:
|
||||
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Second iteration with same tool call
|
||||
let callCount = 0;
|
||||
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
await provider.handleStream(commonParams, dataCallback);
|
||||
|
||||
// Should only execute the tool once, then force completion
|
||||
expect(agenticRAG).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should handle invalid JSON in tool call arguments', async () => {
|
||||
const chunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
tool_calls: [
|
||||
{
|
||||
index: 0,
|
||||
id: 'call_123',
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'agentic_rag',
|
||||
arguments: '{"action": "search_knowledge", "invalid": json}', // Invalid JSON
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'tool_calls',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const secondStreamChunks = [
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: { content: 'I encountered an error parsing the tool arguments.' },
|
||||
finish_reason: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
choices: [
|
||||
{
|
||||
delta: {},
|
||||
finish_reason: 'stop',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
let callCount = 0;
|
||||
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
} else {
|
||||
return Promise.resolve({
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of secondStreamChunks) {
|
||||
yield chunk;
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Should not throw, should handle gracefully
|
||||
await expect(provider.handleStream(commonParams, dataCallback)).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
|
@@ -50,6 +50,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
let toolCallIterations = 0;
|
||||
const maxToolCallIterations = 5; // Prevent infinite loops
|
||||
let toolsExecuted = false; // Track if we've executed tools
|
||||
const attemptedToolCalls = new Set<string>(); // Track attempted tool calls to prevent duplicates
|
||||
|
||||
while (!conversationComplete && toolCallIterations < maxToolCallIterations) {
|
||||
const streamParams = this.getStreamParams(param, safeMessages);
|
||||
@@ -112,6 +113,21 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
// Execute tool calls and add results to conversation
|
||||
console.log('Executing tool calls:', toolCalls);
|
||||
|
||||
// Limit to one tool call per iteration to prevent concurrent execution issues
|
||||
// Also filter out duplicate tool calls
|
||||
const uniqueToolCalls = toolCalls.filter(toolCall => {
|
||||
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
|
||||
return !attemptedToolCalls.has(toolCallKey);
|
||||
});
|
||||
const toolCallsToExecute = uniqueToolCalls.slice(0, 1);
|
||||
|
||||
if (toolCallsToExecute.length === 0) {
|
||||
console.log('All tool calls have been attempted already, forcing completion');
|
||||
toolsExecuted = true;
|
||||
conversationComplete = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Send feedback to user about tool invocation
|
||||
dataCallback({
|
||||
type: 'chat',
|
||||
@@ -119,7 +135,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: `\n\n🔧 Invoking ${toolCalls.length} tool${toolCalls.length > 1 ? 's' : ''}...\n`,
|
||||
content: `\n\n🔧 Invoking ${toolCallsToExecute.length} tool${toolCallsToExecute.length > 1 ? 's' : ''}...\n`,
|
||||
},
|
||||
},
|
||||
],
|
||||
@@ -130,15 +146,20 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
safeMessages.push({
|
||||
role: 'assistant',
|
||||
content: assistantMessage || null,
|
||||
tool_calls: toolCalls,
|
||||
tool_calls: toolCallsToExecute,
|
||||
});
|
||||
|
||||
// Execute each tool call and add results
|
||||
for (const toolCall of toolCalls) {
|
||||
let needsMoreRetrieval = false;
|
||||
for (const toolCall of toolCallsToExecute) {
|
||||
if (toolCall.type === 'function') {
|
||||
const name = toolCall.function.name;
|
||||
console.log(`Calling function: ${name}`);
|
||||
|
||||
// Track this tool call attempt
|
||||
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
|
||||
attemptedToolCalls.add(toolCallKey);
|
||||
|
||||
// Send feedback about specific tool being called
|
||||
dataCallback({
|
||||
type: 'chat',
|
||||
@@ -160,6 +181,36 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
const result = await callFunction(name, args);
|
||||
console.log(`Function result:`, result);
|
||||
|
||||
// Check if agentic-rag indicates more retrieval is needed
|
||||
if (
|
||||
name === 'agentic_rag' &&
|
||||
result?.data?.analysis?.needsRetrieval === true &&
|
||||
(!result?.data?.retrieved_documents ||
|
||||
result.data.retrieved_documents.length === 0)
|
||||
) {
|
||||
needsMoreRetrieval = true;
|
||||
console.log('Agentic RAG indicates more retrieval needed');
|
||||
|
||||
// Add context about previous attempts to help LLM make better decisions
|
||||
const attemptedActions = Array.from(attemptedToolCalls)
|
||||
.filter(key => key.startsWith('agentic_rag:'))
|
||||
.map(key => {
|
||||
try {
|
||||
const args = JSON.parse(key.split(':', 2)[1]);
|
||||
return `${args.action} with query: "${args.query}"`;
|
||||
} catch {
|
||||
return 'unknown action';
|
||||
}
|
||||
});
|
||||
|
||||
if (attemptedActions.length > 0) {
|
||||
safeMessages.push({
|
||||
role: 'system',
|
||||
content: `Previous retrieval attempts: ${attemptedActions.join(', ')}. Consider trying a different approach or more specific query.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Send feedback about tool completion
|
||||
dataCallback({
|
||||
type: 'chat',
|
||||
@@ -206,8 +257,10 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
||||
}
|
||||
}
|
||||
|
||||
// Mark that tools have been executed to prevent repeated calls
|
||||
// Only mark tools as executed if we don't need more retrieval
|
||||
if (!needsMoreRetrieval) {
|
||||
toolsExecuted = true;
|
||||
}
|
||||
|
||||
// Send feedback that tool execution is complete
|
||||
dataCallback({
|
||||
|
@@ -1,36 +1,88 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
|
||||
import { agenticRAG, AgenticRAGTools } from '../agentic-rag';
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock('@zilliz/milvus2-sdk-node', () => ({
|
||||
MilvusClient: vi.fn().mockImplementation(() => ({
|
||||
listCollections: vi.fn().mockResolvedValue({
|
||||
collection_names: ['family_domestic', 'business_corporate'],
|
||||
data: [{ name: 'family_domestic' }, { name: 'business_corporate' }],
|
||||
}),
|
||||
search: vi.fn().mockResolvedValue({
|
||||
results: [
|
||||
{
|
||||
content: 'Test document about AI and machine learning',
|
||||
score: 0.85,
|
||||
metadata: '{"category": "AI", "author": "Test Author"}',
|
||||
},
|
||||
{
|
||||
content: 'Another document about neural networks',
|
||||
score: 0.75,
|
||||
metadata: '{"category": "ML", "author": "Another Author"}',
|
||||
},
|
||||
],
|
||||
}),
|
||||
insert: vi.fn().mockResolvedValue({ success: true }),
|
||||
createCollection: vi.fn().mockResolvedValue({ success: true }),
|
||||
createIndex: vi.fn().mockResolvedValue({ success: true }),
|
||||
})),
|
||||
DataType: {
|
||||
VarChar: 'VarChar',
|
||||
FloatVector: 'FloatVector',
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
OpenAI: vi.fn().mockImplementation(() => ({
|
||||
embeddings: {
|
||||
create: vi.fn().mockResolvedValue({
|
||||
data: [{ embedding: new Array(768).fill(0.1) }],
|
||||
}),
|
||||
},
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock environment variables
|
||||
vi.stubEnv('FIREWORKS_API_KEY', 'test-api-key');
|
||||
|
||||
describe('Agentic RAG System', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should analyze queries correctly', async () => {
|
||||
// Test factual query
|
||||
const factualResult = await agenticRAG({
|
||||
action: 'analyze_query',
|
||||
query: 'What is artificial intelligence?',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(factualResult.status).toBe('success');
|
||||
expect(factualResult.needsRetrieval).toBe(true);
|
||||
expect(factualResult.data.needsRetrieval).toBe(true);
|
||||
expect(factualResult.data.queryType).toBe('factual');
|
||||
|
||||
// Test conversational query
|
||||
// Test conversational query with multiple conversational keywords
|
||||
const conversationalResult = await agenticRAG({
|
||||
action: 'analyze_query',
|
||||
query: 'Hello, how are you?',
|
||||
query: 'Hello, how are you doing today?',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(conversationalResult.status).toBe('success');
|
||||
expect(conversationalResult.needsRetrieval).toBe(false);
|
||||
expect(conversationalResult.data.needsRetrieval).toBe(false);
|
||||
expect(conversationalResult.data.queryType).toBe('conversational');
|
||||
|
||||
// Test creative query
|
||||
// Test creative query with multiple creative keywords
|
||||
const creativeResult = await agenticRAG({
|
||||
action: 'analyze_query',
|
||||
query: 'Write a story about a robot',
|
||||
query: 'Write a story and compose a poem',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(creativeResult.status).toBe('success');
|
||||
expect(creativeResult.needsRetrieval).toBe(false);
|
||||
expect(creativeResult.data.needsRetrieval).toBe(false);
|
||||
expect(creativeResult.data.queryType).toBe('creative');
|
||||
});
|
||||
|
||||
@@ -38,25 +90,27 @@ describe('Agentic RAG System', () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'search_knowledge',
|
||||
query: 'What is machine learning?',
|
||||
collection_name: 'family_domestic',
|
||||
top_k: 2,
|
||||
threshold: 0.1,
|
||||
similarity_threshold: 0.1,
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.needsRetrieval).toBe(true);
|
||||
expect(result.context).toBeDefined();
|
||||
expect(Array.isArray(result.context)).toBe(true);
|
||||
expect(result.data.retrieved_documents).toBeDefined();
|
||||
expect(result.data.analysis.needsRetrieval).toBe(true);
|
||||
});
|
||||
|
||||
it('should not search for conversational queries', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'search_knowledge',
|
||||
query: 'Hello there!',
|
||||
query: 'Hello there! How are you?',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.needsRetrieval).toBe(false);
|
||||
expect(result.data.analysis.needsRetrieval).toBe(false);
|
||||
expect(result.data.retrieved_documents).toHaveLength(0);
|
||||
});
|
||||
|
||||
@@ -68,6 +122,7 @@ describe('Agentic RAG System', () => {
|
||||
content: 'This is a test document about neural networks and deep learning.',
|
||||
metadata: { category: 'AI', author: 'Test Author' },
|
||||
},
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
@@ -79,18 +134,43 @@ describe('Agentic RAG System', () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'get_context',
|
||||
query: 'Tell me about vector databases',
|
||||
collection_name: 'family_domestic',
|
||||
top_k: 2,
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.needsRetrieval).toBe(true);
|
||||
expect(result.data.analysis.needsRetrieval).toBe(true);
|
||||
expect(result.context).toBeDefined();
|
||||
expect(result.data.context_summary).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle semantic search', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'semantic_search',
|
||||
query: 'artificial intelligence concepts',
|
||||
collection_name: 'family_domestic',
|
||||
top_k: 3,
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.data.results).toBeDefined();
|
||||
expect(Array.isArray(result.data.results)).toBe(true);
|
||||
});
|
||||
|
||||
it('should list collections', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'list_collections',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.message).toContain('family_domestic');
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'analyze_query',
|
||||
collection_name: 'family_domestic',
|
||||
// Missing query parameter
|
||||
});
|
||||
|
||||
@@ -98,35 +178,82 @@ describe('Agentic RAG System', () => {
|
||||
expect(result.message).toContain('Query is required');
|
||||
});
|
||||
|
||||
it('should handle invalid actions', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'invalid_action',
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(result.status).toBe('error');
|
||||
expect(result.message).toContain('Invalid action');
|
||||
});
|
||||
|
||||
it('should have correct tool definition structure', () => {
|
||||
expect(AgenticRAGTool.type).toBe('function');
|
||||
expect(AgenticRAGTool.function.name).toBe('agentic_rag');
|
||||
expect(AgenticRAGTool.function.description).toBeDefined();
|
||||
expect(AgenticRAGTool.function.parameters.type).toBe('object');
|
||||
expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined();
|
||||
expect(AgenticRAGTool.function.parameters.required).toContain('action');
|
||||
expect(AgenticRAGTools.type).toBe('function');
|
||||
expect(AgenticRAGTools.function.name).toBe('agentic_rag');
|
||||
expect(AgenticRAGTools.function.description).toBeDefined();
|
||||
expect(AgenticRAGTools.function.parameters.type).toBe('object');
|
||||
expect(AgenticRAGTools.function.parameters.properties.action).toBeDefined();
|
||||
expect(AgenticRAGTools.function.parameters.required).toContain('action');
|
||||
expect(AgenticRAGTools.function.parameters.required).toContain('collection_name');
|
||||
});
|
||||
|
||||
it('should demonstrate intelligent retrieval decision making', async () => {
|
||||
// Test various query types to show intelligent decision making
|
||||
const queries = [
|
||||
{ query: 'What is AI?', expectedRetrieval: true },
|
||||
{ query: 'Hello world', expectedRetrieval: false },
|
||||
{ query: 'Write a poem', expectedRetrieval: false },
|
||||
{ query: 'Hello world how are you', expectedRetrieval: false },
|
||||
{ query: 'Write a poem and create a story', expectedRetrieval: false },
|
||||
{ query: 'Explain machine learning', expectedRetrieval: true },
|
||||
{ query: 'How are you doing?', expectedRetrieval: false },
|
||||
{ query: 'How are you doing today?', expectedRetrieval: true },
|
||||
{ query: 'Tell me about neural networks', expectedRetrieval: true },
|
||||
];
|
||||
|
||||
for (const testCase of queries) {
|
||||
const result = await agenticRAG({
|
||||
action: 'search_knowledge',
|
||||
query: testCase.query,
|
||||
collection_name: 'family_domestic',
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
expect(result.needsRetrieval).toBe(testCase.expectedRetrieval);
|
||||
expect(result.data.analysis.needsRetrieval).toBe(testCase.expectedRetrieval);
|
||||
|
||||
console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`);
|
||||
console.log(
|
||||
`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.data.analysis.needsRetrieval}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
it('should filter results by similarity threshold', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'search_knowledge',
|
||||
query: 'What is machine learning?',
|
||||
collection_name: 'family_domestic',
|
||||
similarity_threshold: 0.8, // High threshold
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
if (result.data.analysis.needsRetrieval) {
|
||||
// Should only return results above threshold
|
||||
result.data.retrieved_documents.forEach((doc: any) => {
|
||||
expect(doc.score).toBeGreaterThanOrEqual(0.8);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('should handle context window limits', async () => {
|
||||
const result = await agenticRAG({
|
||||
action: 'get_context',
|
||||
query: 'Tell me about artificial intelligence',
|
||||
collection_name: 'family_domestic',
|
||||
context_window: 1000,
|
||||
});
|
||||
|
||||
expect(result.status).toBe('success');
|
||||
if (result.data.analysis.needsRetrieval && result.data.context_summary) {
|
||||
// Context should respect the window limit (approximate check)
|
||||
expect(result.data.context_summary.length).toBeLessThanOrEqual(2000); // Allow some flexibility
|
||||
}
|
||||
});
|
||||
});
|
@@ -50,24 +50,48 @@ export const AgenticRAGTools = {
|
||||
properties: {
|
||||
action: {
|
||||
type: 'string',
|
||||
enum: ['search', 'list_collections', 'report_status'],
|
||||
enum: [
|
||||
'list_collections',
|
||||
'report_status',
|
||||
'semantic_search',
|
||||
'search_knowledge',
|
||||
'analyze_query',
|
||||
'get_context',
|
||||
],
|
||||
description: 'Action to perform with the agentic RAG system.',
|
||||
},
|
||||
query: {
|
||||
type: 'string',
|
||||
description: 'User query or search term for knowledge retrieval.',
|
||||
},
|
||||
document: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
content: { type: 'string', description: 'Document content to store' },
|
||||
metadata: { type: 'object', description: 'Additional metadata for the document' },
|
||||
id: { type: 'string', description: 'Unique identifier for the document' },
|
||||
},
|
||||
description: 'Document to store in the knowledge base.',
|
||||
},
|
||||
// document: {
|
||||
// type: 'object',
|
||||
// properties: {
|
||||
// content: { type: 'string', description: 'Document content to store' },
|
||||
// metadata: { type: 'object', description: 'Additional metadata for the document' },
|
||||
// id: { type: 'string', description: 'Unique identifier for the document' },
|
||||
// },
|
||||
// description: 'Document to store in the knowledge base.',
|
||||
// },
|
||||
collection_name: {
|
||||
type: 'string',
|
||||
// todo: make this fancy w/ dynamic collection
|
||||
enum: [
|
||||
'business_corporate',
|
||||
'civil_procedure',
|
||||
'criminal_justice',
|
||||
'education_professions',
|
||||
'environmental_infrastructure',
|
||||
'family_domestic',
|
||||
'foundational_law',
|
||||
'government_administration',
|
||||
'health_social_services',
|
||||
'miscellaneous',
|
||||
'property_real_estate',
|
||||
'special_documents',
|
||||
'taxation_finance',
|
||||
'transportation_motor_vehicles',
|
||||
],
|
||||
description: 'Name of the collection to work with.',
|
||||
},
|
||||
top_k: {
|
||||
@@ -83,7 +107,7 @@ export const AgenticRAGTools = {
|
||||
description: 'Maximum number of context tokens to include (default: 2000).',
|
||||
},
|
||||
},
|
||||
required: ['action'],
|
||||
required: ['action', 'collection_name'],
|
||||
additionalProperties: false,
|
||||
},
|
||||
strict: true,
|
||||
@@ -95,10 +119,10 @@ export const AgenticRAGTools = {
|
||||
*/
|
||||
const DEFAULT_CONFIG: AgenticRAGConfig = {
|
||||
milvusAddress: 'localhost:19530',
|
||||
collectionName: 'knowledge_base',
|
||||
collectionName: 'family_domestic',
|
||||
embeddingDimension: 768,
|
||||
topK: 5,
|
||||
similarityThreshold: 0.58,
|
||||
similarityThreshold: 0.5,
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -200,14 +224,16 @@ function analyzeQueryForRetrieval(query: string): {
|
||||
'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
|
||||
queryType: 'factual',
|
||||
};
|
||||
} else if (creativeScore > conversationalScore) {
|
||||
} else if (creativeScore > conversationalScore && creativeScore > 1) {
|
||||
// Only skip retrieval for clearly creative tasks with multiple creative keywords
|
||||
return {
|
||||
needsRetrieval: false,
|
||||
confidence: 0.8,
|
||||
reasoning: 'Query appears to be requesting creative content generation.',
|
||||
queryType: 'creative',
|
||||
};
|
||||
} else if (conversationalScore > 0) {
|
||||
} else if (conversationalScore > 1 && conversationalScore > factualScore) {
|
||||
// Only skip retrieval for clearly conversational queries with multiple conversational keywords
|
||||
return {
|
||||
needsRetrieval: false,
|
||||
confidence: 0.7,
|
||||
@@ -215,10 +241,11 @@ function analyzeQueryForRetrieval(query: string): {
|
||||
queryType: 'conversational',
|
||||
};
|
||||
} else {
|
||||
// Default to retrieval for most cases to ensure comprehensive responses
|
||||
return {
|
||||
needsRetrieval: true,
|
||||
confidence: 0.6,
|
||||
reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.',
|
||||
confidence: 0.8,
|
||||
reasoning: 'Defaulting to retrieval to provide comprehensive and accurate information.',
|
||||
queryType: 'analytical',
|
||||
};
|
||||
}
|
||||
@@ -235,8 +262,8 @@ export async function agenticRAG(args: {
|
||||
top_k?: number;
|
||||
similarity_threshold?: number;
|
||||
context_window?: number;
|
||||
user_confirmed?: boolean;
|
||||
}): Promise<AgenticRAGResult> {
|
||||
console.log('calling agentic rag tool', args);
|
||||
const config = { ...DEFAULT_CONFIG };
|
||||
const collectionName = args.collection_name || config.collectionName!;
|
||||
const topK = args.top_k || config.topK!;
|
||||
@@ -297,9 +324,9 @@ export async function agenticRAG(args: {
|
||||
// eslint-disable-next-line no-case-declarations
|
||||
const searchResult = await milvusClient.search({
|
||||
collection_name: collectionName,
|
||||
query_vectors: [queryEmbedding],
|
||||
top_k: topK,
|
||||
params: { nprobe: 16 },
|
||||
vector: queryEmbedding,
|
||||
topk: topK,
|
||||
params: { nprobe: 8 },
|
||||
output_fields: ['content', 'metadata'],
|
||||
});
|
||||
|
||||
@@ -385,8 +412,7 @@ export async function agenticRAG(args: {
|
||||
],
|
||||
};
|
||||
|
||||
// @ts-expect-error - idk man
|
||||
await milvusClient.createCollection(collectionSchema);
|
||||
await milvusClient.createCollection(collectionSchema as any);
|
||||
|
||||
// Create index for efficient similarity search
|
||||
await milvusClient.createIndex({
|
||||
@@ -426,9 +452,9 @@ export async function agenticRAG(args: {
|
||||
// eslint-disable-next-line no-case-declarations
|
||||
const semanticResults = await milvusClient.search({
|
||||
collection_name: collectionName,
|
||||
query_vectors: [semanticEmbedding],
|
||||
top_k: topK,
|
||||
params: { nprobe: 16 },
|
||||
vector: semanticEmbedding,
|
||||
topk: topK,
|
||||
params: { nprobe: 8 },
|
||||
output_fields: ['content', 'metadata'],
|
||||
});
|
||||
|
||||
@@ -452,14 +478,13 @@ export async function agenticRAG(args: {
|
||||
// This is a comprehensive context retrieval that combines analysis and search
|
||||
// eslint-disable-next-line no-case-declarations
|
||||
const contextAnalysis = analyzeQueryForRetrieval(args.query);
|
||||
|
||||
if (contextAnalysis.needsRetrieval) {
|
||||
const contextEmbedding = await generateEmbedding(args.query);
|
||||
const contextSearch = await milvusClient.search({
|
||||
collection_name: collectionName,
|
||||
query_vectors: [contextEmbedding],
|
||||
top_k: topK,
|
||||
params: { nprobe: 16 },
|
||||
vector: contextEmbedding,
|
||||
topk: topK,
|
||||
params: { nprobe: 8 },
|
||||
output_fields: ['content', 'metadata'],
|
||||
});
|
||||
|
||||
|
@@ -39,19 +39,19 @@ export const LandingComponent: React.FC = () => {
|
||||
},
|
||||
}}
|
||||
switches={{
|
||||
GpsMap: {
|
||||
value: mapActive,
|
||||
onChange(enabled) {
|
||||
if (enabled) {
|
||||
setEnabledComponent('gpsmap');
|
||||
setAiActive(false);
|
||||
} else {
|
||||
setEnabledComponent('');
|
||||
}
|
||||
setMapActive(enabled);
|
||||
},
|
||||
label: 'GPS',
|
||||
},
|
||||
// GpsMap: {
|
||||
// value: mapActive,
|
||||
// onChange(enabled) {
|
||||
// if (enabled) {
|
||||
// setEnabledComponent('gpsmap');
|
||||
// setAiActive(false);
|
||||
// } else {
|
||||
// setEnabledComponent('');
|
||||
// }
|
||||
// setMapActive(enabled);
|
||||
// },
|
||||
// label: 'GPS',
|
||||
// },
|
||||
AI: {
|
||||
value: aiActive,
|
||||
onChange(enabled) {
|
||||
|
Reference in New Issue
Block a user