using BinaryDad.Extensions; using Newtonsoft.Json; using Salesforce.NET.Entities; using System; using System.Collections.Generic; using System.Linq.Expressions; using System.Linq; using System.Reflection; using System.Threading.Tasks; using System.Net.Http; using System.Net.Http.Headers; using System.Text; using Newtonsoft.Json.Converters; using System.Net; namespace Salesforce.NET { public class SalesforceApiClient { private const string ContextKey = "SalesforceApiContext"; private readonly Task initializationTask; private readonly string batchProcessingConnectionString; const string apiVersion = "51.0"; private SalesforceContext SalesforceContext { get => CacheHelper.Get(ContextKey); set => CacheHelper.Add(ContextKey, value, 1440); } public SalesforceApiClient(SalesforceCredentials credentials) { if (SalesforceContext == null) { // authenticate from constructor, but task is awaited prior to invoking request initializationTask = AuthenticateAsync(credentials); } } /// /// Retrieves a single Salesforce record by ID /// /// /// /// public async Task GetAsync(string id) where T : SalesforceEntity { var salesforceObject = GetObjectName(); var getApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}/{id}"); return await InvokeRequestAsync(async httpClient => { var response = await httpClient.GetAsync(getApiUrl); return await response.Content.ReadAsAsync(); }); } /// /// Retrieves a single Salesforce record matching on an external ID /// /// /// The property representing of the external ID to lookup /// /// public async Task GetAsync(Expression> externalIdProperty, string id) where T : SalesforceEntity { var salesforceObject = GetObjectName(); if (externalIdProperty.Body is MemberExpression memberExpression) { var jsonProperty = memberExpression.Member.GetCustomAttribute(); var propertyName = jsonProperty?.PropertyName ?? memberExpression.Member.Name; var getApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}/{propertyName}/{id}"); return await InvokeRequestAsync(async httpClient => { var response = await httpClient.GetAsync(getApiUrl); return await response.Content.ReadAsAsync(); }); } throw new ArgumentException("External ID must be a property expression", nameof(externalIdProperty)); } /// /// Creates a new Salesforce record /// /// /// /// public async Task> CreateAsync(T record) where T : SalesforceEntity { var result = new Result(); var salesforceObject = GetObjectName(); var createApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}"); return await InvokeRequestAsync(async httpClient => { record.SerializeReadOnlyProperties = false; var response = await httpClient.PostAsJsonAsync(createApiUrl, record); var createResponse = await response.Content.ReadAsAsync(); if (createResponse.Success) { record.RecordId = createResponse.RecordId; result.Data = createResponse.RecordId; result.Success = true; } else { // TODO: parse error response result.AddError($"Error creating record"); } return result; }); } /// /// Creates a set of new Salesforce records /// /// /// /// public async Task CreateAsync(IEnumerable records) where T : SalesforceEntity { var result = new Result(); var salesforceObject = GetObjectName(); var createApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}"); return await InvokeRequestAsync(async httpClient => { var messages = new List(); // default to true until any fail in loop result.Success = true; foreach (var record in records) { record.SerializeReadOnlyProperties = false; var response = await httpClient.PostAsJsonAsync(createApiUrl, record); var responseBody = await response.Content.ReadAsStringAsync(); if (response.IsSuccessStatusCode) { var createResponse = responseBody.Deserialize(); if (createResponse.Success) { record.RecordId = createResponse.RecordId; messages.Add($"Created record with ID {record.RecordId}"); } else { // TODO: follow pattern of batch lead create, consolidate response inti Result result.AddError($"Error creating record"); } } else { var errorResponse = responseBody.Deserialize(); var errorMessages = errorResponse .Select(e => $"[{e.ErrorCode}] {e.Message}") .ToList(); result.AddErrors(errorMessages); } } if (messages.Any()) { result.Message = messages.Join("; "); } return result; }); } /// /// Creates a set of new Salesforce records (TO POSSIBLY REPLACE CreateAsync(IEnumerable)) /// /// /// /// public async Task CreateBulkAsync(IEnumerable records) where T : SalesforceEntity { var salesforceObject = GetObjectName(); #region Create batch job var jobCreateResponse = await InvokeRequestAsync(async httpClient => { var createApiUrl = await GetEndpointUrlAsync("jobs/ingest"); var response = await httpClient.PostAsJsonAsync(createApiUrl, new { @object = salesforceObject, contentType = "CSV", operation = "insert", lineEnding = "CRLF" }); return await response.Content.ReadAsAsync(); }); #endregion #region Upload content var uploadResponse = await InvokeRequestAsync(async httpClient => { var insertApiUrl = UrlUtility.Combine(SalesforceContext.InstanceUrl, jobCreateResponse.ContentUrl); var recordsTable = records.ToDataTable(info => info.GetCustomAttribute()?.PropertyName); var recordsCsv = recordsTable.ToCsv(); var recordsEncoded = recordsCsv.Base64Encode(); var csvContent = new StringContent(recordsEncoded); csvContent.Headers.ContentType = new MediaTypeHeaderValue("text/csv"); httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); var response = await httpClient.PutAsync(insertApiUrl, csvContent); return await response.Content.ReadAsStringAsync(); }); #endregion #region Close job var closeResponse = await InvokeRequestAsync(async httpClient => { var closeApiUrl = await GetEndpointUrlAsync($"jobs/ingest/{jobCreateResponse.Id}"); var method = new HttpMethod("PATCH"); var body = new { state = "UploadComplete" }; var content = new StringContent(body.Serialize(), Encoding.UTF8, "application/json"); var message = new HttpRequestMessage(method, closeApiUrl) { Content = content }; var response = await httpClient.SendAsync(message); return await response.Content.ReadAsStringAsync(); }); #endregion #region Get Status var statusResponse = await InvokeRequestAsync(async httpClient => { var statusApiUrl = await GetEndpointUrlAsync($"jobs/ingest/{jobCreateResponse.Id}"); var response = await httpClient.GetAsync(statusApiUrl); return await response.Content.ReadAsStringAsync(); }); #endregion #region Download results var downloadResponse = await InvokeRequestAsync(async httpClient => { var downloadApiUrl = await GetEndpointUrlAsync($"jobs/ingest/{jobCreateResponse.Id}/failedResults"); httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("text/csv")); var response = await httpClient.GetAsync(downloadApiUrl); return await response.Content.ReadAsStringAsync(); }); #endregion return null; } public Task UpdateAsync(string id, Action update) where T : SalesforceEntity, new() { var record = new T() { RecordId = id }; var trackable = record.AsTrackable(); update(record); return UpdateAsync(trackable); } public async Task UpdateAsync(Trackable record) where T : SalesforceEntity { if (record.Value.RecordId.IsNullOrWhiteSpace()) { throw new ArgumentException("Record ID cannot be null or empty", nameof(record.Value.RecordId)); } var salesforceObject = GetObjectName(); var createApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}/{record.Value.RecordId}"); return await InvokeRequestAsync(async httpClient => { var method = new HttpMethod("PATCH"); var body = JsonConvert.SerializeObject(record.Modified, new StringEnumConverter()); var content = new StringContent(body, Encoding.UTF8, "application/json"); var message = new HttpRequestMessage(method, createApiUrl) { Content = content }; var response = await httpClient.SendAsync(message); if (response.IsSuccessStatusCode) { return new UpdateResponse(true); } return await response.Content.ReadAsAsync(); }); } /// /// Retrieves Salesforce records using an object query expression /// /// /// /// public async Task> QueryAsync(Expression> query) where T : SalesforceEntity { var salesforceObject = GetObjectName(); // build a list of properties used in the "select" portion of the query var properties = typeof(T) .GetProperties() .WhereIsQueryable() .WhereIsSerializable() .Select(p => { var jsonProperty = p.GetCustomAttribute(); return jsonProperty?.PropertyName ?? p.Name; }) .WhereNotNull() .Join(", "); // generate a SQL clause from the query expression var clause = ConvertExpression(query.Body); var queryApiUrl = await GetEndpointUrlAsync($"query?q=select {properties} from {salesforceObject} where {clause}"); return await InvokeRequestAsync(async httpClient => { var response = await httpClient.GetAsync(queryApiUrl); var queryResponse = await response.Content.ReadAsAsync>(); var records = queryResponse.Records; // TODO: incorporate limit while (!queryResponse.Done) { var absoluteNextRecordsUrl = UrlUtility.Combine(SalesforceContext.InstanceUrl, queryResponse.NextRecordsUrl); response = await httpClient.GetAsync(absoluteNextRecordsUrl); queryResponse = await response.Content.ReadAsAsync>(); records.AddRange(queryResponse.Records); } return records; }); } /// /// Retrieves schema information for an entity /// /// /// public async Task GetSchema() where T : SalesforceEntity { var salesforceObject = GetObjectName(); var schemaApiUrl = await GetEndpointUrlAsync($"sobjects/{salesforceObject}/describe"); return await InvokeRequestAsync(async httpClient => { var response = await httpClient.GetAsync(schemaApiUrl); var responseBody = await response.Content.ReadAsStringAsync(); return await response.Content.ReadAsAsync(); }); } /// /// Validates whether the entity is properly mapped to its schema in Salesforce /// /// /// public async Task> ValidateSchema() where T : SalesforceEntity, new() { var result = new Result(); var schema = await GetSchema(); // get the fields from Salesforce var salesforceFields = schema.Fields .Select(f => f.Name) .ToList(); // get the fields for the entity var propertyNames = typeof(T) .GetProperties() .WhereIsSerializable() .WhereIsQueryable() .GetPropertyNames(); // find any orphaned/mismatched fields var mismatchedFields = propertyNames .Where(p => !salesforceFields.Contains(p, StringComparer.OrdinalIgnoreCase)) .ToList(); result.Success = mismatchedFields.NoneOrNull(); if (result.Success) { result.Message = "All fields are valid"; } else { foreach (var mismatchedField in mismatchedFields) { result.AddError($"Field [{mismatchedField}] was not found in source"); } } return result; } #region Private Methods private async Task AuthenticateAsync(SalesforceCredentials credentials) { var baseApiUrl = credentials.IsProduction ? "https://login.salesforce.com/services" : "https://test.salesforce.com/services"; var authenticationUrl = UrlUtility.Combine(baseApiUrl, "oauth2/token"); var httpClient = new HttpClient(); var authenticationRequest = new Dictionary { ["grant_type"] = "password", ["client_id"] = credentials.ClientId, ["client_secret"] = credentials.ClientSecret, ["username"] = credentials.Username, ["password"] = WebUtility.UrlEncode(credentials.Password) + credentials.SecurityToken }; // don't use ToQueryString() as it encodes the values var requestBody = authenticationRequest .Select(k => $"{k.Key}={k.Value}") .Join("&"); var requestContent = new StringContent(requestBody, Encoding.UTF8, "application/x-www-form-urlencoded"); var response = await httpClient.PostAsync(authenticationUrl, requestContent); var authenticationResponse = await response.Content.ReadAsAsync(); if (authenticationResponse.Success) { SalesforceContext = new SalesforceContext(authenticationResponse.AccessToken, authenticationResponse.InstanceUrl, authenticationResponse.TokenType); } } private async Task InvokeRequestAsync(Func> action) { await AssertInitializedAsync(); var httpClient = new HttpClient(); httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue(SalesforceContext.TokenType.ToString(), SalesforceContext.AccessToken); return await action(httpClient); } private async Task GetEndpointUrlAsync(string relativeApiUrl) { await AssertInitializedAsync(); return UrlUtility.Combine(SalesforceContext.InstanceUrl, $"services/data/v{apiVersion}/{relativeApiUrl}"); } private string GetObjectName() { var type = typeof(T); var objectAttribute = type.GetCustomAttribute(); return objectAttribute?.ObjectName ?? type.Name; } /// /// Converts an expression tree into a SQL query clause /// /// /// private static string ConvertExpression(Expression expression) { /* NOTE: this is a work in progress, as I did not like the built-in * Expression.ToString(), and this allows us to customize the query to our liking */ if (expression is BinaryExpression binaryExpression) { // traverse the tree and combine all sub expression groups var left = ConvertExpression(binaryExpression.Left); var comparison = GetExpressionType(binaryExpression.NodeType); var right = ConvertExpression(binaryExpression.Right); return $"({left} {comparison} {right})"; } if (expression is ConstantExpression || expression is MethodCallExpression) { return GetExpressionValue(expression); } if (expression is MemberExpression memberExpression) { if (memberExpression.Expression is ParameterExpression) { // NOTE: this should be be the property from the delegate's parameter var jsonProperty = memberExpression.Member.GetCustomAttribute(); return jsonProperty != null ? jsonProperty.PropertyName : memberExpression.Member.Name; } return GetExpressionValue(expression); } return string.Empty; } private static string GetExpressionValue(Expression expression) { var value = Expression.Lambda(expression).Compile().DynamicInvoke(); if (expression.Type == typeof(string)) { // TODO: escaping single quotes? return $"'{value}'"; } if (expression.Type == typeof(DateTime)) { return ((DateTime)value).ToString("yyyy-MM-ddThh:mm:ssZ"); } if (value == null) { return string.Empty; } return value.ToString(); } private static string GetExpressionType(ExpressionType expressionType) { return expressionType == ExpressionType.Equal ? "=" : expressionType == ExpressionType.NotEqual ? "!=" : expressionType == ExpressionType.GreaterThan ? ">" : expressionType == ExpressionType.GreaterThanOrEqual ? ">=" : expressionType == ExpressionType.LessThan ? "<" : expressionType == ExpressionType.LessThanOrEqual ? "<=" : expressionType == ExpressionType.AndAlso ? "and" : expressionType == ExpressionType.And ? "and" : expressionType == ExpressionType.Or ? "or" : expressionType == ExpressionType.OrElse ? "or" : ""; } private async Task AssertInitializedAsync() { if (SalesforceContext == null && initializationTask != null) { await initializationTask; } } #endregion } }