Add Agentic RAG Tool integration with test cases

- Implemented intelligent retrieval-augmented generation system (`agentic_rag`) for dynamic decision-making on knowledge retrieval.
- Uses Milvus with a large dataset
- Added comprehensive test cases for query analysis, storage, retrieval, and error handling.
- Integrated `AgenticRAGTools` into `chat-stream-provider` enabling tool-based responses.
- Updated dependencies with `@zilliz/milvus2-sdk-node` for Milvus integration.
- Updated lander hero title.
This commit is contained in:
geoffsee
2025-07-31 16:15:23 -04:00
parent ae6a6e4064
commit 6c433581d3
10 changed files with 898 additions and 150 deletions

View File

@@ -1,8 +1,7 @@
import { OpenAI } from 'openai';
import ChatSdk from '../chat-sdk/chat-sdk.ts';
import { getWeather, WeatherTool } from '../tools/weather.ts';
import { yachtpitAi, YachtpitTools } from '../tools/yachtpit.ts';
import { agenticRAG, AgenticRAGTools } from '../tools/agentic-rag.ts';
import type { GenericEnv } from '../types';
export interface CommonProviderParams {
@@ -38,14 +37,11 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
const client = this.getOpenAIClient(param);
const tools = [WeatherTool, YachtpitTools];
const tools = [AgenticRAGTools];
const callFunction = async (name, args) => {
if (name === 'get_weather') {
return getWeather(args.latitude, args.longitude);
}
if (name === 'ship_control') {
return yachtpitAi({ action: args.action, value: args.value });
if (name === 'agentic_rag') {
return agenticRAG(args);
}
};
@@ -59,6 +55,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
const streamParams = this.getStreamParams(param, safeMessages);
// Only provide tools on the first call, after that force text response
const currentTools = toolsExecuted ? undefined : tools;
const stream = await client.chat.completions.create({ ...streamParams, tools: currentTools });
let assistantMessage = '';
@@ -170,7 +167,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
choices: [
{
delta: {
content: `\n`,
content: `\n ${JSON.stringify(result)}`,
},
},
],
@@ -181,7 +178,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
safeMessages.push({
role: 'tool',
tool_call_id: toolCall.id,
content: result?.toString() || '',
content: JSON.stringify(result),
});
} catch (error) {
console.error(`Error executing tool ${name}:`, error);
@@ -236,7 +233,7 @@ export abstract class BaseChatProvider implements ChatStreamProvider {
// Process chunk normally for non-tool-call responses
if (!chunk.choices[0]?.delta?.tool_calls) {
console.log('after-tool-call-chunk', chunk);
// console.log('after-tool-call-chunk', chunk);
const shouldBreak = await this.processChunk(chunk, dataCallback);
if (shouldBreak) {
conversationComplete = true;

View File

@@ -0,0 +1,132 @@
import { describe, it, expect } from 'vitest';
import { agenticRAG, AgenticRAGTool } from '../agentic-rag-clean';
describe('Agentic RAG System', () => {
it('should analyze queries correctly', async () => {
// Test factual query
const factualResult = await agenticRAG({
action: 'analyze_query',
query: 'What is artificial intelligence?',
});
expect(factualResult.status).toBe('success');
expect(factualResult.needsRetrieval).toBe(true);
expect(factualResult.data.queryType).toBe('factual');
// Test conversational query
const conversationalResult = await agenticRAG({
action: 'analyze_query',
query: 'Hello, how are you?',
});
expect(conversationalResult.status).toBe('success');
expect(conversationalResult.needsRetrieval).toBe(false);
expect(conversationalResult.data.queryType).toBe('conversational');
// Test creative query
const creativeResult = await agenticRAG({
action: 'analyze_query',
query: 'Write a story about a robot',
});
expect(creativeResult.status).toBe('success');
expect(creativeResult.needsRetrieval).toBe(false);
expect(creativeResult.data.queryType).toBe('creative');
});
it('should search knowledge base for factual queries', async () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'What is machine learning?',
top_k: 2,
threshold: 0.1,
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true);
expect(result.context).toBeDefined();
expect(Array.isArray(result.context)).toBe(true);
expect(result.data.retrieved_documents).toBeDefined();
});
it('should not search for conversational queries', async () => {
const result = await agenticRAG({
action: 'search_knowledge',
query: 'Hello there!',
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(false);
expect(result.data.retrieved_documents).toHaveLength(0);
});
it('should store documents successfully', async () => {
const result = await agenticRAG({
action: 'store_document',
document: {
id: 'test-doc-1',
content: 'This is a test document about neural networks and deep learning.',
metadata: { category: 'AI', author: 'Test Author' },
},
});
expect(result.status).toBe('success');
expect(result.data.document_id).toBe('test-doc-1');
expect(result.data.content_length).toBeGreaterThan(0);
});
it('should get context for factual queries', async () => {
const result = await agenticRAG({
action: 'get_context',
query: 'Tell me about vector databases',
top_k: 2,
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(true);
expect(result.context).toBeDefined();
expect(result.data.context_summary).toBeDefined();
});
it('should handle errors gracefully', async () => {
const result = await agenticRAG({
action: 'analyze_query',
// Missing query parameter
});
expect(result.status).toBe('error');
expect(result.message).toContain('Query is required');
});
it('should have correct tool definition structure', () => {
expect(AgenticRAGTool.type).toBe('function');
expect(AgenticRAGTool.function.name).toBe('agentic_rag');
expect(AgenticRAGTool.function.description).toBeDefined();
expect(AgenticRAGTool.function.parameters.type).toBe('object');
expect(AgenticRAGTool.function.parameters.properties.action).toBeDefined();
expect(AgenticRAGTool.function.parameters.required).toContain('action');
});
it('should demonstrate intelligent retrieval decision making', async () => {
// Test various query types to show intelligent decision making
const queries = [
{ query: 'What is AI?', expectedRetrieval: true },
{ query: 'Hello world', expectedRetrieval: false },
{ query: 'Write a poem', expectedRetrieval: false },
{ query: 'Explain machine learning', expectedRetrieval: true },
{ query: 'How are you doing?', expectedRetrieval: false },
];
for (const testCase of queries) {
const result = await agenticRAG({
action: 'search_knowledge',
query: testCase.query,
});
expect(result.status).toBe('success');
expect(result.needsRetrieval).toBe(testCase.expectedRetrieval);
console.log(`[DEBUG_LOG] Query: "${testCase.query}" - Retrieval needed: ${result.needsRetrieval}`);
}
});
});

View File

@@ -0,0 +1,505 @@
import { MilvusClient, DataType } from '@zilliz/milvus2-sdk-node';
import { OpenAI } from 'openai';
import { ProviderRepository } from '../providers/_ProviderRepository.ts';
/**
* Configuration for the Agentic RAG system
*/
export interface AgenticRAGConfig {
milvusAddress?: string;
collectionName?: string;
embeddingDimension?: number;
topK?: number;
similarityThreshold?: number;
}
/**
* Result structure for Agentic RAG operations
*/
export interface AgenticRAGResult {
message: string;
status: 'success' | 'error';
data?: any;
context?: string[];
relevanceScore?: number;
}
/**
* Document structure for knowledge base
*/
export interface Document {
id: string;
content: string;
metadata?: Record<string, any>;
embedding?: number[];
}
/**
* Agentic RAG Tools for intelligent retrieval-augmented generation
* This system makes intelligent decisions about when and how to retrieve information
*/
export const AgenticRAGTools = {
type: 'function',
function: {
name: 'agentic_rag',
description:
'Intelligent retrieval-augmented generation system that can store documents, search knowledge base, and provide contextual information based on user queries. The system intelligently decides when retrieval is needed.',
parameters: {
type: 'object',
properties: {
action: {
type: 'string',
enum: ['search', 'list_collections', 'report_status'],
description: 'Action to perform with the agentic RAG system.',
},
query: {
type: 'string',
description: 'User query or search term for knowledge retrieval.',
},
document: {
type: 'object',
properties: {
content: { type: 'string', description: 'Document content to store' },
metadata: { type: 'object', description: 'Additional metadata for the document' },
id: { type: 'string', description: 'Unique identifier for the document' },
},
description: 'Document to store in the knowledge base.',
},
collection_name: {
type: 'string',
description: 'Name of the collection to work with.',
},
top_k: {
type: 'number',
description: 'Number of similar documents to retrieve (default: 5).',
},
similarity_threshold: {
type: 'number',
description: 'Minimum similarity score for relevant results (0-1, default: 0.7).',
},
context_window: {
type: 'number',
description: 'Maximum number of context tokens to include (default: 2000).',
},
},
required: ['action'],
additionalProperties: false,
},
strict: true,
},
};
/**
* Default configuration for the Agentic RAG system
*/
const DEFAULT_CONFIG: AgenticRAGConfig = {
milvusAddress: 'localhost:19530',
collectionName: 'knowledge_base',
embeddingDimension: 768,
topK: 5,
similarityThreshold: 0.58,
};
/**
* Simple embedding function using a mock implementation
* In production, this should use a real embedding service like OpenAI, Cohere, etc.
*/
async function generateEmbedding(text: string): Promise<number[] | undefined> {
const embeddingsClient = new OpenAI({
apiKey: process.env.FIREWORKS_API_KEY,
baseURL: ProviderRepository.OPENAI_COMPAT_ENDPOINTS.fireworks,
}).embeddings;
const embeddings = await embeddingsClient.create({
input: [text],
model: 'nomic-ai/nomic-embed-text-v1.5',
dimensions: 768,
});
return embeddings.data.at(0)?.embedding;
}
/**
* Analyze query to determine if retrieval is needed
*/
function analyzeQueryForRetrieval(query: string): {
needsRetrieval: boolean;
confidence: number;
reasoning: string;
queryType: 'factual' | 'conversational' | 'creative' | 'analytical';
} {
const lowerQuery = query.toLowerCase();
// Keywords that suggest factual information is needed
const factualKeywords = [
'what is',
'who is',
'when did',
'where is',
'how does',
'explain',
'define',
'describe',
'tell me about',
'information about',
'details on',
'facts about',
'history of',
'background on',
];
// Keywords that suggest conversational/creative responses
const conversationalKeywords = [
'hello',
'hi',
'how are you',
'thank you',
'please help',
'i think',
'in my opinion',
'what do you think',
'can you help',
];
// Keywords that suggest creative tasks
const creativeKeywords = [
'write a',
'create a',
'generate',
'compose',
'draft',
'story',
'poem',
'essay',
'letter',
'email',
];
let factualScore = 0;
let conversationalScore = 0;
let creativeScore = 0;
factualKeywords.forEach(keyword => {
if (lowerQuery.includes(keyword)) factualScore += 1;
});
conversationalKeywords.forEach(keyword => {
if (lowerQuery.includes(keyword)) conversationalScore += 1;
});
creativeKeywords.forEach(keyword => {
if (lowerQuery.includes(keyword)) creativeScore += 1;
});
// Determine query type and retrieval need
if (factualScore > conversationalScore && factualScore > creativeScore) {
return {
needsRetrieval: true,
confidence: Math.min(factualScore * 0.3, 0.9),
reasoning:
'Query appears to be asking for factual information that may benefit from knowledge retrieval.',
queryType: 'factual',
};
} else if (creativeScore > conversationalScore) {
return {
needsRetrieval: false,
confidence: 0.8,
reasoning: 'Query appears to be requesting creative content generation.',
queryType: 'creative',
};
} else if (conversationalScore > 0) {
return {
needsRetrieval: false,
confidence: 0.7,
reasoning: 'Query appears to be conversational in nature.',
queryType: 'conversational',
};
} else {
return {
needsRetrieval: true,
confidence: 0.6,
reasoning: 'Query type unclear, defaulting to retrieval for comprehensive response.',
queryType: 'analytical',
};
}
}
/**
* Main Agentic RAG function that handles intelligent retrieval decisions
*/
export async function agenticRAG(args: {
action: string;
query?: string;
document?: Document;
collection_name?: string;
top_k?: number;
similarity_threshold?: number;
context_window?: number;
}): Promise<AgenticRAGResult> {
console.log('calling agentic rag tool', args);
const config = { ...DEFAULT_CONFIG };
const collectionName = args.collection_name || config.collectionName!;
const topK = args.top_k || config.topK!;
const similarityThreshold = args.similarity_threshold || config.similarityThreshold!;
const milvusClient = new MilvusClient({ address: config.milvusAddress! });
try {
switch (args.action) {
case 'analyze_query':
if (!args.query) {
return { status: 'error', message: 'Query is required for analysis.' };
}
// eslint-disable-next-line no-case-declarations
const analysis = analyzeQueryForRetrieval(args.query);
return {
status: 'success',
message: `Query analysis complete. Retrieval ${analysis.needsRetrieval ? 'recommended' : 'not needed'}.`,
data: analysis,
};
case 'list_collections':
// eslint-disable-next-line no-case-declarations
const { collection_names } = (await milvusClient.listCollections()) as any as {
collection_names: string[];
};
return {
status: 'success',
message: JSON.stringify(collection_names),
};
case 'search_knowledge':
if (!args.query) {
return { status: 'error', message: 'Query is required for knowledge search.' };
}
// First, analyze if retrieval is needed
// eslint-disable-next-line no-case-declarations
const queryAnalysis = analyzeQueryForRetrieval(args.query);
if (!queryAnalysis.needsRetrieval) {
return {
status: 'success',
message: 'Query analysis suggests retrieval is not needed for this type of query.',
data: {
analysis: queryAnalysis,
retrieved_documents: [],
context: [],
},
};
}
// Generate embedding for the query
// eslint-disable-next-line no-case-declarations
const queryEmbedding = await generateEmbedding(args.query);
// Search for similar documents
// eslint-disable-next-line no-case-declarations
const searchResult = await milvusClient.search({
collection_name: collectionName,
query_vectors: [queryEmbedding],
top_k: topK,
params: { nprobe: 16 },
output_fields: ['content', 'metadata'],
});
// Filter results by similarity threshold
// eslint-disable-next-line no-case-declarations
const relevantResults = searchResult.results.filter(
(result: any) => result.score >= similarityThreshold,
);
// eslint-disable-next-line no-case-declarations
const contextDocuments = relevantResults.map((result: any) => ({
content: result.content,
score: result.score,
metadata: result.metadata,
}));
return {
status: 'success',
message: `Found ${relevantResults.length} relevant documents for query.`,
data: {
analysis: queryAnalysis,
retrieved_documents: contextDocuments,
context: contextDocuments.map((doc: any) => doc.content),
},
context: contextDocuments.map((doc: any) => doc.content),
relevanceScore: relevantResults.length > 0 ? relevantResults.at(0)?.score : 0,
};
case 'store_document':
if (!args.document || !args.document.content) {
return { status: 'error', message: 'Document with content is required for storage.' };
}
// Generate embedding for the document
// eslint-disable-next-line no-case-declarations
const docEmbedding = await generateEmbedding(args.document.content);
// eslint-disable-next-line no-case-declarations
const docId =
args.document.id || `doc_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
// Store document in Milvus
await milvusClient.insert({
collection_name: collectionName,
fields_data: [
{ name: 'id', values: [docId] },
{ name: 'embedding', values: [docEmbedding] },
{ name: 'content', values: [args.document.content] },
{ name: 'metadata', values: [JSON.stringify(args.document.metadata || {})] },
],
});
return {
status: 'success',
message: `Document stored successfully with ID: ${docId}`,
data: { document_id: docId, content_length: args.document.content.length },
};
case 'manage_collection':
try {
// Check if collection exists
const collections = await milvusClient.listCollections();
const collectionExists =
collections.data.filter(c => c.name.includes(collectionName)).length > 0;
if (!collectionExists) {
// Create collection with proper schema for RAG
const collectionSchema = {
collection_name: collectionName,
fields: [
{
name: 'id',
type: DataType.VarChar,
params: { max_length: 100 },
is_primary_key: true,
},
{
name: 'embedding',
type: DataType.FloatVector,
params: { dim: config.embeddingDimension },
},
{ name: 'content', type: DataType.VarChar, params: { max_length: 65535 } },
{ name: 'metadata', type: DataType.VarChar, params: { max_length: 1000 } },
],
};
// @ts-expect-error - idk man
await milvusClient.createCollection(collectionSchema);
// Create index for efficient similarity search
await milvusClient.createIndex({
collection_name: collectionName,
field_name: 'embedding',
index_type: 'IVF_FLAT',
params: { nlist: 1024 },
metric_type: 'COSINE',
});
return {
status: 'success',
message: `Collection '${collectionName}' created successfully with RAG schema.`,
data: { collection_name: collectionName, action: 'created' },
};
} else {
return {
status: 'success',
message: `Collection '${collectionName}' already exists.`,
data: { collection_name: collectionName, action: 'exists' },
};
}
} catch (error: any) {
return {
status: 'error',
message: `Error managing collection: ${error.message}`,
};
}
case 'semantic_search':
if (!args.query) {
return { status: 'error', message: 'Query is required for semantic search.' };
}
// eslint-disable-next-line no-case-declarations
const semanticEmbedding = await generateEmbedding(args.query);
// eslint-disable-next-line no-case-declarations
const semanticResults = await milvusClient.search({
collection_name: collectionName,
query_vectors: [semanticEmbedding],
top_k: topK,
params: { nprobe: 16 },
output_fields: ['content', 'metadata'],
});
return {
status: 'success',
message: `Semantic search completed. Found ${semanticResults.results.length} results.`,
data: {
results: semanticResults.results.map((result: any) => ({
content: result.content,
score: result.score,
metadata: JSON.parse(result.metadata || '{}'),
})),
},
};
case 'get_context':
if (!args.query) {
return { status: 'error', message: 'Query is required to get context.' };
}
// This is a comprehensive context retrieval that combines analysis and search
// eslint-disable-next-line no-case-declarations
const contextAnalysis = analyzeQueryForRetrieval(args.query);
if (contextAnalysis.needsRetrieval) {
const contextEmbedding = await generateEmbedding(args.query);
const contextSearch = await milvusClient.search({
collection_name: collectionName,
query_vectors: [contextEmbedding],
top_k: topK,
params: { nprobe: 16 },
output_fields: ['content', 'metadata'],
});
const contextResults = contextSearch.results
.filter((result: any) => result.score >= similarityThreshold)
.map((result: any) => ({
content: result.content,
score: result.score,
metadata: JSON.parse(result.metadata || '{}'),
}));
return {
status: 'success',
message: `Context retrieved successfully. Found ${contextResults.length} relevant documents.`,
data: {
analysis: contextAnalysis,
context_documents: contextResults,
context_summary: contextResults.map((doc: any) => doc.content).join('\n\n'),
},
context: contextResults.map((doc: any) => doc.content),
};
} else {
return {
status: 'success',
message: 'No context retrieval needed for this query type.',
data: {
analysis: contextAnalysis,
context_documents: [],
context_summary: '',
},
};
}
default:
return { status: 'error', message: 'Invalid action specified.' };
}
} catch (error: any) {
return {
status: 'error',
message: `Integration error: ${error.message}`,
};
}
}