mirror of
https://github.com/geoffsee/open-gsio.git
synced 2025-09-08 22:56:46 +00:00
Refine tool call execution logic in chat-stream-provider
to prevent duplicates, enhance retries for agentic-rag
, and improve incremental processing, including test updates.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -25,3 +25,4 @@ wrangler.dev.jsonc
|
|||||||
/packages/client/public/pwa-192x192.png
|
/packages/client/public/pwa-192x192.png
|
||||||
/packages/client/public/pwa-512x512.png
|
/packages/client/public/pwa-512x512.png
|
||||||
packages/client/public/yachtpit_bg*
|
packages/client/public/yachtpit_bg*
|
||||||
|
/project/
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import { OpenAI } from 'openai';
|
import { OpenAI } from 'openai';
|
||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
BaseChatProvider,
|
BaseChatProvider,
|
||||||
@@ -29,7 +29,7 @@ class TestChatProvider extends BaseChatProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Mock dependencies
|
// Mock dependencies
|
||||||
vi.mock('../../lib/chat-sdk', () => ({
|
vi.mock('../../chat-sdk/chat-sdk.ts', () => ({
|
||||||
default: {
|
default: {
|
||||||
buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'),
|
buildAssistantPrompt: vi.fn().mockReturnValue('Assistant prompt'),
|
||||||
buildMessageChain: vi.fn().mockReturnValue([
|
buildMessageChain: vi.fn().mockReturnValue([
|
||||||
@@ -39,6 +39,26 @@ vi.mock('../../lib/chat-sdk', () => ({
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
vi.mock('../../tools/agentic-rag.ts', () => ({
|
||||||
|
agenticRAG: vi.fn(),
|
||||||
|
AgenticRAGTools: {
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
description: 'Test agentic RAG tool',
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
action: { type: 'string', enum: ['search_knowledge'] },
|
||||||
|
query: { type: 'string' },
|
||||||
|
collection_name: { type: 'string' },
|
||||||
|
},
|
||||||
|
required: ['action', 'collection_name'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
describe('ChatStreamProvider', () => {
|
describe('ChatStreamProvider', () => {
|
||||||
it('should define the required interface', () => {
|
it('should define the required interface', () => {
|
||||||
// Verify the interface has the required method
|
// Verify the interface has the required method
|
||||||
@@ -50,26 +70,616 @@ describe('ChatStreamProvider', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('BaseChatProvider', () => {
|
describe('BaseChatProvider - Model Tool Calling', () => {
|
||||||
it('should implement the ChatStreamProvider interface', () => {
|
let provider: TestChatProvider;
|
||||||
// Create a concrete implementation
|
let mockOpenAI: any;
|
||||||
const provider = new TestChatProvider();
|
let dataCallback: any;
|
||||||
|
let commonParams: CommonProviderParams;
|
||||||
|
|
||||||
// Verify it implements the interface
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
|
||||||
|
provider = new TestChatProvider();
|
||||||
|
dataCallback = vi.fn();
|
||||||
|
|
||||||
|
mockOpenAI = {
|
||||||
|
chat: {
|
||||||
|
completions: {
|
||||||
|
create: vi.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
commonParams = {
|
||||||
|
openai: mockOpenAI,
|
||||||
|
systemPrompt: 'Test system prompt',
|
||||||
|
preprocessedContext: {},
|
||||||
|
maxTokens: 1000,
|
||||||
|
messages: [{ role: 'user', content: 'Test message' }],
|
||||||
|
model: 'gpt-4',
|
||||||
|
env: {} as any,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should implement the ChatStreamProvider interface', () => {
|
||||||
expect(provider.handleStream).toBeInstanceOf(Function);
|
expect(provider.handleStream).toBeInstanceOf(Function);
|
||||||
expect(provider.getOpenAIClient).toBeInstanceOf(Function);
|
expect(provider.getOpenAIClient).toBeInstanceOf(Function);
|
||||||
expect(provider.getStreamParams).toBeInstanceOf(Function);
|
expect(provider.getStreamParams).toBeInstanceOf(Function);
|
||||||
expect(provider.processChunk).toBeInstanceOf(Function);
|
expect(provider.processChunk).toBeInstanceOf(Function);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should have abstract methods that need to be implemented', () => {
|
it('should handle regular text streaming without tool calls', async () => {
|
||||||
// This test verifies that the abstract methods exist
|
// Mock stream chunks for regular text response
|
||||||
// We can't instantiate BaseChatProvider directly, so we use the concrete implementation
|
const chunks = [
|
||||||
const provider = new TestChatProvider();
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'Hello ' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'world!' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'stop',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
// Verify the abstract methods are implemented
|
mockOpenAI.chat.completions.create.mockResolvedValue({
|
||||||
expect(provider.getOpenAIClient).toBeDefined();
|
async *[Symbol.asyncIterator]() {
|
||||||
expect(provider.getStreamParams).toBeDefined();
|
for (const chunk of chunks) {
|
||||||
expect(provider.processChunk).toBeDefined();
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
tools: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
type: 'function',
|
||||||
|
function: expect.objectContaining({
|
||||||
|
name: 'agentic_rag',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle tool calls in streaming response', async () => {
|
||||||
|
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||||
|
vi.mocked(agenticRAG).mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
data: {
|
||||||
|
results: ['Test result'],
|
||||||
|
analysis: { needsRetrieval: false },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock stream chunks for tool call response
|
||||||
|
const chunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
arguments:
|
||||||
|
'{"action": "search_knowledge", "query": "test query", "collection_name": "test_collection"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// Second stream for response after tool execution
|
||||||
|
const secondStreamChunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'Based on the search results: Test result' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'stop',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
if (callCount === 1) {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of secondStreamChunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
// Verify tool was called
|
||||||
|
expect(agenticRAG).toHaveBeenCalledWith({
|
||||||
|
action: 'search_knowledge',
|
||||||
|
query: 'test query',
|
||||||
|
collection_name: 'test_collection',
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify feedback messages were sent
|
||||||
|
expect(dataCallback).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
type: 'chat',
|
||||||
|
data: expect.objectContaining({
|
||||||
|
choices: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
delta: expect.objectContaining({
|
||||||
|
content: expect.stringContaining('🔧 Invoking'),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(dataCallback).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
type: 'chat',
|
||||||
|
data: expect.objectContaining({
|
||||||
|
choices: expect.arrayContaining([
|
||||||
|
expect.objectContaining({
|
||||||
|
delta: expect.objectContaining({
|
||||||
|
content: expect.stringContaining('📞 Calling agentic_rag'),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle tool call streaming with incremental arguments', async () => {
|
||||||
|
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||||
|
vi.mocked(agenticRAG).mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
data: { results: ['Test result'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock stream chunks with incremental tool call arguments
|
||||||
|
const chunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_',
|
||||||
|
type: 'function',
|
||||||
|
function: { name: 'agentic_rag', arguments: '{"action": "search_' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: '123',
|
||||||
|
function: { arguments: 'knowledge", "query": "test", ' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
function: { arguments: '"collection_name": "test_collection"}' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const secondStreamChunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'Response after tool call' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'stop',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
if (callCount === 1) {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of secondStreamChunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
// Verify the complete tool call was assembled and executed
|
||||||
|
expect(agenticRAG).toHaveBeenCalledWith({
|
||||||
|
action: 'search_knowledge',
|
||||||
|
query: 'test',
|
||||||
|
collection_name: 'test_collection',
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prevent infinite tool call loops', async () => {
|
||||||
|
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||||
|
vi.mocked(agenticRAG).mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
data: {
|
||||||
|
results: [],
|
||||||
|
analysis: { needsRetrieval: true },
|
||||||
|
retrieved_documents: [],
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock stream that always returns tool calls
|
||||||
|
const toolCallChunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
arguments:
|
||||||
|
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
mockOpenAI.chat.completions.create.mockResolvedValue({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of toolCallChunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
// Should detect duplicate tool calls and force completion (up to 5 iterations based on maxToolCallIterations)
|
||||||
|
// In this case, it should stop after 2 calls due to duplicate detection, but could go up to 5
|
||||||
|
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle tool call errors gracefully', async () => {
|
||||||
|
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||||
|
vi.mocked(agenticRAG).mockRejectedValue(new Error('Tool execution failed'));
|
||||||
|
|
||||||
|
const chunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
arguments:
|
||||||
|
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const secondStreamChunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'I apologize, but I encountered an error.' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'stop',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
if (callCount === 1) {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of secondStreamChunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
// Should still complete without throwing
|
||||||
|
expect(mockOpenAI.chat.completions.create).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should prevent duplicate tool calls', async () => {
|
||||||
|
const { agenticRAG } = await import('../../tools/agentic-rag.ts');
|
||||||
|
vi.mocked(agenticRAG).mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
data: { results: ['Test result'] },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock the same tool call twice
|
||||||
|
const chunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
arguments:
|
||||||
|
'{"action": "search_knowledge", "query": "test", "collection_name": "test_collection"}',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// Second iteration with same tool call
|
||||||
|
let callCount = 0;
|
||||||
|
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await provider.handleStream(commonParams, dataCallback);
|
||||||
|
|
||||||
|
// Should only execute the tool once, then force completion
|
||||||
|
expect(agenticRAG).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle invalid JSON in tool call arguments', async () => {
|
||||||
|
const chunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
index: 0,
|
||||||
|
id: 'call_123',
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: 'agentic_rag',
|
||||||
|
arguments: '{"action": "search_knowledge", "invalid": json}', // Invalid JSON
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'tool_calls',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const secondStreamChunks = [
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: { content: 'I encountered an error parsing the tool arguments.' },
|
||||||
|
finish_reason: null,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
delta: {},
|
||||||
|
finish_reason: 'stop',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let callCount = 0;
|
||||||
|
mockOpenAI.chat.completions.create.mockImplementation(() => {
|
||||||
|
callCount++;
|
||||||
|
if (callCount === 1) {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
return Promise.resolve({
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of secondStreamChunks) {
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should not throw, should handle gracefully
|
||||||
|
await expect(provider.handleStream(commonParams, dataCallback)).resolves.not.toThrow();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
@@ -50,6 +50,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
let toolCallIterations = 0;
|
let toolCallIterations = 0;
|
||||||
const maxToolCallIterations = 5; // Prevent infinite loops
|
const maxToolCallIterations = 5; // Prevent infinite loops
|
||||||
let toolsExecuted = false; // Track if we've executed tools
|
let toolsExecuted = false; // Track if we've executed tools
|
||||||
|
const attemptedToolCalls = new Set<string>(); // Track attempted tool calls to prevent duplicates
|
||||||
|
|
||||||
while (!conversationComplete && toolCallIterations < maxToolCallIterations) {
|
while (!conversationComplete && toolCallIterations < maxToolCallIterations) {
|
||||||
const streamParams = this.getStreamParams(param, safeMessages);
|
const streamParams = this.getStreamParams(param, safeMessages);
|
||||||
@@ -112,6 +113,21 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
// Execute tool calls and add results to conversation
|
// Execute tool calls and add results to conversation
|
||||||
console.log('Executing tool calls:', toolCalls);
|
console.log('Executing tool calls:', toolCalls);
|
||||||
|
|
||||||
|
// Limit to one tool call per iteration to prevent concurrent execution issues
|
||||||
|
// Also filter out duplicate tool calls
|
||||||
|
const uniqueToolCalls = toolCalls.filter(toolCall => {
|
||||||
|
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
|
||||||
|
return !attemptedToolCalls.has(toolCallKey);
|
||||||
|
});
|
||||||
|
const toolCallsToExecute = uniqueToolCalls.slice(0, 1);
|
||||||
|
|
||||||
|
if (toolCallsToExecute.length === 0) {
|
||||||
|
console.log('All tool calls have been attempted already, forcing completion');
|
||||||
|
toolsExecuted = true;
|
||||||
|
conversationComplete = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Send feedback to user about tool invocation
|
// Send feedback to user about tool invocation
|
||||||
dataCallback({
|
dataCallback({
|
||||||
type: 'chat',
|
type: 'chat',
|
||||||
@@ -119,7 +135,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
choices: [
|
choices: [
|
||||||
{
|
{
|
||||||
delta: {
|
delta: {
|
||||||
content: `\n\n🔧 Invoking ${toolCalls.length} tool${toolCalls.length > 1 ? 's' : ''}...\n`,
|
content: `\n\n🔧 Invoking ${toolCallsToExecute.length} tool${toolCallsToExecute.length > 1 ? 's' : ''}...\n`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
@@ -130,15 +146,20 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
safeMessages.push({
|
safeMessages.push({
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: assistantMessage || null,
|
content: assistantMessage || null,
|
||||||
tool_calls: toolCalls,
|
tool_calls: toolCallsToExecute,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Execute each tool call and add results
|
// Execute each tool call and add results
|
||||||
for (const toolCall of toolCalls) {
|
let needsMoreRetrieval = false;
|
||||||
|
for (const toolCall of toolCallsToExecute) {
|
||||||
if (toolCall.type === 'function') {
|
if (toolCall.type === 'function') {
|
||||||
const name = toolCall.function.name;
|
const name = toolCall.function.name;
|
||||||
console.log(`Calling function: ${name}`);
|
console.log(`Calling function: ${name}`);
|
||||||
|
|
||||||
|
// Track this tool call attempt
|
||||||
|
const toolCallKey = `${toolCall.function.name}:${toolCall.function.arguments}`;
|
||||||
|
attemptedToolCalls.add(toolCallKey);
|
||||||
|
|
||||||
// Send feedback about specific tool being called
|
// Send feedback about specific tool being called
|
||||||
dataCallback({
|
dataCallback({
|
||||||
type: 'chat',
|
type: 'chat',
|
||||||
@@ -160,6 +181,36 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
const result = await callFunction(name, args);
|
const result = await callFunction(name, args);
|
||||||
console.log(`Function result:`, result);
|
console.log(`Function result:`, result);
|
||||||
|
|
||||||
|
// Check if agentic-rag indicates more retrieval is needed
|
||||||
|
if (
|
||||||
|
name === 'agentic_rag' &&
|
||||||
|
result?.data?.analysis?.needsRetrieval === true &&
|
||||||
|
(!result?.data?.retrieved_documents ||
|
||||||
|
result.data.retrieved_documents.length === 0)
|
||||||
|
) {
|
||||||
|
needsMoreRetrieval = true;
|
||||||
|
console.log('Agentic RAG indicates more retrieval needed');
|
||||||
|
|
||||||
|
// Add context about previous attempts to help LLM make better decisions
|
||||||
|
const attemptedActions = Array.from(attemptedToolCalls)
|
||||||
|
.filter(key => key.startsWith('agentic_rag:'))
|
||||||
|
.map(key => {
|
||||||
|
try {
|
||||||
|
const args = JSON.parse(key.split(':', 2)[1]);
|
||||||
|
return `${args.action} with query: "${args.query}"`;
|
||||||
|
} catch {
|
||||||
|
return 'unknown action';
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (attemptedActions.length > 0) {
|
||||||
|
safeMessages.push({
|
||||||
|
role: 'system',
|
||||||
|
content: `Previous retrieval attempts: ${attemptedActions.join(', ')}. Consider trying a different approach or more specific query.`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send feedback about tool completion
|
// Send feedback about tool completion
|
||||||
dataCallback({
|
dataCallback({
|
||||||
type: 'chat',
|
type: 'chat',
|
||||||
@@ -206,8 +257,10 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark that tools have been executed to prevent repeated calls
|
// Only mark tools as executed if we don't need more retrieval
|
||||||
|
if (!needsMoreRetrieval) {
|
||||||
toolsExecuted = true;
|
toolsExecuted = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Send feedback that tool execution is complete
|
// Send feedback that tool execution is complete
|
||||||
dataCallback({
|
dataCallback({
|
||||||
|
@@ -1,36 +1,88 @@
|
|||||||
import { describe, it, expect } from 'vitest';
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean';
|
|
||||||
|
import { agenticRAG, AgenticRAGTools } from '../agentic-rag';
|
||||||
|
|
||||||
|
// Mock the dependencies
|
||||||
|
vi.mock('@zilliz/milvus2-sdk-node', () => ({
|
||||||
|
MilvusClient: vi.fn().mockImplementation(() => ({
|
||||||
|
listCollections: vi.fn().mockResolvedValue({
|
||||||
|
collection_names: ['family_domestic', 'business_corporate'],
|
||||||
|
data: [{ name: 'family_domestic' }, { name: 'business_corporate' }],
|
||||||
|
}),
|
||||||
|
search: vi.fn().mockResolvedValue({
|
||||||
|
results: [
|
||||||
|
{
|
||||||
|
content: 'Test document about AI and machine learning',
|
||||||
|
score: 0.85,
|
||||||
|
metadata: '{"category": "AI", "author": "Test Author"}',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
content: 'Another document about neural networks',
|
||||||
|
score: 0.75,
|
||||||
|
metadata: '{"category": "ML", "author": "Another Author"}',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
insert: vi.fn().mockResolvedValue({ success: true }),
|
||||||
|
createCollection: vi.fn().mockResolvedValue({ success: true }),
|
||||||
|
createIndex: vi.fn().mockResolvedValue({ success: true }),
|
||||||
|
})),
|
||||||
|
DataType: {
|
||||||
|
VarChar: 'VarChar',
|
||||||
|
FloatVector: 'FloatVector',
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
vi.mock('openai', () => ({
|
||||||
|
OpenAI: vi.fn().mockImplementation(() => ({
|
||||||
|
embeddings: {
|
||||||
|
create: vi.fn().mockResolvedValue({
|
||||||
|
data: [{ embedding: new Array(768).fill(0.1) }],
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock environment variables
|
||||||
|
vi.stubEnv('FIREWORKS_API_KEY', 'test-api-key');
|
||||||
|
|
||||||
describe('Agentic RAG System', () => {
|
describe('Agentic RAG System', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
it('should analyze queries correctly', async () => {
|
it('should analyze queries correctly', async () => {
|
||||||
// Test factual query
|
// Test factual query
|
||||||
const factualResult = await agenticRAG({
|
const factualResult = await agenticRAG({
|
||||||
action: 'analyze_query',
|
action: 'analyze_query',
|
||||||
query: 'What is artificial intelligence?',
|
query: 'What is artificial intelligence?',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(factualResult.status).toBe('success');
|
expect(factualResult.status).toBe('success');
|
||||||
expect(factualResult.needsRetrieval).toBe(true);
|
expect(factualResult.data.needsRetrieval).toBe(true);
|
||||||
expect(factualResult.data.queryType).toBe('factual');
|
expect(factualResult.data.queryType).toBe('factual');
|
||||||
|
|
||||||
// Test conversational query
|
// Test conversational query with multiple conversational keywords
|
||||||
const conversationalResult = await agenticRAG({
|
const conversationalResult = await agenticRAG({
|
||||||
action: 'analyze_query',
|
action: 'analyze_query',
|
||||||
query: 'Hello, how are you?',
|
query: 'Hello, how are you doing today?',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(conversationalResult.status).toBe('success');
|
expect(conversationalResult.status).toBe('success');
|
||||||
expect(conversationalResult.needsRetrieval).toBe(false);
|
expect(conversationalResult.data.needsRetrieval).toBe(false);
|
||||||
expect(conversationalResult.data.queryType).toBe('conversational');
|
expect(conversationalResult.data.queryType).toBe('conversational');
|
||||||
|
|
||||||
// Test creative query
|
// Test creative query with multiple creative keywords
|
||||||
const creativeResult = await agenticRAG({
|
const creativeResult = await agenticRAG({
|
||||||
action: 'analyze_query',
|
action: 'analyze_query',
|
||||||
query: 'Write a story about a robot',
|
query: 'Write a story and compose a poem',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(creativeResult.status).toBe('success');
|
expect(creativeResult.status).toBe('success');
|
||||||
expect(creativeResult.needsRetrieval).toBe(false);
|
expect(creativeResult.data.needsRetrieval).toBe(false);
|
||||||
expect(creativeResult.data.queryType).toBe('creative');
|
expect(creativeResult.data.queryType).toBe('creative');
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -38,25 +90,27 @@ describe('Agentic RAG System', () => {
|
|||||||
const result = await agenticRAG({
|
const result = await agenticRAG({
|
||||||
action: 'search_knowledge',
|
action: 'search_knowledge',
|
||||||
query: 'What is machine learning?',
|
query: 'What is machine learning?',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
top_k: 2,
|
top_k: 2,
|
||||||
threshold: 0.1,
|
similarity_threshold: 0.1,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.status).toBe('success');
|
expect(result.status).toBe('success');
|
||||||
expect(result.needsRetrieval).toBe(true);
|
|
||||||
expect(result.context).toBeDefined();
|
expect(result.context).toBeDefined();
|
||||||
expect(Array.isArray(result.context)).toBe(true);
|
expect(Array.isArray(result.context)).toBe(true);
|
||||||
expect(result.data.retrieved_documents).toBeDefined();
|
expect(result.data.retrieved_documents).toBeDefined();
|
||||||
|
expect(result.data.analysis.needsRetrieval).toBe(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not search for conversational queries', async () => {
|
it('should not search for conversational queries', async () => {
|
||||||
const result = await agenticRAG({
|
const result = await agenticRAG({
|
||||||
action: 'search_knowledge',
|
action: 'search_knowledge',
|
||||||
query: 'Hello there!',
|
query: 'Hello there! How are you?',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.status).toBe('success');
|
expect(result.status).toBe('success');
|
||||||
expect(result.needsRetrieval).toBe(false);
|
expect(result.data.analysis.needsRetrieval).toBe(false);
|
||||||
expect(result.data.retrieved_documents).toHaveLength(0);
|
expect(result.data.retrieved_documents).toHaveLength(0);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -68,6 +122,7 @@ describe('Agentic RAG System', () => {
|
|||||||
content: 'This is a test document about neural networks and deep learning.',
|
content: 'This is a test document about neural networks and deep learning.',
|
||||||
metadata: { category: 'AI', author: 'Test Author' },
|
metadata: { category: 'AI', author: 'Test Author' },
|
||||||
},
|
},
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.status).toBe('success');
|
expect(result.status).toBe('success');
|
||||||
@@ -79,18 +134,43 @@ describe('Agentic RAG System', () => {
|
|||||||
const result = await agenticRAG({
|
const result = await agenticRAG({
|
||||||
action: 'get_context',
|
action: 'get_context',
|
||||||
query: 'Tell me about vector databases',
|
query: 'Tell me about vector databases',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
top_k: 2,
|
top_k: 2,
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.status).toBe('success');
|
expect(result.status).toBe('success');
|
||||||
expect(result.needsRetrieval).toBe(true);
|
expect(result.data.analysis.needsRetrieval).toBe(true);
|
||||||
expect(result.context).toBeDefined();
|
expect(result.context).toBeDefined();
|
||||||
expect(result.data.context_summary).toBeDefined();
|
expect(result.data.context_summary).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle semantic search', async () => {
|
||||||
|
const result = await agenticRAG({
|
||||||
|
action: 'semantic_search',
|
||||||
|
query: 'artificial intelligence concepts',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
|
top_k: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.status).toBe('success');
|
||||||
|
expect(result.data.results).toBeDefined();
|
||||||
|
expect(Array.isArray(result.data.results)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should list collections', async () => {
|
||||||
|
const result = await agenticRAG({
|
||||||
|
action: 'list_collections',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.status).toBe('success');
|
||||||
|
expect(result.message).toContain('family_domestic');
|
||||||
|
});
|
||||||
|
|
||||||
it('should handle errors gracefully', async () => {
|
it('should handle errors gracefully', async () => {
|
||||||
const result = await agenticRAG({
|
const result = await agenticRAG({
|
||||||
action: 'analyze_query',
|
action: 'analyze_query',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
// Missing query parameter
|
// Missing query parameter
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -98,35 +178,82 @@ describe('Agentic RAG System', () => {
|
|||||||
expect(result.message).toContain('Query is required');
|
expect(result.message).toContain('Query is required');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle invalid actions', async () => {
|
||||||
|
const result = await agenticRAG({
|
||||||
|
action: 'invalid_action',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.status).toBe('error');
|
||||||
|
expect(result.message).toContain('Invalid action');
|
||||||
|
});
|
||||||
|
|
||||||
it('should have correct tool definition structure', () => {
|
it('should have correct tool definition structure', () => {
|
||||||
expect(AgenticRAGTool.type).toBe('function');
|
expect(AgenticRAGTools.type).toBe('function');
|
||||||
expect(AgenticRAGTool.function.name).toBe('agentic_rag');
|
expect(AgenticRAGTools.function.name).toBe('agentic_rag');
|
||||||
expect(AgenticRAGTool.function.description).toBeDefined();
|
expect(AgenticRAGTools.function.description).toBeDefined();
|
||||||
expect(AgenticRAGTool.function.parameters.type).toBe('object');
|
expect(AgenticRAGTools.function.parameters.type).toBe('object');
|
||||||
expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined();
|
expect(AgenticRAGTools.function.parameters.properties.action).toBeDefined();
|
||||||
expect(AgenticRAGTool.function.parameters.required).toContain('action');
|
expect(AgenticRAGTools.function.parameters.required).toContain('action');
|
||||||
|
expect(AgenticRAGTools.function.parameters.required).toContain('collection_name');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should demonstrate intelligent retrieval decision making', async () => {
|
it('should demonstrate intelligent retrieval decision making', async () => {
|
||||||
// Test various query types to show intelligent decision making
|
// Test various query types to show intelligent decision making
|
||||||
const queries = [
|
const queries = [
|
||||||
{ query: 'What is AI?', expectedRetrieval: true },
|
{ query: 'What is AI?', expectedRetrieval: true },
|
||||||
{ query: 'Hello world', expectedRetrieval: false },
|
{ query: 'Hello world how are you', expectedRetrieval: false },
|
||||||
{ query: 'Write a poem', expectedRetrieval: false },
|
{ query: 'Write a poem and create a story', expectedRetrieval: false },
|
||||||
{ query: 'Explain machine learning', expectedRetrieval: true },
|
{ query: 'Explain machine learning', expectedRetrieval: true },
|
||||||
{ query: 'How are you doing?', expectedRetrieval: false },
|
{ query: 'How are you doing today?', expectedRetrieval: true },
|
||||||
|
{ query: 'Tell me about neural networks', expectedRetrieval: true },
|
||||||
];
|
];
|
||||||
|
|
||||||
for (const testCase of queries) {
|
for (const testCase of queries) {
|
||||||
const result = await agenticRAG({
|
const result = await agenticRAG({
|
||||||
action: 'search_knowledge',
|
action: 'search_knowledge',
|
||||||
query: testCase.query,
|
query: testCase.query,
|
||||||
|
collection_name: 'family_domestic',
|
||||||
});
|
});
|
||||||
|
|
||||||
expect(result.status).toBe('success');
|
expect(result.status).toBe('success');
|
||||||
expect(result.needsRetrieval).toBe(testCase.expectedRetrieval);
|
expect(result.data.analysis.needsRetrieval).toBe(testCase.expectedRetrieval);
|
||||||
|
|
||||||
console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`);
|
console.log(
|
||||||
|
`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.data.analysis.needsRetrieval}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should filter results by similarity threshold', async () => {
|
||||||
|
const result = await agenticRAG({
|
||||||
|
action: 'search_knowledge',
|
||||||
|
query: 'What is machine learning?',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
|
similarity_threshold: 0.8, // High threshold
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.status).toBe('success');
|
||||||
|
if (result.data.analysis.needsRetrieval) {
|
||||||
|
// Should only return results above threshold
|
||||||
|
result.data.retrieved_documents.forEach((doc: any) => {
|
||||||
|
expect(doc.score).toBeGreaterThanOrEqual(0.8);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle context window limits', async () => {
|
||||||
|
const result = await agenticRAG({
|
||||||
|
action: 'get_context',
|
||||||
|
query: 'Tell me about artificial intelligence',
|
||||||
|
collection_name: 'family_domestic',
|
||||||
|
context_window: 1000,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.status).toBe('success');
|
||||||
|
if (result.data.analysis.needsRetrieval && result.data.context_summary) {
|
||||||
|
// Context should respect the window limit (approximate check)
|
||||||
|
expect(result.data.context_summary.length).toBeLessThanOrEqual(2000); // Allow some flexibility
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
});
|
});
|
@@ -50,24 +50,48 @@ export const AgenticRAGTools = {
|
|||||||
properties: {
|
properties: {
|
||||||
action: {
|
action: {
|
||||||
type: 'string',
|
type: 'string',
|
||||||
enum: ['search', 'list_collections', 'report_status'],
|
enum: [
|
||||||
|
'list_collections',
|
||||||
|
'report_status',
|
||||||
|
'semantic_search',
|
||||||
|
'search_knowledge',
|
||||||
|
'analyze_query',
|
||||||
|
'get_context',
|
||||||
|
],
|
||||||
description: 'Action to perform with the agentic RAG system.',
|
description: 'Action to perform with the agentic RAG system.',
|
||||||
},
|
},
|
||||||
query: {
|
query: {
|
||||||
type: 'string',
|
type: 'string',
|
||||||
description: 'User query or search term for knowledge retrieval.',
|
description: 'User query or search term for knowledge retrieval.',
|
||||||
},
|
},
|
||||||
document: {
|
// document: {
|
||||||
type: 'object',
|
// type: 'object',
|
||||||
properties: {
|
// properties: {
|
||||||
content: { type: 'string', description: 'Document content to store' },
|
// content: { type: 'string', description: 'Document content to store' },
|
||||||
metadata: { type: 'object', description: 'Additional metadata for the document' },
|
// metadata: { type: 'object', description: 'Additional metadata for the document' },
|
||||||
id: { type: 'string', description: 'Unique identifier for the document' },
|
// id: { type: 'string', description: 'Unique identifier for the document' },
|
||||||
},
|
// },
|
||||||
description: 'Document to store in the knowledge base.',
|
// description: 'Document to store in the knowledge base.',
|
||||||
},
|
// },
|
||||||
collection_name: {
|
collection_name: {
|
||||||
type: 'string',
|
type: 'string',
|
||||||
|
// todo: make this fancy w/ dynamic collection
|
||||||
|
enum: [
|
||||||
|
'business_corporate',
|
||||||
|
'civil_procedure',
|
||||||
|
'criminal_justice',
|
||||||
|
'education_professions',
|
||||||
|
'environmental_infrastructure',
|
||||||
|
'family_domestic',
|
||||||
|
'foundational_law',
|
||||||
|
'government_administration',
|
||||||
|
'health_social_services',
|
||||||
|
'miscellaneous',
|
||||||
|
'property_real_estate',
|
||||||
|
'special_documents',
|
||||||
|
'taxation_finance',
|
||||||
|
'transportation_motor_vehicles',
|
||||||
|
],
|
||||||
description: 'Name of the collection to work with.',
|
description: 'Name of the collection to work with.',
|
||||||
},
|
},
|
||||||
top_k: {
|
top_k: {
|
||||||
@@ -83,7 +107,7 @@ export const AgenticRAGTools = {
|
|||||||
description: 'Maximum number of context tokens to include (default: 2000).',
|
description: 'Maximum number of context tokens to include (default: 2000).',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
required: ['action'],
|
required: ['action', 'collection_name'],
|
||||||
additionalProperties: false,
|
additionalProperties: false,
|
||||||
},
|
},
|
||||||
strict: true,
|
strict: true,
|
||||||
@@ -95,10 +119,10 @@ export const AgenticRAGTools = {
|
|||||||
*/
|
*/
|
||||||
const DEFAULT_CONFIG: AgenticRAGConfig = {
|
const DEFAULT_CONFIG: AgenticRAGConfig = {
|
||||||
milvusAddress: 'localhost:19530',
|
milvusAddress: 'localhost:19530',
|
||||||
collectionName: 'knowledge_base',
|
collectionName: 'family_domestic',
|
||||||
embeddingDimension: 768,
|
embeddingDimension: 768,
|
||||||
topK: 5,
|
topK: 5,
|
||||||
similarityThreshold: 0.58,
|
similarityThreshold: 0.5,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -200,14 +224,16 @@ function analyzeQueryForRetrieval(query: string): {
|
|||||||
'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
|
'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
|
||||||
queryType: 'factual',
|
queryType: 'factual',
|
||||||
};
|
};
|
||||||
} else if (creativeScore > conversationalScore) {
|
} else if (creativeScore > conversationalScore && creativeScore > 1) {
|
||||||
|
// Only skip retrieval for clearly creative tasks with multiple creative keywords
|
||||||
return {
|
return {
|
||||||
needsRetrieval: false,
|
needsRetrieval: false,
|
||||||
confidence: 0.8,
|
confidence: 0.8,
|
||||||
reasoning: 'Query appears to be requesting creative content generation.',
|
reasoning: 'Query appears to be requesting creative content generation.',
|
||||||
queryType: 'creative',
|
queryType: 'creative',
|
||||||
};
|
};
|
||||||
} else if (conversationalScore > 0) {
|
} else if (conversationalScore > 1 && conversationalScore > factualScore) {
|
||||||
|
// Only skip retrieval for clearly conversational queries with multiple conversational keywords
|
||||||
return {
|
return {
|
||||||
needsRetrieval: false,
|
needsRetrieval: false,
|
||||||
confidence: 0.7,
|
confidence: 0.7,
|
||||||
@@ -215,10 +241,11 @@ function analyzeQueryForRetrieval(query: string): {
|
|||||||
queryType: 'conversational',
|
queryType: 'conversational',
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
|
// Default to retrieval for most cases to ensure comprehensive responses
|
||||||
return {
|
return {
|
||||||
needsRetrieval: true,
|
needsRetrieval: true,
|
||||||
confidence: 0.6,
|
confidence: 0.8,
|
||||||
reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.',
|
reasoning: 'Defaulting to retrieval to provide comprehensive and accurate information.',
|
||||||
queryType: 'analytical',
|
queryType: 'analytical',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@@ -235,8 +262,8 @@ export async function agenticRAG(args: {
|
|||||||
top_k?: number;
|
top_k?: number;
|
||||||
similarity_threshold?: number;
|
similarity_threshold?: number;
|
||||||
context_window?: number;
|
context_window?: number;
|
||||||
|
user_confirmed?: boolean;
|
||||||
}): Promise<AgenticRAGResult> {
|
}): Promise<AgenticRAGResult> {
|
||||||
console.log('calling agentic rag tool', args);
|
|
||||||
const config = { ...DEFAULT_CONFIG };
|
const config = { ...DEFAULT_CONFIG };
|
||||||
const collectionName = args.collection_name || config.collectionName!;
|
const collectionName = args.collection_name || config.collectionName!;
|
||||||
const topK = args.top_k || config.topK!;
|
const topK = args.top_k || config.topK!;
|
||||||
@@ -297,9 +324,9 @@ export async function agenticRAG(args: {
|
|||||||
// eslint-disable-next-line no-case-declarations
|
// eslint-disable-next-line no-case-declarations
|
||||||
const searchResult = await milvusClient.search({
|
const searchResult = await milvusClient.search({
|
||||||
collection_name: collectionName,
|
collection_name: collectionName,
|
||||||
query_vectors: [queryEmbedding],
|
vector: queryEmbedding,
|
||||||
top_k: topK,
|
topk: topK,
|
||||||
params: { nprobe: 16 },
|
params: { nprobe: 8 },
|
||||||
output_fields: ['content', 'metadata'],
|
output_fields: ['content', 'metadata'],
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -385,8 +412,7 @@ export async function agenticRAG(args: {
|
|||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|
||||||
// @ts-expect-error - idk man
|
await milvusClient.createCollection(collectionSchema as any);
|
||||||
await milvusClient.createCollection(collectionSchema);
|
|
||||||
|
|
||||||
// Create index for efficient similarity search
|
// Create index for efficient similarity search
|
||||||
await milvusClient.createIndex({
|
await milvusClient.createIndex({
|
||||||
@@ -426,9 +452,9 @@ export async function agenticRAG(args: {
|
|||||||
// eslint-disable-next-line no-case-declarations
|
// eslint-disable-next-line no-case-declarations
|
||||||
const semanticResults = await milvusClient.search({
|
const semanticResults = await milvusClient.search({
|
||||||
collection_name: collectionName,
|
collection_name: collectionName,
|
||||||
query_vectors: [semanticEmbedding],
|
vector: semanticEmbedding,
|
||||||
top_k: topK,
|
topk: topK,
|
||||||
params: { nprobe: 16 },
|
params: { nprobe: 8 },
|
||||||
output_fields: ['content', 'metadata'],
|
output_fields: ['content', 'metadata'],
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -452,14 +478,13 @@ export async function agenticRAG(args: {
|
|||||||
// This is a comprehensive context retrieval that combines analysis and search
|
// This is a comprehensive context retrieval that combines analysis and search
|
||||||
// eslint-disable-next-line no-case-declarations
|
// eslint-disable-next-line no-case-declarations
|
||||||
const contextAnalysis = analyzeQueryForRetrieval(args.query);
|
const contextAnalysis = analyzeQueryForRetrieval(args.query);
|
||||||
|
|
||||||
if (contextAnalysis.needsRetrieval) {
|
if (contextAnalysis.needsRetrieval) {
|
||||||
const contextEmbedding = await generateEmbedding(args.query);
|
const contextEmbedding = await generateEmbedding(args.query);
|
||||||
const contextSearch = await milvusClient.search({
|
const contextSearch = await milvusClient.search({
|
||||||
collection_name: collectionName,
|
collection_name: collectionName,
|
||||||
query_vectors: [contextEmbedding],
|
vector: contextEmbedding,
|
||||||
top_k: topK,
|
topk: topK,
|
||||||
params: { nprobe: 16 },
|
params: { nprobe: 8 },
|
||||||
output_fields: ['content', 'metadata'],
|
output_fields: ['content', 'metadata'],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@@ -39,19 +39,19 @@ export const LandingComponent: React.FC = () => {
|
|||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
switches={{
|
switches={{
|
||||||
GpsMap: {
|
// GpsMap: {
|
||||||
value: mapActive,
|
// value: mapActive,
|
||||||
onChange(enabled) {
|
// onChange(enabled) {
|
||||||
if (enabled) {
|
// if (enabled) {
|
||||||
setEnabledComponent('gpsmap');
|
// setEnabledComponent('gpsmap');
|
||||||
setAiActive(false);
|
// setAiActive(false);
|
||||||
} else {
|
// } else {
|
||||||
setEnabledComponent('');
|
// setEnabledComponent('');
|
||||||
}
|
// }
|
||||||
setMapActive(enabled);
|
// setMapActive(enabled);
|
||||||
},
|
// },
|
||||||
label: 'GPS',
|
// label: 'GPS',
|
||||||
},
|
// },
|
||||||
AI: {
|
AI: {
|
||||||
value: aiActive,
|
value: aiActive,
|
||||||
onChange(enabled) {
|
onChange(enabled) {
|
||||||
|
Reference in New Issue
Block a user